/*
 * Decompiled with CFR 0.152.
 */
package org.graalvm.compiler.phases.common;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import org.graalvm.collections.EconomicMap;
import org.graalvm.collections.EconomicSet;
import org.graalvm.collections.Equivalence;
import org.graalvm.collections.MapCursor;
import org.graalvm.compiler.core.common.memory.MemoryExtendKind;
import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.debug.Assertions;
import org.graalvm.compiler.debug.GraalError;
import org.graalvm.compiler.graph.Node;
import org.graalvm.compiler.nodeinfo.InputType;
import org.graalvm.compiler.nodes.FixedWithNextNode;
import org.graalvm.compiler.nodes.GraphState;
import org.graalvm.compiler.nodes.NodeView;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.calc.IntegerConvertNode;
import org.graalvm.compiler.nodes.calc.NarrowNode;
import org.graalvm.compiler.nodes.calc.SignExtendNode;
import org.graalvm.compiler.nodes.calc.ZeroExtendNode;
import org.graalvm.compiler.nodes.memory.ExtendableMemoryAccess;
import org.graalvm.compiler.nodes.spi.LoweringProvider;
import org.graalvm.compiler.phases.BasePhase;
import org.graalvm.compiler.phases.tiers.LowTierContext;

public class OptimizeExtendsPhase
extends BasePhase<LowTierContext> {
    private static final int UNSET = -1;

    @Override
    public Optional<BasePhase.NotApplicable> canApply(GraphState graphState) {
        return BasePhase.NotApplicable.mustRunAfter(this, GraphState.StageFlag.FINAL_CANONICALIZATION, graphState);
    }

    @Override
    protected void run(StructuredGraph graph, LowTierContext context) {
        if (!context.getLowerer().narrowsUseCastValue()) {
            return;
        }
        int origNumExtends = 0;
        EconomicSet defsWithExtends = EconomicSet.create((Equivalence)Equivalence.DEFAULT);
        for (Node node : graph.getNodes().filter(OptimizeExtendsPhase::isExtendNode)) {
            IntegerConvertNode extend = (IntegerConvertNode)node;
            ++origNumExtends;
            assert (extend.getInputBits() < extend.getResultBits());
            defsWithExtends.add((Object)extend.getValue());
        }
        EconomicMap extendReplacements = EconomicMap.create((Equivalence)Equivalence.DEFAULT);
        EconomicSet addedNarrows = Assertions.assertionsEnabled() ? EconomicSet.create((Equivalence)Equivalence.DEFAULT) : null;
        for (ValueNode origDef : defsWithExtends) {
            int inputBitsSize = -1;
            int maxZeroExtend = -1;
            int maxSignExtend = -1;
            ValueNode def = origDef;
            if (def instanceof IntegerConvertNode && extendReplacements.containsKey((Object)def)) {
                def = (ValueNode)extendReplacements.get((Object)def);
            }
            List uses = def.usages().filter(OptimizeExtendsPhase::isExtendNode).snapshot();
            boolean hasRedundantExtends = false;
            for (Node n : uses) {
                IntegerConvertNode use = (IntegerConvertNode)n;
                int inputBits = use.getInputBits();
                int resultBits = use.getResultBits();
                if (inputBitsSize == -1) {
                    inputBitsSize = inputBits;
                } else {
                    GraalError.guarantee(inputBitsSize == inputBits, "Unexpected input bits size: %s. Expected size: %s", (Object)inputBits, (Object)inputBitsSize);
                }
                if (use instanceof ZeroExtendNode) {
                    hasRedundantExtends |= maxZeroExtend != -1;
                    maxZeroExtend = Integer.max(maxZeroExtend, resultBits);
                    continue;
                }
                assert (use instanceof SignExtendNode);
                hasRedundantExtends |= maxSignExtend != -1;
                maxSignExtend = Integer.max(maxSignExtend, resultBits);
            }
            if (!(def instanceof ExtendableMemoryAccess) && !hasRedundantExtends) continue;
            ValueNode newZeroExtend = null;
            ValueNode newSignExtend = null;
            ValueNode extendInput = def;
            if (def instanceof ExtendableMemoryAccess) {
                MemoryExtendKind extendKind;
                ExtendableMemoryAccess access = (ExtendableMemoryAccess)((Object)def);
                FixedWithNextNode extendedDef = null;
                MemoryExtendKind memoryExtendKind = extendKind = maxZeroExtend == -1 ? MemoryExtendKind.DEFAULT : MemoryExtendKind.getZeroExtendKind(maxZeroExtend);
                if (extendKind.isExtended() && context.getLowerer().supportsFoldingExtendIntoAccess(access, extendKind)) {
                    extendedDef = graph.add(access.copyWithExtendKind(extendKind));
                    newZeroExtend = extendedDef;
                } else {
                    MemoryExtendKind memoryExtendKind2 = extendKind = maxSignExtend == -1 ? MemoryExtendKind.DEFAULT : MemoryExtendKind.getSignExtendKind(maxSignExtend);
                    if (extendKind.isExtended() && context.getLowerer().supportsFoldingExtendIntoAccess(access, extendKind)) {
                        extendedDef = graph.add(access.copyWithExtendKind(extendKind));
                        newSignExtend = extendedDef;
                    }
                }
                if (extendedDef != null) {
                    extendInput = graph.addOrUnique(new NarrowNode(extendedDef, inputBitsSize));
                    def.replaceAtUsages((Node)extendInput, InputType.Value);
                    graph.replaceFixedWithFixed(access.asFixedWithNextNode(), extendedDef);
                }
            }
            if (extendInput == def && !hasRedundantExtends) continue;
            if (maxZeroExtend != -1 && newZeroExtend == null) {
                newZeroExtend = graph.addOrUnique(new ZeroExtendNode(extendInput, inputBitsSize, maxZeroExtend, false));
            }
            if (maxSignExtend != -1 && newSignExtend == null) {
                newSignExtend = graph.addOrUnique(new SignExtendNode(extendInput, inputBitsSize, maxSignExtend));
            }
            for (Node n : uses) {
                ValueNode replacement;
                int replacementBits;
                IntegerConvertNode use = (IntegerConvertNode)n;
                if (use instanceof ZeroExtendNode) {
                    assert (newZeroExtend != null);
                    replacementBits = maxZeroExtend;
                    replacement = newZeroExtend;
                } else {
                    assert (newSignExtend != null);
                    replacementBits = maxSignExtend;
                    replacement = newSignExtend;
                }
                int resultBits = use.getResultBits();
                if (resultBits != replacementBits) {
                    assert (replacementBits > resultBits);
                    replacement = graph.addOrUnique(new NarrowNode(replacement, replacementBits, resultBits));
                    if (Assertions.assertionsEnabled()) {
                        addedNarrows.add((Object)replacement);
                    }
                }
                if (use == replacement) continue;
                use.replaceAtUsagesAndDelete(replacement);
                if (!defsWithExtends.contains((Object)use)) continue;
                extendReplacements.put((Object)use, (Object)replacement);
            }
            if (extendInput == def || !extendInput.hasNoUsages()) continue;
            extendInput.safeDelete();
        }
        assert (OptimizeExtendsPhase.validateOptimization(graph, context.getLowerer(), origNumExtends, (EconomicSet<ValueNode>)addedNarrows));
    }

    private static boolean validateOptimization(StructuredGraph graph, LoweringProvider lowerer, int origNumExtends, EconomicSet<ValueNode> addedNarrows) {
        int numExtends = graph.getNodes().filter(OptimizeExtendsPhase::isExtendNode).count();
        assert (numExtends <= origNumExtends);
        EconomicMap extendMap = EconomicMap.create((Equivalence)Equivalence.DEFAULT);
        for (Node node : graph.getNodes().filter(OptimizeExtendsPhase::isExtendNode)) {
            IntegerConvertNode extend = (IntegerConvertNode)node;
            ValueNode def = extend.getValue();
            while (def instanceof NarrowNode && addedNarrows.contains((Object)def)) {
                def = ((NarrowNode)def).getValue();
            }
            if (def instanceof ExtendableMemoryAccess && ((ExtendableMemoryAccess)((Object)def)).extendsAccess()) continue;
            ArrayList<IntegerConvertNode> value = (ArrayList<IntegerConvertNode>)extendMap.get((Object)def);
            if (value == null) {
                value = new ArrayList<IntegerConvertNode>();
                extendMap.put((Object)def, value);
            }
            value.add(extend);
        }
        MapCursor entries = extendMap.getEntries();
        while (entries.advance()) {
            ValueNode def = (ValueNode)entries.getKey();
            List extendNodes = (List)entries.getValue();
            assert (extendNodes.size() <= 2);
            if (extendNodes.size() == 2) {
                boolean firstIsZeroExtend = extendNodes.get(0) instanceof ZeroExtendNode;
                boolean secondIsZeroExtend = extendNodes.get(1) instanceof ZeroExtendNode;
                assert (firstIsZeroExtend ^ secondIsZeroExtend);
            }
            if (!(def instanceof ExtendableMemoryAccess)) continue;
            ExtendableMemoryAccess access = (ExtendableMemoryAccess)((Object)def);
            for (IntegerConvertNode extend : extendNodes) {
                MemoryExtendKind extendKind;
                int extendSize = extend.getResultBits();
                if (extend instanceof ZeroExtendNode) {
                    extendKind = MemoryExtendKind.getZeroExtendKind(extendSize);
                } else {
                    assert (extend instanceof SignExtendNode);
                    extendKind = MemoryExtendKind.getSignExtendKind(extendSize);
                }
                assert (!lowerer.supportsFoldingExtendIntoAccess(access, extendKind));
            }
        }
        return true;
    }

    private static boolean isExtendNode(Node node) {
        return (node instanceof ZeroExtendNode || node instanceof SignExtendNode) && ((IntegerConvertNode)node).stamp(NodeView.DEFAULT) instanceof IntegerStamp;
    }
}

