package org.apache.flink.state.benchmark;

import java.io.IOException;
import java.lang.invoke.SerializedLambda;
import java.util.Iterator;
import java.util.function.Supplier;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
import org.apache.flink.runtime.state.CheckpointStorageAccess;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.StateBackend;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;

/* loaded from: input_file:org/apache/flink/state/benchmark/RescalingBenchmark.class */
public class RescalingBenchmark<KEY> {
    private final int maxParallelism;
    private final int parallelismBefore;
    private final int parallelismAfter;
    private final int managedMemorySize;
    private final StateBackend stateBackend;
    private final CheckpointStorageAccess checkpointStorageAccess;
    private OperatorSubtaskState stateForRescaling;
    private OperatorSubtaskState stateForSubtask;
    private KeyedOneInputStreamOperatorTestHarness subtaskHarness;
    private final StreamRecordGenerator<KEY> streamRecordGenerator;
    private final Supplier<KeyedProcessFunction<KEY, KEY, Void>> stateProcessFunctionSupplier;

    /* loaded from: input_file:org/apache/flink/state/benchmark/RescalingBenchmark$StreamRecordGenerator.class */
    public interface StreamRecordGenerator<T> {
        Iterator<StreamRecord<T>> generate();

        TypeInformation<T> getTypeInformation();
    }

    public RescalingBenchmark(int i, int i2, int i3, int i4, StateBackend stateBackend, CheckpointStorageAccess checkpointStorageAccess, StreamRecordGenerator<KEY> streamRecordGenerator, Supplier<KeyedProcessFunction<KEY, KEY, Void>> supplier) {
        this.parallelismBefore = i;
        this.parallelismAfter = i2;
        this.maxParallelism = i3;
        this.managedMemorySize = i4;
        this.stateBackend = stateBackend;
        this.checkpointStorageAccess = checkpointStorageAccess;
        this.streamRecordGenerator = streamRecordGenerator;
        this.stateProcessFunctionSupplier = supplier;
    }

    public void setUp() throws Exception {
        this.stateForRescaling = prepareState();
    }

    public void tearDown() throws IOException {
        this.stateForRescaling.discardState();
    }

    public void rescale() throws Exception {
        this.subtaskHarness.initializeState(this.stateForSubtask);
    }

    public void closeOperator() throws Exception {
        this.subtaskHarness.close();
    }

    public void prepareStateForOperator(int i) throws Exception {
        this.stateForSubtask = AbstractStreamOperatorTestHarness.repartitionOperatorState(this.stateForRescaling, this.maxParallelism, this.parallelismBefore, this.parallelismAfter, i);
        this.subtaskHarness = getTestHarness(obj -> {
            return obj;
        }, this.maxParallelism, this.parallelismAfter, i);
        this.subtaskHarness.setStateBackend(this.stateBackend);
        this.subtaskHarness.setup();
    }

    private OperatorSubtaskState prepareState() throws Exception {
        KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] keyedOneInputStreamOperatorTestHarnessArr = new KeyedOneInputStreamOperatorTestHarness[this.parallelismBefore];
        for (int i = 0; i < this.parallelismBefore; i++) {
            try {
                keyedOneInputStreamOperatorTestHarnessArr[i] = getTestHarness(obj -> {
                    return obj;
                }, this.maxParallelism, this.parallelismBefore, i);
                keyedOneInputStreamOperatorTestHarnessArr[i].setStateBackend(this.stateBackend);
                keyedOneInputStreamOperatorTestHarnessArr[i].setup();
                keyedOneInputStreamOperatorTestHarnessArr[i].open();
            } catch (Throwable th) {
                closeHarnessArray(keyedOneInputStreamOperatorTestHarnessArr);
                throw th;
            }
        }
        Iterator<StreamRecord<KEY>> generate = this.streamRecordGenerator.generate();
        while (generate.hasNext()) {
            StreamRecord<KEY> next = generate.next();
            keyedOneInputStreamOperatorTestHarnessArr[KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(this.maxParallelism, this.parallelismBefore, KeyGroupRangeAssignment.assignToKeyGroup(next.getValue(), this.maxParallelism))].processElement(next);
        }
        OperatorSubtaskState[] operatorSubtaskStateArr = new OperatorSubtaskState[this.parallelismBefore];
        for (int i2 = 0; i2 < this.parallelismBefore; i2++) {
            operatorSubtaskStateArr[i2] = keyedOneInputStreamOperatorTestHarnessArr[i2].snapshot(0L, 1L);
        }
        OperatorSubtaskState repackageState = AbstractStreamOperatorTestHarness.repackageState(operatorSubtaskStateArr);
        closeHarnessArray(keyedOneInputStreamOperatorTestHarnessArr);
        return repackageState;
    }

    private KeyedOneInputStreamOperatorTestHarness<KEY, KEY, Void> getTestHarness(KeySelector<KEY, KEY> keySelector, int i, int i2, int i3) throws Exception {
        MockEnvironment build = new MockEnvironmentBuilder().setTaskName("RescalingTask").setManagedMemorySize(this.managedMemorySize).setMaxParallelism(i).setParallelism(i2).setSubtaskIndex(i3).build();
        build.setCheckpointStorageAccess(this.checkpointStorageAccess);
        return new KeyedOneInputStreamOperatorTestHarness<>(new KeyedProcessOperator(this.stateProcessFunctionSupplier.get()), keySelector, this.streamRecordGenerator.getTypeInformation(), build);
    }

    private void closeHarnessArray(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] keyedOneInputStreamOperatorTestHarnessArr) throws Exception {
        for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> keyedOneInputStreamOperatorTestHarness : keyedOneInputStreamOperatorTestHarnessArr) {
            if (keyedOneInputStreamOperatorTestHarness != null) {
                keyedOneInputStreamOperatorTestHarness.close();
            }
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1107722067:
                if (implMethodName.equals("lambda$prepareState$2668cbae$1")) {
                    z = false;
                    break;
                }
                break;
            case 1539707009:
                if (implMethodName.equals("lambda$prepareStateForOperator$2f7060f5$1")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/state/benchmark/RescalingBenchmark") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                    return obj -> {
                        return obj;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/flink/api/java/functions/KeySelector") && serializedLambda.getFunctionalInterfaceMethodName().equals("getKey") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/flink/state/benchmark/RescalingBenchmark") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                    return obj2 -> {
                        return obj2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
