package org.apache.flink.runtime.scheduler.adaptivebatch;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.flink.configuration.BatchExecutionOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
import org.apache.flink.runtime.executiongraph.ParallelismAndInputInfos;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.io.network.partition.hybrid.tiered.storage.HashBufferAccumulatorTest;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.shaded.guava31.com.google.common.collect.Iterables;
import org.apache.flink.util.Preconditions;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.class */
class DefaultVertexParallelismAndInputInfosDeciderTest {
    private static final long BYTE_256_MB = 268435456;
    private static final long BYTE_512_MB = 536870912;
    private static final long BYTE_1_GB = 1073741824;
    private static final long BYTE_8_GB = 8589934592L;
    private static final long BYTE_1_TB = 1099511627776L;
    private static final int MAX_PARALLELISM = 100;
    private static final int MIN_PARALLELISM = 3;
    private static final int DEFAULT_SOURCE_PARALLELISM = 10;
    private static final long DATA_VOLUME_PER_TASK = 1073741824;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest$TestingBlockingResultInfo.class */
    public static class TestingBlockingResultInfo implements BlockingResultInfo {
        private final boolean isBroadcast;
        private final long producedBytes;
        private final int numPartitions;
        private final int numSubpartitions;

        private TestingBlockingResultInfo(boolean z, long j) {
            this(z, j, DefaultVertexParallelismAndInputInfosDeciderTest.MAX_PARALLELISM, DefaultVertexParallelismAndInputInfosDeciderTest.MAX_PARALLELISM);
        }

        private TestingBlockingResultInfo(boolean z, long j, int i, int i2) {
            this.isBroadcast = z;
            this.producedBytes = j;
            this.numPartitions = i;
            this.numSubpartitions = i2;
        }

        public IntermediateDataSetID getResultId() {
            return new IntermediateDataSetID();
        }

        public boolean isBroadcast() {
            return this.isBroadcast;
        }

        public boolean isPointwise() {
            return false;
        }

        public int getNumPartitions() {
            return this.numPartitions;
        }

        public int getNumSubpartitions(int i) {
            return this.numSubpartitions;
        }

        public long getNumBytesProduced() {
            return this.producedBytes;
        }

        public long getNumBytesProduced(IndexRange indexRange, IndexRange indexRange2) {
            throw new UnsupportedOperationException();
        }

        public void recordPartitionInfo(int i, ResultPartitionBytes resultPartitionBytes) {
        }

        public void resetPartitionInfo(int i) {
        }
    }

    DefaultVertexParallelismAndInputInfosDeciderTest() {
    }

    @Test
    void testDecideParallelism() {
        Assertions.assertThat(createDeciderAndDecideParallelism(Arrays.asList(createFromBroadcastResult(BYTE_256_MB), createFromNonBroadcastResult(8858370048L)))).isEqualTo(9);
    }

    @Test
    void testInitiallyNormalizedParallelismIsLargerThanMaxParallelism() {
        Assertions.assertThat(createDeciderAndDecideParallelism(Arrays.asList(createFromBroadcastResult(BYTE_256_MB), createFromNonBroadcastResult(1108101562368L)))).isEqualTo(MAX_PARALLELISM);
    }

    @Test
    void testInitiallyNormalizedParallelismIsSmallerThanMinParallelism() {
        Assertions.assertThat(createDeciderAndDecideParallelism(Arrays.asList(createFromBroadcastResult(BYTE_256_MB), createFromNonBroadcastResult(BYTE_512_MB)))).isEqualTo(3);
    }

    @Test
    void testNonBroadcastBytesCanNotDividedEvenly() {
        Assertions.assertThat(createDeciderAndDecideParallelism(Arrays.asList(createFromBroadcastResult(BYTE_512_MB), createFromNonBroadcastResult(8858370048L)))).isEqualTo(9);
    }

    @Test
    void testDecideParallelismWithMaxSubpartitionLimitation() {
        Assertions.assertThat(createDeciderAndDecideParallelism(1, MAX_PARALLELISM, BYTE_256_MB, Arrays.asList(new TestingBlockingResultInfo(false, 1L, HashBufferAccumulatorTest.NETWORK_BUFFER_SIZE, HashBufferAccumulatorTest.NETWORK_BUFFER_SIZE), new TestingBlockingResultInfo(false, 1L, 512, 512)))).isEqualTo(32);
    }

    @Test
    void testAllEdgesAllToAll() {
        BlockingResultInfo createAllToAllBlockingResultInfo = createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17});
        BlockingResultInfo createAllToAllBlockingResultInfo2 = createAllToAllBlockingResultInfo(new long[]{8, 12, 21, 9, 13, 7, 19, 13, 14, 5});
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, 10, 60L, Arrays.asList(createAllToAllBlockingResultInfo, createAllToAllBlockingResultInfo2));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(5);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
        List asList = Arrays.asList(new IndexRange(0, 1), new IndexRange(2, 3), new IndexRange(4, 6), new IndexRange(7, 8), new IndexRange(9, 9));
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo.getResultId()), asList);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo2.getResultId()), asList);
    }

    @Test
    void testAllEdgesAllToAllAndDecidedParallelismIsMaxParallelism() {
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, 2, 10L, Collections.singletonList(createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17})));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(2);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) Iterables.getOnlyElement(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().values()), Arrays.asList(new IndexRange(0, 5), new IndexRange(6, 9)));
    }

    @Test
    void testAllEdgesAllToAllAndDecidedParallelismIsMinParallelism() {
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(4, 10, 1000L, Collections.singletonList(createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17})));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(4);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) Iterables.getOnlyElement(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().values()), Arrays.asList(new IndexRange(0, 1), new IndexRange(2, 5), new IndexRange(6, 7), new IndexRange(8, 9)));
    }

    @Test
    void testFallBackToEvenlyDistributeSubpartitions() {
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(8, 8, 10L, Collections.singletonList(createAllToAllBlockingResultInfo(new long[]{10, 1, 10, 1, 10, 1, 10, 1, 10, 1})));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(8);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(1);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) Iterables.getOnlyElement(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().values()), Arrays.asList(new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(2, 2), new IndexRange(3, 4), new IndexRange(5, 5), new IndexRange(6, 6), new IndexRange(7, 7), new IndexRange(8, 9)));
    }

    @Test
    void testAllEdgesAllToAllAndOneIsBroadcast() {
        BlockingResultInfo createAllToAllBlockingResultInfo = createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17});
        BlockingResultInfo createAllToAllBlockingResultInfo2 = createAllToAllBlockingResultInfo(new long[]{10}, true);
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, 10, 60L, Arrays.asList(createAllToAllBlockingResultInfo, createAllToAllBlockingResultInfo2));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(3);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo.getResultId()), Arrays.asList(new IndexRange(0, 4), new IndexRange(5, 8), new IndexRange(9, 9)));
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo2.getResultId()), Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(0, 0)));
    }

    @Test
    void testAllEdgesBroadcast() {
        BlockingResultInfo createAllToAllBlockingResultInfo = createAllToAllBlockingResultInfo(new long[]{10}, true);
        BlockingResultInfo createAllToAllBlockingResultInfo2 = createAllToAllBlockingResultInfo(new long[]{10}, true);
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, 10, 60L, Arrays.asList(createAllToAllBlockingResultInfo, createAllToAllBlockingResultInfo2));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(1);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo.getResultId()), Collections.singletonList(new IndexRange(0, 0)));
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo2.getResultId()), Collections.singletonList(new IndexRange(0, 0)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v3, types: [long[], long[][]] */
    @Test
    void testHavePointwiseEdges() {
        BlockingResultInfo createAllToAllBlockingResultInfo = createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17});
        BlockingResultInfo createPointwiseBlockingResultInfo = createPointwiseBlockingResultInfo(new long[]{new long[]{8, 12, 21, 9, 13}, new long[]{7, 19, 13, 14, 5}});
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, 10, 60L, Arrays.asList(createAllToAllBlockingResultInfo, createPointwiseBlockingResultInfo));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(4);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).hasSize(2);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createAllToAllBlockingResultInfo.getResultId()), Arrays.asList(new IndexRange(0, 1), new IndexRange(2, 4), new IndexRange(5, 6), new IndexRange(7, 9)));
        checkPointwiseJobVertexInputInfo((JobVertexInputInfo) createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().get(createPointwiseBlockingResultInfo.getResultId()), Arrays.asList(new IndexRange(0, 0), new IndexRange(0, 0), new IndexRange(1, 1), new IndexRange(1, 1)), Arrays.asList(new IndexRange(0, 1), new IndexRange(2, 4), new IndexRange(0, 1), new IndexRange(2, 4)));
    }

    @Test
    void testParallelismAlreadyDecided() {
        ParallelismAndInputInfos decideParallelismAndInputInfosForVertex = createDecider(3, MAX_PARALLELISM, 1073741824L).decideParallelismAndInputInfosForVertex(new JobVertexID(), Collections.singletonList(createAllToAllBlockingResultInfo(new long[]{10, 15, 13, 12, 1, 10, 8, 20, 12, 17})), 3, MAX_PARALLELISM);
        Assertions.assertThat(decideParallelismAndInputInfosForVertex.getParallelism()).isEqualTo(3);
        Assertions.assertThat(decideParallelismAndInputInfosForVertex.getJobVertexInputInfos()).hasSize(1);
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) Iterables.getOnlyElement(decideParallelismAndInputInfosForVertex.getJobVertexInputInfos().values()), Arrays.asList(new IndexRange(0, 2), new IndexRange(3, 5), new IndexRange(6, 9)));
    }

    @Test
    void testSourceJobVertex() {
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(3, MAX_PARALLELISM, 1073741824L, Collections.emptyList());
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(10);
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos()).isEmpty();
    }

    @Test
    void testEvenlyDistributeDataWithMaxSubpartitionLimitation() {
        long[] jArr = new long[HashBufferAccumulatorTest.NETWORK_BUFFER_SIZE];
        Arrays.fill(jArr, 1L);
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), HashBufferAccumulatorTest.NETWORK_BUFFER_SIZE, HashBufferAccumulatorTest.NETWORK_BUFFER_SIZE, false);
        for (int i = 0; i < 1024; i++) {
            allToAllBlockingResultInfo.recordPartitionInfo(i, new ResultPartitionBytes(jArr));
        }
        ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos = createDeciderAndDecideParallelismAndInputInfos(1, MAX_PARALLELISM, BYTE_256_MB, Collections.singletonList(allToAllBlockingResultInfo));
        Assertions.assertThat(createDeciderAndDecideParallelismAndInputInfos.getParallelism()).isEqualTo(32);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < 32; i2++) {
            arrayList.add(new IndexRange(i2 * 32, ((i2 + 1) * 32) - 1));
        }
        checkAllToAllJobVertexInputInfo((JobVertexInputInfo) Iterables.getOnlyElement(createDeciderAndDecideParallelismAndInputInfos.getJobVertexInputInfos().values()), new IndexRange(0, 1023), arrayList);
    }

    private static void checkAllToAllJobVertexInputInfo(JobVertexInputInfo jobVertexInputInfo, List<IndexRange> list) {
        checkAllToAllJobVertexInputInfo(jobVertexInputInfo, new IndexRange(0, 0), list);
    }

    private static void checkAllToAllJobVertexInputInfo(JobVertexInputInfo jobVertexInputInfo, IndexRange indexRange, List<IndexRange> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            arrayList.add(new ExecutionVertexInputInfo(i, indexRange, list.get(i)));
        }
        Assertions.assertThat(jobVertexInputInfo.getExecutionVertexInputInfos()).containsExactlyInAnyOrderElementsOf(arrayList);
    }

    private static void checkPointwiseJobVertexInputInfo(JobVertexInputInfo jobVertexInputInfo, List<IndexRange> list, List<IndexRange> list2) {
        Assertions.assertThat(list.size()).isEqualTo(list2.size());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list2.size(); i++) {
            arrayList.add(new ExecutionVertexInputInfo(i, list.get(i), list2.get(i)));
        }
        Assertions.assertThat(jobVertexInputInfo.getExecutionVertexInputInfos()).containsExactlyInAnyOrderElementsOf(arrayList);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DefaultVertexParallelismAndInputInfosDecider createDecider(int i, int i2, long j) {
        return createDecider(i, i2, j, 10);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static DefaultVertexParallelismAndInputInfosDecider createDecider(int i, int i2, long j, int i3) {
        Configuration configuration = new Configuration();
        configuration.setInteger(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_MIN_PARALLELISM, i);
        configuration.set(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_AVG_DATA_VOLUME_PER_TASK, new MemorySize(j));
        configuration.setInteger(BatchExecutionOptions.ADAPTIVE_AUTO_PARALLELISM_DEFAULT_SOURCE_PARALLELISM, i3);
        return DefaultVertexParallelismAndInputInfosDecider.from(i2, configuration);
    }

    private static int createDeciderAndDecideParallelism(List<BlockingResultInfo> list) {
        return createDeciderAndDecideParallelism(3, MAX_PARALLELISM, 1073741824L, list);
    }

    private static int createDeciderAndDecideParallelism(int i, int i2, long j, List<BlockingResultInfo> list) {
        return createDecider(i, i2, j).decideParallelism(new JobVertexID(), list, i, i2);
    }

    private static ParallelismAndInputInfos createDeciderAndDecideParallelismAndInputInfos(int i, int i2, long j, List<BlockingResultInfo> list) {
        return createDecider(i, i2, j).decideParallelismAndInputInfosForVertex(new JobVertexID(), list, -1, i2);
    }

    private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo(long[] jArr) {
        return createAllToAllBlockingResultInfo(jArr, false);
    }

    private AllToAllBlockingResultInfo createAllToAllBlockingResultInfo(long[] jArr, boolean z) {
        AllToAllBlockingResultInfo allToAllBlockingResultInfo = new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 1, jArr.length, z);
        allToAllBlockingResultInfo.recordPartitionInfo(0, new ResultPartitionBytes(jArr));
        return allToAllBlockingResultInfo;
    }

    private PointwiseBlockingResultInfo createPointwiseBlockingResultInfo(long[]... jArr) {
        Set set = (Set) Arrays.stream(jArr).map(jArr2 -> {
            return Integer.valueOf(jArr2.length);
        }).collect(Collectors.toSet());
        Preconditions.checkState(set.size() == 1);
        PointwiseBlockingResultInfo pointwiseBlockingResultInfo = new PointwiseBlockingResultInfo(new IntermediateDataSetID(), jArr.length, ((Integer) set.iterator().next()).intValue());
        int i = 0;
        for (long[] jArr3 : jArr) {
            int i2 = i;
            i++;
            pointwiseBlockingResultInfo.recordPartitionInfo(i2, new ResultPartitionBytes(jArr3));
        }
        return pointwiseBlockingResultInfo;
    }

    private static BlockingResultInfo createFromBroadcastResult(long j) {
        return new TestingBlockingResultInfo(true, j);
    }

    private static BlockingResultInfo createFromNonBroadcastResult(long j) {
        return new TestingBlockingResultInfo(false, j);
    }
}
