package org.apache.flink.runtime.state.heap;

import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.StateSnapshotTransformer;

/* loaded from: input_file:org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.class */
public class NestedStateMapSnapshot<K, N, S> extends StateMapSnapshot<K, N, S, NestedStateMap<K, N, S>> {
    public NestedStateMapSnapshot(NestedStateMap<K, N, S> nestedStateMap) {
        super(nestedStateMap);
    }

    @Override // org.apache.flink.runtime.state.heap.StateMapSnapshot
    public void writeState(TypeSerializer<K> typeSerializer, TypeSerializer<N> typeSerializer2, TypeSerializer<S> typeSerializer3, @Nonnull DataOutputView dataOutputView, @Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) throws IOException {
        Map<N, Map<K, S>> filterMappingsIfNeeded = filterMappingsIfNeeded(((NestedStateMap) this.owningStateMap).getNamespaceMap(), stateSnapshotTransformer);
        dataOutputView.writeInt(countMappingsInKeyGroup(filterMappingsIfNeeded));
        for (Map.Entry<N, Map<K, S>> entry : filterMappingsIfNeeded.entrySet()) {
            N key = entry.getKey();
            for (Map.Entry<K, S> entry2 : entry.getValue().entrySet()) {
                typeSerializer2.serialize(key, dataOutputView);
                typeSerializer.serialize(entry2.getKey(), dataOutputView);
                typeSerializer3.serialize(entry2.getValue(), dataOutputView);
            }
        }
    }

    private Map<N, Map<K, S>> filterMappingsIfNeeded(Map<N, Map<K, S>> map, StateSnapshotTransformer<S> stateSnapshotTransformer) {
        if (stateSnapshotTransformer == null) {
            return map;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<N, Map<K, S>> entry : map.entrySet()) {
            N key = entry.getKey();
            Map map2 = (Map) hashMap.computeIfAbsent(key, obj -> {
                return new HashMap();
            });
            for (Map.Entry<K, S> entry2 : entry.getValue().entrySet()) {
                K key2 = entry2.getKey();
                S filterOrTransform = stateSnapshotTransformer.filterOrTransform(entry2.getValue());
                if (filterOrTransform != null) {
                    map2.put(key2, filterOrTransform);
                }
            }
            if (map2.isEmpty()) {
                hashMap.remove(key);
            }
        }
        return hashMap;
    }

    private int countMappingsInKeyGroup(Map<N, Map<K, S>> map) {
        int i = 0;
        Iterator<Map<K, S>> it2 = map.values().iterator();
        while (it2.hasNext()) {
            i += it2.next().size();
        }
        return i;
    }
}
