package org.linqs.psl.reasoner.function;

import org.linqs.psl.model.atom.GroundAtom;

/* loaded from: input_file:org/linqs/psl/reasoner/function/GeneralFunction.class */
public class GeneralFunction implements FunctionTerm {
    private final float[] coefficients;
    private final FunctionTerm[] terms;
    private int size = 0;
    private float constant = 0.0f;
    private boolean constantTerms = true;
    private boolean linearTerms = true;
    private boolean nonNegative;
    private boolean squared;

    public GeneralFunction(boolean z, boolean z2, int i) {
        this.coefficients = new float[i];
        this.terms = new FunctionTerm[i];
        this.nonNegative = z;
        this.squared = z2;
    }

    public float getConstant() {
        return this.constant;
    }

    public boolean isSquared() {
        return this.squared;
    }

    public boolean isNonNegative() {
        return this.nonNegative;
    }

    @Override // org.linqs.psl.reasoner.function.FunctionTerm
    public boolean isLinear() {
        return !this.squared && this.linearTerms;
    }

    @Override // org.linqs.psl.reasoner.function.FunctionTerm
    public boolean isConstant() {
        return this.constantTerms;
    }

    public void setSquared(boolean z) {
        this.squared = z;
    }

    public void setNonNegative(boolean z) {
        this.nonNegative = z;
    }

    public void add(float f) {
        this.constant += f;
    }

    public void add(float f, FunctionTerm functionTerm) {
        if (functionTerm.isConstant()) {
            this.constant += f * functionTerm.getValue();
            return;
        }
        if (this.size == this.terms.length) {
            throw new IllegalStateException("More than the max terms added to the function. Max: " + this.terms.length);
        }
        this.terms[this.size] = functionTerm;
        this.coefficients[this.size] = f;
        this.size++;
        this.constantTerms = this.constantTerms && functionTerm.isConstant();
        this.linearTerms = this.linearTerms && functionTerm.isLinear();
    }

    public int size() {
        return this.size;
    }

    public float getCoefficient(int i) {
        return this.coefficients[i];
    }

    public FunctionTerm getTerm(int i) {
        return this.terms[i];
    }

    @Override // org.linqs.psl.reasoner.function.FunctionTerm
    public float getValue() {
        float f = this.constant;
        for (int i = 0; i < this.size; i++) {
            f += this.terms[i].getValue() * this.coefficients[i];
        }
        if (!this.nonNegative || f >= 0.0d) {
            return this.squared ? f * f : f;
        }
        return 0.0f;
    }

    public float getValue(float[] fArr) {
        float f = this.constant;
        for (int i = 0; i < this.size; i++) {
            f += this.coefficients[i] * fArr[i];
        }
        if (!this.nonNegative || f >= 0.0d) {
            return this.squared ? f * f : f;
        }
        return 0.0f;
    }

    public float getValue(GroundAtom groundAtom, float f) {
        float f2;
        float f3;
        float value;
        float f4 = this.constant;
        for (int i = 0; i < this.size; i++) {
            FunctionTerm functionTerm = this.terms[i];
            float f5 = this.coefficients[i];
            if (functionTerm == groundAtom) {
                f2 = f4;
                f3 = f5;
                value = f;
            } else {
                f2 = f4;
                f3 = f5;
                value = functionTerm.getValue();
            }
            f4 = f2 + (f3 * value);
        }
        if (!this.nonNegative || f4 >= 0.0d) {
            return this.squared ? f4 * f4 : f4;
        }
        return 0.0f;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        if (this.nonNegative) {
            sb.append("max(0.0, ");
        } else {
            sb.append("(");
        }
        sb.append(this.constant);
        for (int i = 0; i < this.size; i++) {
            FunctionTerm functionTerm = this.terms[i];
            float f = this.coefficients[i];
            sb.append(" + ");
            sb.append("" + f + " * " + functionTerm.toString());
        }
        sb.append(")");
        if (this.squared) {
            sb.append("^2");
        }
        return sb.toString();
    }
}
