/*
 * Decompiled with CFR 0.152.
 */
package net.automatalib.util.automaton.ads;

import com.google.common.collect.Maps;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.automatalib.alphabet.Alphabet;
import net.automatalib.automaton.concept.StateIDs;
import net.automatalib.automaton.transducer.MealyMachine;
import net.automatalib.common.smartcollection.ReflexiveMapView;
import net.automatalib.common.util.Pair;
import net.automatalib.graph.ads.ADSLeafNode;
import net.automatalib.graph.ads.ADSNode;
import net.automatalib.graph.ads.ADSSymbolNode;
import net.automatalib.util.automaton.ads.ADS;
import net.automatalib.util.automaton.ads.ADSUtil;
import net.automatalib.util.automaton.ads.SplitTree;
import net.automatalib.word.Word;
import org.checkerframework.checker.nullness.qual.NonNull;

public final class BacktrackingSearch {
    private BacktrackingSearch() {
    }

    public static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, Set<S> states) {
        if (states.size() == 1) {
            return ADS.compute(automaton, input, states);
        }
        SplitTree node = new SplitTree(states, new ReflexiveMapView<S>(states));
        return BacktrackingSearch.compute(automaton, input, node);
    }

    static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, SplitTree<S, I, O> node) {
        return BacktrackingSearch.compute(automaton, input, node, node.getPartition().size());
    }

    private static <S, I, T, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, T, O> automaton, Alphabet<I> input, SplitTree<S, I, O> node, int originalPartitionSize) {
        long maximumSplittingWordLength = ADSUtil.computeMaximumSplittingWordLength(automaton.size(), node.getPartition().size(), originalPartitionSize);
        LinkedList splittingWordCandidates = new LinkedList();
        StateIDs<Object> stateIds = automaton.stateIDs();
        HashSet<BitSet> cache = new HashSet<BitSet>();
        splittingWordCandidates.add(Word.epsilon());
        while (!splittingWordCandidates.isEmpty()) {
            @NonNull Word prefix = (Word)splittingWordCandidates.poll();
            Map currentToInitialMapping = node.getPartition().stream().collect(Collectors.toMap(x -> automaton.getSuccessor(x, (Object)prefix), Function.identity()));
            BitSet currentSetAsBitSet = new BitSet();
            for (Object s : currentToInitialMapping.keySet()) {
                currentSetAsBitSet.set(stateIds.getStateId(s));
            }
            if (cache.contains(currentSetAsBitSet)) continue;
            block2: for (Object i : input) {
                HashMap successors = new HashMap();
                for (Map.Entry entry : currentToInitialMapping.entrySet()) {
                    SplitTree child;
                    Object current = entry.getKey();
                    Object trans = automaton.getTransition(current, i);
                    if (trans == null) {
                        throw new IllegalArgumentException("Partial automata are not supported");
                    }
                    Object nextState = automaton.getSuccessor(trans);
                    Object nextOutput = automaton.getTransitionOutput(trans);
                    if (!successors.containsKey(nextOutput)) {
                        child = new SplitTree(new HashSet());
                        successors.put(nextOutput, child);
                    } else {
                        child = (SplitTree)successors.get(nextOutput);
                    }
                    if (!child.getPartition().add(nextState)) continue block2;
                    child.getMapping().put(nextState, node.getMapping().get(entry.getValue()));
                }
                if (successors.size() > 1) {
                    Map.Entry entry;
                    HashMap results = new HashMap();
                    entry = successors.entrySet().iterator();
                    while (entry.hasNext()) {
                        Map.Entry entry2 = (Map.Entry)entry.next();
                        SplitTree currentNode = (SplitTree)entry2.getValue();
                        BitSet currentNodeAsBitSet = new BitSet();
                        for (Object s : currentNode.getPartition()) {
                            currentNodeAsBitSet.set(stateIds.getStateId(s));
                        }
                        if (cache.contains(currentNodeAsBitSet)) continue block2;
                        Optional<ADSNode<S, I, O>> succ = currentNode.getPartition().size() > 2 ? BacktrackingSearch.compute(automaton, input, currentNode, originalPartitionSize) : ADS.compute(automaton, input, currentNode);
                        if (!succ.isPresent()) {
                            cache.add(currentNodeAsBitSet);
                            continue block2;
                        }
                        results.put(entry2.getKey(), succ.get());
                    }
                    Pair<ADSNode<S, Object, O>, ADSNode<S, Object, O>> ads = ADSUtil.buildFromTrace(automaton, prefix.append(i), node.getPartition().iterator().next());
                    ADSNode<S, Object, O> head = ads.getFirst();
                    ADSNode<S, Object, O> tail = ads.getSecond();
                    for (Map.Entry entry3 : results.entrySet()) {
                        ((ADSNode)entry3.getValue()).setParent(tail);
                        tail.getChildren().put(entry3.getKey(), (ADSNode)entry3.getValue());
                    }
                    return Optional.of(head);
                }
                if ((long)prefix.length() >= maximumSplittingWordLength) continue;
                splittingWordCandidates.add(prefix.append(i));
            }
            cache.add(currentSetAsBitSet);
        }
        return Optional.empty();
    }

    public static <S, I, O> Optional<ADSNode<S, I, O>> computeOptimal(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, Set<S> states, CostAggregator costAggregator) {
        if (states.size() == 1) {
            return ADS.compute(automaton, input, states);
        }
        Optional<SearchState<S, I, O>> searchState = BacktrackingSearch.exploreSearchSpace(automaton, input, states, costAggregator, new HashMap<Set<S>, Optional<SearchState<S, I, O>>>(), new HashSet<Set<S>>(), Integer.MAX_VALUE);
        return searchState.map(s -> BacktrackingSearch.constructADS(automaton, new ReflexiveMapView(states), s));
    }

    private static <S, I, T, O> Optional<SearchState<S, I, O>> exploreSearchSpace(MealyMachine<S, I, T, O> automaton, Alphabet<I> alphabet, Set<S> targets, CostAggregator costAggregator, Map<Set<S>, Optional<SearchState<S, I, O>>> stateCache, Set<Set<S>> currentTraceCache, int costsBound) {
        Optional<SearchState<S, I, O>> cachedValue = stateCache.get(targets);
        if (cachedValue != null) {
            return cachedValue;
        }
        if (currentTraceCache.contains(targets)) {
            return Optional.empty();
        }
        if (targets.size() == 1) {
            SearchState resultSS = new SearchState();
            Optional result = Optional.of(resultSS);
            stateCache.put(targets, result);
            return result;
        }
        if (costsBound == 0) {
            return Optional.empty();
        }
        boolean foundValidSuccessor = false;
        boolean convergingStates = true;
        int bestCosts = costsBound;
        HashMap bestSuccessor = null;
        Object bestInputSymbol = null;
        block0: for (Object i : alphabet) {
            int costsForInputSymbol;
            SearchState subResult;
            Optional potentialResult;
            Map successorsForInputSymbol;
            HashMap successors = new HashMap();
            for (S s : targets) {
                Set<Object> child;
                Object trans = automaton.getTransition(s, i);
                if (trans == null) {
                    throw new IllegalArgumentException("Partial automata are not supported");
                }
                Object nextState = automaton.getSuccessor(trans);
                Object o = automaton.getTransitionOutput(trans);
                if (!successors.containsKey(o)) {
                    child = new HashSet();
                    successors.put(o, child);
                } else {
                    child = (Set)successors.get(o);
                }
                if (child.add(nextState)) continue;
                continue block0;
            }
            convergingStates = false;
            if (successors.size() > 1) {
                successorsForInputSymbol = Maps.newHashMapWithExpectedSize((int)successors.size());
                int partitionCosts = 0;
                for (Map.Entry entry : successors.entrySet()) {
                    potentialResult = BacktrackingSearch.exploreSearchSpace(automaton, alphabet, (Set)entry.getValue(), costAggregator, stateCache, new HashSet<Set<S>>(), bestCosts);
                    if (!potentialResult.isPresent()) continue block0;
                    subResult = potentialResult.get();
                    successorsForInputSymbol.put(entry.getKey(), subResult);
                    if ((partitionCosts = ((Integer)costAggregator.apply(partitionCosts, subResult.costs)).intValue()) < bestCosts) continue;
                    continue block0;
                }
                costsForInputSymbol = partitionCosts;
            } else {
                Map.Entry entry = successors.entrySet().iterator().next();
                Set nextTargets = (Set)entry.getValue();
                HashSet<Set<S>> hashSet = new HashSet<Set<S>>(currentTraceCache);
                hashSet.add(targets);
                potentialResult = BacktrackingSearch.exploreSearchSpace(automaton, alphabet, nextTargets, costAggregator, stateCache, hashSet, bestCosts);
                if (!potentialResult.isPresent()) continue;
                subResult = potentialResult.get();
                costsForInputSymbol = subResult.costs;
                successorsForInputSymbol = Collections.singletonMap(entry.getKey(), subResult);
            }
            if (costsForInputSymbol >= bestCosts) continue;
            foundValidSuccessor = true;
            bestCosts = costsForInputSymbol;
            bestSuccessor = successorsForInputSymbol;
            bestInputSymbol = i;
        }
        if (convergingStates) {
            stateCache.put(targets, Optional.empty());
            return Optional.empty();
        }
        if (!foundValidSuccessor) {
            return Optional.empty();
        }
        SearchState resultSS = new SearchState();
        resultSS.costs = bestCosts + 1;
        resultSS.successors = bestSuccessor;
        resultSS.symbol = bestInputSymbol;
        Optional result = Optional.of(resultSS);
        stateCache.put(targets, result);
        return result;
    }

    private static <S, I, O> ADSNode<S, I, O> constructADS(MealyMachine<S, I, ?, O> automaton, Map<S, S> currentToInitialMapping, SearchState<S, I, O> searchState) {
        if (currentToInitialMapping.size() == 1) {
            return new ADSLeafNode(null, currentToInitialMapping.values().iterator().next());
        }
        Object i = ((SearchState)searchState).symbol;
        HashMap successors = new HashMap();
        for (Map.Entry<S, S> entry : currentToInitialMapping.entrySet()) {
            Map<S, S> nextMapping;
            S current = entry.getKey();
            S nextState = automaton.getSuccessor(current, i);
            Object nextOutput = automaton.getOutput(current, i);
            if (!successors.containsKey(nextOutput)) {
                nextMapping = new HashMap();
                successors.put(nextOutput, nextMapping);
            } else {
                nextMapping = (Map)successors.get(nextOutput);
            }
            if (nextMapping.put(nextState, entry.getValue()) == null) continue;
            throw new IllegalStateException();
        }
        ADSSymbolNode result = new ADSSymbolNode(null, i);
        for (Map.Entry entry : successors.entrySet()) {
            Object output = entry.getKey();
            Map nextMapping = (Map)entry.getValue();
            ADSNode<S, I, O> successor = BacktrackingSearch.constructADS(automaton, nextMapping, (SearchState)((SearchState)searchState).successors.get(output));
            result.getChildren().put(output, successor);
            successor.setParent(result);
        }
        return result;
    }

    private static class SearchState<S, I, O> {
        private I symbol;
        private Map<O, SearchState<S, I, O>> successors;
        private int costs;

        private SearchState() {
        }
    }

    public static enum CostAggregator implements BiFunction<Integer, Integer, Integer>
    {
        MIN_LENGTH{

            @Override
            public Integer apply(Integer oldValue, Integer newValue) {
                return Math.max(oldValue, newValue);
            }
        }
        ,
        MIN_SIZE{

            @Override
            public Integer apply(Integer oldValue, Integer newValue) {
                return oldValue + newValue;
            }
        };

    }
}

