package org.renjin.stats.internals.models;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.renjin.eval.EvalException;
import org.renjin.repackaged.guava.collect.Lists;
import org.renjin.sexp.FunctionCall;
import org.renjin.sexp.ListVector;
import org.renjin.sexp.PairList;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Symbol;
import org.renjin.sexp.Vector;

/* loaded from: input_file:WEB-INF/lib/renjin-core-0.8.2415.jar:org/renjin/stats/internals/models/FormulaInterpreter.class */
public class FormulaInterpreter {
    private static final Symbol TILDE = Symbol.get("~");
    private static final Symbol UNION = Symbol.get("+");
    private static final Symbol EXPAND_TERMS = Symbol.get("*");
    private static final Symbol DIFFERENCE = Symbol.get("-");
    private static final Symbol GROUP = Symbol.get("(");
    private static final Symbol DOT = Symbol.get(".");
    private SEXP response;
    private int intercept = 1;
    private ListVector dataFrame = null;
    private boolean allowDotAsName = false;

    public Formula interpret(FunctionCall functionCall) {
        SEXP expandPredictor;
        FunctionCall newCall;
        if (functionCall.getFunction() != TILDE) {
            throw new EvalException("expected model formula (~)", new Object[0]);
        }
        if (functionCall.getArguments().length() == 1) {
            this.response = null;
            expandPredictor = expandPredictor(functionCall.getArgument(0), null);
            newCall = FunctionCall.newCall(TILDE, expandPredictor);
        } else {
            if (functionCall.getArguments().length() != 2) {
                throw new EvalException("Expected at most two arguments to `~` operator", new Object[0]);
            }
            this.response = functionCall.getArgument(0);
            expandPredictor = expandPredictor(functionCall.getArgument(1), null);
            newCall = FunctionCall.newCall(TILDE, this.response, expandPredictor);
        }
        TermList termList = new TermList();
        add(termList, expandPredictor);
        return new Formula(newCall, this.intercept, termList.sorted());
    }

    private SEXP expandPredictor(SEXP sexp, SEXP sexp2) {
        if (!this.allowDotAsName && sexp == DOT) {
            return expandRemainingVariables(sexp2);
        }
        if (!(sexp instanceof FunctionCall)) {
            return sexp;
        }
        FunctionCall functionCall = (FunctionCall) sexp;
        FunctionCall.Builder builder = new FunctionCall.Builder();
        builder.mo9016add(functionCall.getFunction());
        for (PairList.Node node : functionCall.getArguments().nodes()) {
            builder.add(node.getName(), expandPredictor(node.getValue(), functionCall));
        }
        return builder.build();
    }

    private SEXP expandRemainingVariables(SEXP sexp) {
        if (this.dataFrame == null) {
            throw new EvalException("'.' in formula and no 'data' argument", new Object[0]);
        }
        HashSet hashSet = new HashSet();
        findResponseVariables(hashSet, this.response);
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < this.dataFrame.length(); i++) {
            String name = this.dataFrame.getName(i);
            if (!hashSet.contains(name)) {
                newArrayList.add(name);
            }
        }
        return newArrayList.isEmpty() ? DOT : (sexp == null || newArrayList.size() <= 1) ? expandRemainingVariables(newArrayList) : FunctionCall.newCall(GROUP, expandRemainingVariables(newArrayList));
    }

    private SEXP expandRemainingVariables(List<String> list) {
        Iterator<String> it = list.iterator();
        SEXP sexp = Symbol.get(it.next());
        while (true) {
            SEXP sexp2 = sexp;
            if (!it.hasNext()) {
                return sexp2;
            }
            sexp = FunctionCall.newCall(UNION, sexp2, Symbol.get(it.next()));
        }
    }

    private void findResponseVariables(Set<String> set, SEXP sexp) {
        if (sexp instanceof Symbol) {
            set.add(((Symbol) sexp).getPrintName());
        } else if (sexp instanceof FunctionCall) {
            Iterator<SEXP> it = ((FunctionCall) sexp).getArguments().values().iterator();
            while (it.hasNext()) {
                findResponseVariables(set, it.next());
            }
        }
    }

    private TermList buildTermList(SEXP sexp, boolean z) {
        TermList termList = new TermList();
        add(termList, sexp, z);
        return termList;
    }

    private TermList buildTermList(SEXP sexp) {
        return buildTermList(sexp, false);
    }

    private void add(TermList termList, SEXP sexp, boolean z) {
        if (sexp instanceof Symbol) {
            termList.add(new Term(sexp));
            return;
        }
        if (sexp instanceof Vector) {
            intercept((Vector) sexp, z);
            return;
        }
        if (sexp instanceof FunctionCall) {
            FunctionCall functionCall = (FunctionCall) sexp;
            if (functionCall.getFunction() == UNION) {
                unionTerms(termList, functionCall);
                return;
            }
            if (functionCall.getFunction() == EXPAND_TERMS) {
                multiply(termList, functionCall);
                return;
            }
            if (functionCall.getFunction() == DIFFERENCE) {
                difference(termList, functionCall);
            } else if (functionCall.getFunction() == GROUP) {
                add(termList, functionCall.getArgument(0), z);
            } else {
                termList.add(new TermBuilder().build(functionCall));
            }
        }
    }

    private void add(TermList termList, SEXP sexp) {
        add(termList, sexp, false);
    }

    private void multiply(TermList termList, FunctionCall functionCall) {
        TermList buildTermList = buildTermList(functionCall.getArgument(0));
        TermList buildTermList2 = buildTermList(functionCall.getArgument(1));
        termList.add(buildTermList);
        termList.add(buildTermList2);
        Iterator<Term> it = buildTermList.iterator();
        while (it.hasNext()) {
            Term next = it.next();
            Iterator<Term> it2 = buildTermList2.iterator();
            while (it2.hasNext()) {
                termList.add(new Term(next, it2.next()));
            }
        }
    }

    private void unionTerms(TermList termList, FunctionCall functionCall) {
        Iterator<SEXP> it = functionCall.getArguments().values().iterator();
        while (it.hasNext()) {
            add(termList, it.next());
        }
    }

    private void difference(TermList termList, FunctionCall functionCall) {
        if (functionCall.getArguments().length() == 1) {
            buildTermList(functionCall.getArgument(0), true);
            return;
        }
        TermList buildTermList = buildTermList(functionCall.getArgument(0));
        buildTermList.subtract(buildTermList(functionCall.getArgument(1), true));
        termList.add(buildTermList);
    }

    private void intercept(Vector vector, boolean z) {
        if (vector.length() != 1) {
            throw new EvalException("Invalid intercept: " + vector.toString() + ", expected 0 or 1", new Object[0]);
        }
        this.intercept = vector.getElementAsInt(0);
        if (this.intercept != 0 && this.intercept != 1) {
            throw new EvalException("Invalid intercept: " + this.intercept + ", expected 0 or 1", new Object[0]);
        }
        if (z) {
            this.intercept = this.intercept == 0 ? 1 : 0;
        }
    }

    public FormulaInterpreter withData(SEXP sexp) {
        if (sexp instanceof ListVector) {
            this.dataFrame = (ListVector) sexp;
        }
        return this;
    }

    public FormulaInterpreter allowDotAsName(boolean z) {
        this.allowDotAsName = z;
        return this;
    }
}
