package ai.libs.jaicore.ml.dyadranking.loss;

import ai.libs.jaicore.ml.core.exception.PredictionException;
import ai.libs.jaicore.ml.dyadranking.algorithm.IDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingDataset;
import ai.libs.jaicore.ml.dyadranking.dataset.DyadRankingInstance;
import ai.libs.jaicore.ml.dyadranking.dataset.IDyadRankingInstance;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/loss/DyadRankingLossUtil.class */
public class DyadRankingLossUtil {
    private DyadRankingLossUtil() {
    }

    public static double computeAverageLoss(DyadRankingLossFunction dyadRankingLossFunction, DyadRankingDataset dyadRankingDataset, DyadRankingDataset dyadRankingDataset2) {
        if (dyadRankingDataset.size() != dyadRankingDataset2.size()) {
            throw new IllegalArgumentException("The list of predictions and the list of ground truth dyad rankings need to have the same length!");
        }
        double d = 0.0d;
        for (int i = 0; i < dyadRankingDataset.size(); i++) {
            d += dyadRankingLossFunction.loss(dyadRankingDataset.get(i), dyadRankingDataset2.get(i));
        }
        return d / dyadRankingDataset.size();
    }

    public static double computeAverageLoss(DyadRankingLossFunction dyadRankingLossFunction, DyadRankingDataset dyadRankingDataset, IDyadRanker iDyadRanker, Random random) throws PredictionException {
        double d = 0.0d;
        for (int i = 0; i < dyadRankingDataset.size(); i++) {
            IDyadRankingInstance iDyadRankingInstance = dyadRankingDataset.get(i);
            ArrayList newArrayList = Lists.newArrayList(iDyadRankingInstance.iterator());
            Collections.shuffle(newArrayList, random);
            d += dyadRankingLossFunction.loss(iDyadRankingInstance, iDyadRanker.predict((IDyadRanker) new DyadRankingInstance(newArrayList)));
        }
        return d / dyadRankingDataset.size();
    }

    public static double computeAverageLoss(DyadRankingLossFunction dyadRankingLossFunction, DyadRankingDataset dyadRankingDataset, IDyadRanker iDyadRanker) throws PredictionException {
        return computeAverageLoss(dyadRankingLossFunction, dyadRankingDataset, iDyadRanker, new Random(0L));
    }
}
