package org.grouplens.lenskit.basic;

import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import javax.annotation.Nullable;
import javax.inject.Inject;
import org.grouplens.lenskit.ItemRecommender;
import org.grouplens.lenskit.ItemScorer;
import org.grouplens.lenskit.scored.ScoredId;
import org.grouplens.lenskit.scored.ScoredIdListBuilder;
import org.grouplens.lenskit.scored.ScoredIds;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.vectors.SparseVector;

/* JADX WARN: Classes with same name are omitted:
  
 */
/* loaded from: input_file:org/grouplens/lenskit/basic/RescoringItemRecommender.class */
public class RescoringItemRecommender implements ItemRecommender {
    public static final Symbol ORIGINAL_SCORE_SYMBOL = Symbol.of("org.grouplens.lenskit.basic.RescoringItemRecommender.ORIGINAL_SCORE");
    private final ItemRecommender delegate;
    private final ItemScorer scorer;

    @Inject
    public RescoringItemRecommender(ItemRecommender itemRecommender, ItemScorer itemScorer) {
        this.delegate = itemRecommender;
        this.scorer = itemScorer;
    }

    public List<ScoredId> recommend(long j) {
        return rescore(j, this.delegate.recommend(j));
    }

    public List<ScoredId> recommend(long j, int i) {
        return rescore(j, this.delegate.recommend(j, i));
    }

    public List<ScoredId> recommend(long j, @Nullable Set<Long> set) {
        return rescore(j, this.delegate.recommend(j, set));
    }

    public List<ScoredId> recommend(long j, int i, @Nullable Set<Long> set, @Nullable Set<Long> set2) {
        return rescore(j, this.delegate.recommend(j, i, set, set2));
    }

    private List<ScoredId> rescore(long j, List<ScoredId> list) {
        if (list.isEmpty()) {
            return Collections.emptyList();
        }
        LongArrayList longArrayList = new LongArrayList(list.size());
        Iterator<ScoredId> it = list.iterator();
        while (it.hasNext()) {
            longArrayList.add(it.next().getId());
        }
        SparseVector score = this.scorer.score(j, longArrayList);
        ScoredIdListBuilder newListBuilder = ScoredIds.newListBuilder(list.size());
        newListBuilder.addChannel(ORIGINAL_SCORE_SYMBOL);
        for (ScoredId scoredId : list) {
            newListBuilder.add(ScoredIds.copyBuilder(scoredId).setScore(score.get(scoredId.getId(), Double.NaN)).addChannel(ORIGINAL_SCORE_SYMBOL, scoredId.getScore()).build());
        }
        return newListBuilder.build();
    }
}
