/*
 * Decompiled with CFR 0.152.
 */
package eu.fbk.utils.eval;

import com.google.common.base.Preconditions;
import com.google.common.collect.HashBasedTable;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import eu.fbk.utils.eval.PrecisionRecall;
import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.BiFunction;
import javax.annotation.Nullable;

public final class SetPrecisionRecall
implements Serializable {
    private static final long serialVersionUID = 1L;
    private final PrecisionRecall exactPR;
    private final PrecisionRecall overlapPR;
    private final PrecisionRecall intersectionPR;
    private final PrecisionRecall alignedPR;

    public SetPrecisionRecall(PrecisionRecall exactPR, PrecisionRecall overlapPR, PrecisionRecall intersectionPR, PrecisionRecall alignedPR) {
        Preconditions.checkNotNull(exactPR);
        Preconditions.checkNotNull(overlapPR);
        Preconditions.checkNotNull(intersectionPR);
        Preconditions.checkNotNull(alignedPR);
        this.exactPR = exactPR;
        this.overlapPR = overlapPR;
        this.intersectionPR = intersectionPR;
        this.alignedPR = alignedPR;
    }

    public PrecisionRecall getExactPR() {
        return this.exactPR;
    }

    public PrecisionRecall getOverlapPR() {
        return this.overlapPR;
    }

    public PrecisionRecall getIntersectionPR() {
        return this.intersectionPR;
    }

    public PrecisionRecall getAlignedPR() {
        return this.alignedPR;
    }

    public static <T> T[] newArray(Class<T> clazz, Object ... elements) {
        T[] array = SetPrecisionRecall.newArray(clazz, elements.length);
        for (int i = 0; i < elements.length; ++i) {
            array[i] = elements[i];
        }
        return array;
    }

    public static <T> T[] newArray(Class<T> clazz, int length) {
        return (Object[])Array.newInstance(clazz, length);
    }

    public static <T> T[][] newArray(Class<T> clazz, int length1, int length2) {
        Class<?> elementClass = SetPrecisionRecall.newArray(clazz, 0).getClass();
        Object[][] array = (Object[][])SetPrecisionRecall.newArray(elementClass, length1);
        if (length2 > 0) {
            for (int i = 0; i < length1; ++i) {
                array[i] = SetPrecisionRecall.newArray(clazz, length2);
            }
        }
        return array;
    }

    public static <T, E extends T> T[][] align(Class<T> clazz, Iterable<E> objects1, Iterable<E> objects2, boolean functional, boolean invFunctional, boolean emitUnaligned, BiFunction<? super E, ? super E, ?> matcher) {
        HashBasedTable<E, E, AlignPair<E>> table = HashBasedTable.create();
        for (E object1 : objects1) {
            for (E object2 : objects2) {
                Object similarity = matcher.apply(object1, object2);
                table.put(object1, object2, new AlignPair<E>(object1, object2, similarity));
            }
        }
        HashSet<E> set1 = emitUnaligned ? Sets.newHashSet(objects1) : null;
        HashSet<E> set2 = emitUnaligned ? Sets.newHashSet(objects2) : null;
        ArrayList<T[]> pairs = Lists.newArrayList();
        while (!table.isEmpty()) {
            AlignPair bestPair = (AlignPair)Ordering.natural().max(table.values());
            if (bestPair.similarity == null) break;
            T[] pair = SetPrecisionRecall.newArray(clazz, bestPair.object1, bestPair.object2);
            pairs.add(pair);
            if (functional) {
                table.rowKeySet().remove(bestPair.object1);
            }
            if (invFunctional) {
                table.columnKeySet().remove(bestPair.object2);
            }
            if (!functional && !invFunctional) {
                table.remove(bestPair.object1, bestPair.object2);
            }
            if (!emitUnaligned) continue;
            set1.remove(bestPair.object1);
            set2.remove(bestPair.object2);
        }
        if (emitUnaligned) {
            for (Object object1 : set1) {
                pairs.add(SetPrecisionRecall.newArray(clazz, object1, null));
            }
            for (Object object2 : set2) {
                pairs.add(SetPrecisionRecall.newArray(clazz, null, object2));
            }
        }
        return (Object[][])pairs.toArray((T[])SetPrecisionRecall.newArray(clazz, pairs.size(), -1));
    }

    public boolean equals(Object object) {
        if (object == this) {
            return true;
        }
        if (!(object instanceof SetPrecisionRecall)) {
            return false;
        }
        SetPrecisionRecall other = (SetPrecisionRecall)object;
        return this.exactPR.equals(other.exactPR) && this.overlapPR.equals(other.overlapPR) && this.intersectionPR.equals(other.intersectionPR) && this.alignedPR.equals(other.alignedPR);
    }

    public int hashCode() {
        return Objects.hash(this.exactPR, this.overlapPR, this.intersectionPR, this.alignedPR);
    }

    public String toString() {
        return "exact:        " + this.exactPR + "\noverlap:      " + this.overlapPR + "\nintersection: " + this.intersectionPR + "\naligned:      " + this.alignedPR;
    }

    public static <T> BiFunction<Set<T>, Set<T>, List<Double>> matcher() {
        return (g, t) -> {
            ArrayList<Double> scores = Lists.newArrayListWithCapacity(2);
            Sets.SetView intersection = Sets.intersection(g, t);
            if (intersection.isEmpty()) {
                return null;
            }
            scores.add((double)intersection.size() / (double)g.size());
            scores.add((double)intersection.size() / (double)t.size());
            return scores;
        };
    }

    public static <T, L> BiFunction<Map.Entry<Set<T>, L>, Map.Entry<Set<T>, L>, List<Double>> matcherLabelled() {
        return (ge, te) -> {
            Set t;
            if (!Objects.equals(ge.getValue(), te.getValue())) {
                return null;
            }
            Set g = (Set)ge.getKey();
            Sets.SetView intersection = Sets.intersection(g, t = (Set)te.getKey());
            if (intersection.isEmpty()) {
                return null;
            }
            ArrayList<Double> scores = Lists.newArrayListWithCapacity(2);
            scores.add((double)intersection.size() / (double)g.size());
            scores.add((double)intersection.size() / (double)t.size());
            return scores;
        };
    }

    public static Evaluator evaluator() {
        return new Evaluator();
    }

    public static final class Evaluator {
        private static final Object DUMMY_LABEL = new Object();
        private final PrecisionRecall.Evaluator exactEvaluator = PrecisionRecall.evaluator();
        private double overlapP = 0.0;
        private double overlapR = 0.0;
        private double intersectionP = 0.0;
        private double intersectionR = 0.0;
        private double alignedP = 0.0;
        private double alignedR = 0.0;
        @Nullable
        private SetPrecisionRecall score = null;

        private Evaluator() {
        }

        public <T> Evaluator add(Iterable<Set<T>> goldSets, Iterable<Set<T>> testSets) {
            ImmutableMap.Builder<Set<T>, Object> goldBuilder = ImmutableMap.builder();
            ImmutableMap.Builder<Set<T>, Object> testBuilder = ImmutableMap.builder();
            for (Set<T> set : goldSets) {
                goldBuilder.put(set, DUMMY_LABEL);
            }
            for (Set<T> set : testSets) {
                testBuilder.put(set, DUMMY_LABEL);
            }
            return this.add(goldBuilder.build(), testBuilder.build());
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public <T, L> Evaluator add(Map<Set<T>, L> goldMap, Map<Set<T>, L> testMap) {
            Map.Entry[][] pairs;
            HashSet<T> intersection;
            int goldSize = goldMap.size();
            int testSize = testMap.size();
            int exactTP = 0;
            for (Map.Entry<Set<T>, L> goldEntry : goldMap.entrySet()) {
                Iterator<Map.Entry<Set<Object>, Object>> testLabel;
                Set<T> set = goldEntry.getKey();
                L goldLabel = goldEntry.getValue();
                if (!testMap.containsKey(set) || !Objects.equals(goldLabel, testLabel = testMap.get(set))) continue;
                ++exactTP;
            }
            double overlapP = 0.0;
            double overlapR = 0.0;
            block4: for (Map.Entry<Set<T>, L> entry : testMap.entrySet()) {
                for (Map.Entry<Set<T>, L> ge : goldMap.entrySet()) {
                    if (!Objects.equals(entry.getValue(), ge.getValue()) || Sets.intersection(entry.getKey(), ge.getKey()).isEmpty()) continue;
                    overlapP += 1.0;
                    continue block4;
                }
            }
            block6: for (Map.Entry<Set<Object>, Object> entry : goldMap.entrySet()) {
                for (Map.Entry<Set<T>, L> te : testMap.entrySet()) {
                    if (!Objects.equals(te.getValue(), entry.getValue()) || Sets.intersection(te.getKey(), entry.getKey()).isEmpty()) continue;
                    overlapR += 1.0;
                    continue block6;
                }
            }
            double intersectionP = 0.0;
            double intersectionR = 0.0;
            for (Map.Entry<Set<T>, L> te : testMap.entrySet()) {
                intersection = Sets.newHashSet();
                for (Map.Entry<Set<T>, L> ge : goldMap.entrySet()) {
                    if (!Objects.equals(te.getValue(), ge.getValue())) continue;
                    intersection.addAll(Sets.intersection(te.getKey(), ge.getKey()));
                }
                intersectionP += (double)intersection.size() / (double)te.getKey().size();
            }
            for (Map.Entry<Set<T>, L> ge : goldMap.entrySet()) {
                intersection = Sets.newHashSet();
                for (Map.Entry<Set<T>, L> te : testMap.entrySet()) {
                    if (!Objects.equals(te.getValue(), ge.getValue())) continue;
                    intersection.addAll(Sets.intersection(te.getKey(), ge.getKey()));
                }
                intersectionR += (double)intersection.size() / (double)ge.getKey().size();
            }
            double alignedP = 0.0;
            double alignedR = 0.0;
            for (Map.Entry[] pair : pairs = SetPrecisionRecall.align(Map.Entry.class, goldMap.entrySet(), testMap.entrySet(), true, true, true, SetPrecisionRecall.matcherLabelled())) {
                if (pair[0] == null || pair[1] == null) continue;
                Set g = (Set)pair[0].getKey();
                Set t = (Set)pair[1].getKey();
                double intersection2 = Sets.intersection(g, t).size();
                alignedP += intersection2 / (double)t.size();
                alignedR += intersection2 / (double)g.size();
            }
            Evaluator evaluator = this;
            synchronized (evaluator) {
                this.score = null;
                this.exactEvaluator.add(exactTP, testSize - exactTP, goldSize - exactTP);
                this.overlapP += overlapP;
                this.overlapR += overlapR;
                this.intersectionP += intersectionP;
                this.intersectionR += intersectionR;
                this.alignedP += alignedP;
                this.alignedR += alignedR;
            }
            return this;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public synchronized Evaluator add(SetPrecisionRecall spr) {
            SetPrecisionRecall setPrecisionRecall = spr;
            synchronized (setPrecisionRecall) {
                double tn = spr.getExactPR().getTP() + spr.getExactPR().getFP();
                double gn = spr.getExactPR().getTP() + spr.getExactPR().getFN();
                this.score = null;
                this.exactEvaluator.add(spr.getExactPR());
                this.overlapP += spr.getOverlapPR().getPrecision() * tn;
                this.overlapR += spr.getOverlapPR().getRecall() * gn;
                this.intersectionP += spr.getIntersectionPR().getPrecision() * tn;
                this.intersectionR += spr.getIntersectionPR().getRecall() * gn;
                this.alignedP += spr.getAlignedPR().getPrecision() * tn;
                this.alignedR += spr.getAlignedPR().getRecall() * gn;
            }
            return this;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public synchronized Evaluator add(Evaluator evaluator) {
            Evaluator evaluator2 = evaluator;
            synchronized (evaluator2) {
                this.score = null;
                this.exactEvaluator.add(evaluator.exactEvaluator);
                this.overlapP += evaluator.overlapP;
                this.overlapR += evaluator.overlapR;
                this.intersectionP += evaluator.intersectionP;
                this.intersectionR += evaluator.intersectionR;
                this.alignedP += evaluator.alignedP;
                this.alignedR += evaluator.alignedR;
            }
            return this;
        }

        public synchronized SetPrecisionRecall getResult() {
            if (this.score == null) {
                PrecisionRecall exactPR = this.exactEvaluator.getResult();
                double tn = exactPR.getTP() + exactPR.getFP();
                double gn = exactPR.getTP() + exactPR.getFN();
                PrecisionRecall overlapPR = PrecisionRecall.forMeasures(this.overlapP / tn, this.overlapR / gn, 1.0 / (2.0 * (tn + gn) / (this.overlapP + this.overlapR) - 1.0));
                PrecisionRecall intersectionPR = PrecisionRecall.forMeasures(this.intersectionP / tn, this.intersectionR / gn, 1.0 / (2.0 * (tn + gn) / (this.intersectionP + this.intersectionR) - 1.0));
                PrecisionRecall alignedPR = PrecisionRecall.forMeasures(this.alignedP / tn, this.alignedR / gn, 1.0 / (2.0 * (tn + gn) / (this.alignedP + this.alignedR) - 1.0));
                this.score = new SetPrecisionRecall(exactPR, overlapPR, intersectionPR, alignedPR);
            }
            return this.score;
        }

        public String toString() {
            return this.getResult().toString();
        }
    }

    private static final class AlignPair<T>
    implements Comparable<AlignPair<?>> {
        final T object1;
        final T object2;
        @Nullable
        Object similarity;

        public AlignPair(T object1, T object2, Object similarity) {
            this.object1 = object1;
            this.object2 = object2;
            this.similarity = similarity;
        }

        @Override
        public int compareTo(AlignPair other) {
            if (this.similarity == null) {
                return other.similarity == null ? 0 : -1;
            }
            if (other.similarity == null) {
                return 1;
            }
            if (this.similarity instanceof Comparable) {
                return ((Comparable)this.similarity).compareTo(other.similarity);
            }
            if (this.similarity instanceof Iterable) {
                return Ordering.natural().lexicographical().compare((Iterable)this.similarity, (Iterable)other.similarity);
            }
            throw new IllegalArgumentException("Could not compare similarities " + this.similarity + ", " + other.similarity);
        }
    }
}

