package org.tribuo.common.tree;

import com.google.protobuf.Any;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.tribuo.Example;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.common.tree.TreeModel;
import org.tribuo.common.tree.protos.LeafNodeProto;
import org.tribuo.common.tree.protos.TreeNodeProto;
import org.tribuo.math.la.SparseVector;
import org.tribuo.protos.core.OutputProto;

/* loaded from: input_file:org/tribuo/common/tree/LeafNode.class */
public class LeafNode<T extends Output<T>> implements Node<T> {
    private static final long serialVersionUID = 4;
    public static final int CURRENT_VERSION = 0;
    private final double impurity;
    private final T output;
    private final Map<String, T> scores;
    private final boolean generatesProbabilities;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/tribuo/common/tree/LeafNode$LeafNodeBuilder.class */
    public static final class LeafNodeBuilder<T extends Output<T>> extends TreeModel.NodeBuilder implements Node<T> {
        private final int parentIdx;
        private final int curIdx;
        private final double impurity;
        private final T output;
        private final Map<String, T> scores;
        private final boolean generatesProbabilities;

        /* JADX INFO: Access modifiers changed from: package-private */
        public LeafNodeBuilder(LeafNodeProto leafNodeProto) {
            this.parentIdx = leafNodeProto.getParentIdx();
            this.curIdx = leafNodeProto.getCurIdx();
            this.impurity = leafNodeProto.getImpurity();
            this.output = (T) Output.deserialize(leafNodeProto.getOutput());
            this.scores = new HashMap();
            for (Map.Entry<String, OutputProto> entry : leafNodeProto.getScoreMap().entrySet()) {
                Output deserialize = Output.deserialize(entry.getValue());
                if (!deserialize.getClass().equals(this.output.getClass())) {
                    throw new IllegalStateException("Invalid protobuf, scores were not the same type as the most likely output, found " + deserialize.getClass() + ", expected " + this.output.getClass());
                }
                this.scores.put(entry.getKey(), deserialize);
            }
            this.generatesProbabilities = leafNodeProto.getGeneratesProbabilities();
        }

        LeafNodeBuilder(int i, int i2, double d, T t, Map<String, T> map, boolean z) {
            this.parentIdx = i;
            this.curIdx = i2;
            this.impurity = d;
            this.output = t;
            this.scores = map;
            this.generatesProbabilities = z;
        }

        @Override // org.tribuo.common.tree.Node
        public boolean isLeaf() {
            return true;
        }

        @Override // org.tribuo.common.tree.Node
        public Node<T> getNextNode(SparseVector sparseVector) {
            return null;
        }

        @Override // org.tribuo.common.tree.Node
        public double getImpurity() {
            return this.impurity;
        }

        @Override // org.tribuo.common.tree.Node
        public LeafNodeBuilder<T> copy() {
            return new LeafNodeBuilder<>(this.parentIdx, this.curIdx, this.impurity, this.output.copy(), new HashMap(this.scores), this.generatesProbabilities);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // org.tribuo.common.tree.TreeModel.NodeBuilder
        public int getParentIdx() {
            return this.parentIdx;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // org.tribuo.common.tree.TreeModel.NodeBuilder
        public int getCurIdx() {
            return this.curIdx;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // org.tribuo.common.tree.TreeModel.NodeBuilder
        public LeafNode<T> build() {
            return new LeafNode<>(this.impurity, this.output.copy(), new HashMap(this.scores), this.generatesProbabilities);
        }
    }

    public LeafNode(double d, T t, Map<String, T> map, boolean z) {
        this.impurity = d;
        this.output = t;
        this.scores = Collections.unmodifiableMap(map);
        this.generatesProbabilities = z;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        LeafNode leafNode = (LeafNode) obj;
        if (!this.output.getClass().equals(leafNode.output.getClass()) || !this.scores.keySet().equals(leafNode.scores.keySet())) {
            return false;
        }
        boolean z = true;
        for (Map.Entry<String, T> entry : this.scores.entrySet()) {
            z &= entry.getValue().fullEquals(leafNode.scores.get(entry.getKey()));
        }
        return z && Double.compare(leafNode.impurity, this.impurity) == 0 && this.generatesProbabilities == leafNode.generatesProbabilities && this.output.fullEquals(leafNode.output);
    }

    public int hashCode() {
        return Objects.hash(Double.valueOf(this.impurity), this.output, this.scores, Boolean.valueOf(this.generatesProbabilities));
    }

    @Override // org.tribuo.common.tree.Node
    public Node<T> getNextNode(SparseVector sparseVector) {
        return null;
    }

    @Override // org.tribuo.common.tree.Node
    public boolean isLeaf() {
        return true;
    }

    @Override // org.tribuo.common.tree.Node
    public double getImpurity() {
        return this.impurity;
    }

    @Override // org.tribuo.common.tree.Node
    public LeafNode<T> copy() {
        return new LeafNode<>(this.impurity, this.output.copy(), new HashMap(this.scores), this.generatesProbabilities);
    }

    public T getOutput() {
        return this.output;
    }

    public Map<String, T> getDistribution() {
        return this.scores;
    }

    public Prediction<T> getPrediction(int i, Example<T> example) {
        return new Prediction<>(this.output, this.scores, i, example, this.generatesProbabilities);
    }

    public String toString() {
        return "LeafNode(impurity=" + this.impurity + ",output=" + this.output.toString() + ",scores=" + this.scores.toString() + ",probability=" + this.generatesProbabilities + ")";
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public TreeNodeProto serialize(int i, int i2) {
        LeafNodeProto.Builder newBuilder = LeafNodeProto.newBuilder();
        newBuilder.setParentIdx(i);
        newBuilder.setCurIdx(i2);
        newBuilder.setOutput((OutputProto) this.output.serialize());
        for (Map.Entry<String, T> entry : this.scores.entrySet()) {
            newBuilder.putScore(entry.getKey(), (OutputProto) entry.getValue().serialize());
        }
        newBuilder.setGeneratesProbabilities(this.generatesProbabilities);
        newBuilder.setImpurity(this.impurity);
        TreeNodeProto.Builder newBuilder2 = TreeNodeProto.newBuilder();
        newBuilder2.setVersion(0);
        newBuilder2.setClassName(LeafNode.class.getName());
        newBuilder2.setSerializedData(Any.pack(newBuilder.m51build()));
        return newBuilder2.m193build();
    }
}
