package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/activelearning/DyadDatasetPoolProvider.class */
public class DyadDatasetPoolProvider implements IDyadRankingPoolProvider {
    private List<IDyadRankingInstance> pool;
    private HashSet<IDyadRankingInstance> queriedRankings;
    private int numberQueries = 0;
    private boolean removeDyadsWhenQueried = false;
    private HashMap<Vector, Set<Dyad>> dyadsByInstances = new HashMap<>();
    private HashMap<Vector, Set<Dyad>> dyadsByAlternatives = new HashMap<>();
    private HashMap<Vector, IDyadRankingInstance> dyadRankingsByInstances = new HashMap<>();
    private HashMap<Vector, IDyadRankingInstance> dyadRankingsByAlternatives = new HashMap<>();

    public DyadDatasetPoolProvider(DyadRankingDataset dyadRankingDataset) {
        this.pool = new ArrayList(dyadRankingDataset.size());
        Iterator<IDyadRankingInstance> it = dyadRankingDataset.iterator();
        while (it.hasNext()) {
            addDyadRankingInstance(it.next());
        }
        this.queriedRankings = new HashSet<>();
    }

    @Override // ai.libs.jaicore.ml.activelearning.IActiveLearningPoolProvider
    public Collection<IDyadRankingInstance> getPool() {
        return this.pool;
    }

    @Override // ai.libs.jaicore.ml.activelearning.IActiveLearningPoolProvider
    public IDyadRankingInstance query(IDyadRankingInstance iDyadRankingInstance) {
        this.numberQueries++;
        if (!(iDyadRankingInstance instanceof SparseDyadRankingInstance)) {
            throw new IllegalArgumentException("Currently only supports SparseDyadRankingInstances!");
        }
        SparseDyadRankingInstance sparseDyadRankingInstance = (SparseDyadRankingInstance) iDyadRankingInstance;
        ArrayList arrayList = new ArrayList(sparseDyadRankingInstance.length());
        Iterator<Dyad> it = sparseDyadRankingInstance.iterator();
        while (it.hasNext()) {
            Dyad next = it.next();
            arrayList.add(new Pair(next, Integer.valueOf(getPositionInRankingByInstanceFeatures(next))));
        }
        Collections.sort(arrayList, Comparator.comparing((v0) -> {
            return v0.getRight();
        }));
        ArrayList arrayList2 = new ArrayList(arrayList.size());
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            arrayList2.add(((Pair) it2.next()).getFirst());
        }
        DyadRankingInstance dyadRankingInstance = new DyadRankingInstance(arrayList2);
        if (this.removeDyadsWhenQueried) {
            Iterator it3 = arrayList2.iterator();
            while (it3.hasNext()) {
                removeDyadFromPool((Dyad) it3.next());
            }
        }
        this.queriedRankings.add(dyadRankingInstance);
        return dyadRankingInstance;
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public Set<Dyad> getDyadsByInstance(Vector vector) {
        return !this.dyadsByInstances.containsKey(vector) ? new HashSet() : this.dyadsByInstances.get(vector);
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public Set<Dyad> getDyadsByAlternative(Vector vector) {
        return !this.dyadsByAlternatives.containsKey(vector) ? new HashSet() : this.dyadsByAlternatives.get(vector);
    }

    private void addDyadRankingInstance(IDyadRankingInstance iDyadRankingInstance) {
        this.pool.add(iDyadRankingInstance);
        this.dyadRankingsByInstances.put(iDyadRankingInstance.getDyadAtPosition(0).getInstance(), iDyadRankingInstance);
        this.dyadRankingsByAlternatives.put(iDyadRankingInstance.getDyadAtPosition(0).getAlternative(), iDyadRankingInstance);
        for (Dyad dyad : iDyadRankingInstance) {
            if (!this.dyadsByInstances.containsKey(dyad.getInstance())) {
                this.dyadsByInstances.put(dyad.getInstance(), new HashSet());
            }
            this.dyadsByInstances.get(dyad.getInstance()).add(dyad);
            if (!this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
                this.dyadsByAlternatives.put(dyad.getAlternative(), new HashSet());
            }
            this.dyadsByAlternatives.get(dyad.getAlternative()).add(dyad);
        }
    }

    private int getPositionInRankingByInstanceFeatures(Dyad dyad) {
        if (!this.dyadRankingsByInstances.containsKey(dyad.getInstance())) {
            return -1;
        }
        IDyadRankingInstance iDyadRankingInstance = this.dyadRankingsByInstances.get(dyad.getInstance());
        boolean z = false;
        int i = 0;
        while (i < iDyadRankingInstance.length() && !z) {
            if (iDyadRankingInstance.getDyadAtPosition(i).equals(dyad)) {
                z = true;
            } else {
                i++;
            }
        }
        return i;
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public Collection<Vector> getInstanceFeatures() {
        return this.dyadsByInstances.keySet();
    }

    private void removeDyadFromPool(Dyad dyad) {
        if (this.dyadsByInstances.containsKey(dyad.getInstance())) {
            this.dyadsByInstances.get(dyad.getInstance()).remove(dyad);
            if (this.dyadsByInstances.get(dyad.getInstance()).size() < 2) {
                this.dyadsByInstances.remove(dyad.getInstance());
            }
        }
        if (this.dyadsByAlternatives.containsKey(dyad.getAlternative())) {
            this.dyadsByAlternatives.get(dyad.getAlternative()).remove(dyad);
            if (this.dyadsByAlternatives.get(dyad.getAlternative()).size() < 2) {
                this.dyadsByAlternatives.remove(dyad.getAlternative());
            }
        }
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public void setRemoveDyadsWhenQueried(boolean z) {
        this.removeDyadsWhenQueried = z;
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public int getPoolSize() {
        int i = 0;
        Iterator<Set<Dyad>> it = this.dyadsByInstances.values().iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        return i;
    }

    public int getNumberQueries() {
        return this.numberQueries;
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.IDyadRankingPoolProvider
    public DyadRankingDataset getQueriedRankings() {
        return new DyadRankingDataset((List<IDyadRankingInstance>) new ArrayList(this.queriedRankings));
    }
}
