package edu.columbia.tjw.item.fit.curve;

import edu.columbia.tjw.item.ItemCurveFactory;
import edu.columbia.tjw.item.ItemCurveParams;
import edu.columbia.tjw.item.ItemCurveType;
import edu.columbia.tjw.item.ItemModel;
import edu.columbia.tjw.item.ItemParameters;
import edu.columbia.tjw.item.ItemRegressor;
import edu.columbia.tjw.item.ItemSettings;
import edu.columbia.tjw.item.ItemStatus;
import edu.columbia.tjw.item.algo.QuantileDistribution;
import edu.columbia.tjw.item.data.ItemStatusGrid;
import edu.columbia.tjw.item.fit.ParamFittingGrid;
import edu.columbia.tjw.item.optimize.ConvergenceException;
import edu.columbia.tjw.item.optimize.EvaluationResult;
import edu.columbia.tjw.item.optimize.MultivariateOptimizer;
import edu.columbia.tjw.item.optimize.MultivariatePoint;
import edu.columbia.tjw.item.optimize.OptimizationResult;
import edu.columbia.tjw.item.util.LogUtil;
import edu.columbia.tjw.item.util.MultiLogistic;
import edu.columbia.tjw.item.util.RectangularDoubleArray;
import java.util.Arrays;
import java.util.logging.Logger;

/* loaded from: input_file:edu/columbia/tjw/item/fit/curve/CurveParamsFitter.class */
public final class CurveParamsFitter<S extends ItemStatus<S>, R extends ItemRegressor<R>, T extends ItemCurveType<T>> {
    private static final Logger LOG = LogUtil.getLogger(CurveParamsFitter.class);
    private final ItemCurveFactory<R, T> _factory;
    private final ItemSettings _settings;
    private final ItemStatusGrid<S, R> _grid;
    private final RectangularDoubleArray _powerScores;
    private final ParamFittingGrid<S, R, T> _paramGrid;
    private final ItemModel<S, R, T> _model;
    private final int[] _actualOutcomes;
    private final MultivariateOptimizer _optimizer;
    private final int[] _indexList;
    private final S _fromStatus;

    public CurveParamsFitter(CurveParamsFitter<S, R, T> curveParamsFitter, ItemParameters<S, R, T> itemParameters) {
        if (curveParamsFitter._fromStatus != itemParameters.getStatus()) {
            throw new IllegalArgumentException("From status mismatch.");
        }
        synchronized (this) {
            synchronized (curveParamsFitter) {
                this._settings = curveParamsFitter._settings;
                this._factory = curveParamsFitter._factory;
                this._grid = curveParamsFitter._grid;
                this._optimizer = curveParamsFitter._optimizer;
                this._fromStatus = curveParamsFitter._fromStatus;
                this._indexList = curveParamsFitter._indexList;
                this._actualOutcomes = curveParamsFitter._actualOutcomes;
            }
            int length = this._indexList.length;
            int reachableCount = this._fromStatus.getReachableCount();
            this._model = new ItemModel<>(itemParameters);
            this._paramGrid = new ParamFittingGrid<>(this._model.getParams(), this._grid);
            this._powerScores = new RectangularDoubleArray(length, reachableCount);
            fillPowerScores();
        }
    }

    public CurveParamsFitter(ItemCurveFactory<R, T> itemCurveFactory, ItemModel<S, R, T> itemModel, ItemStatusGrid<S, R> itemStatusGrid, ItemSettings itemSettings) {
        synchronized (this) {
            this._settings = itemSettings;
            this._factory = itemCurveFactory;
            this._model = itemModel;
            this._grid = itemStatusGrid;
            this._optimizer = new MultivariateOptimizer(itemSettings.getBlockSize(), 300, 20, 0.1d);
            this._fromStatus = itemModel.getParams().getStatus();
            int reachableCount = this._fromStatus.getReachableCount();
            this._indexList = generateIndexList(this._grid, this._fromStatus);
            int length = this._indexList.length;
            this._actualOutcomes = new int[length];
            for (int i = 0; i < length; i++) {
                this._actualOutcomes[i] = this._grid.getNextStatus(this._indexList[i]);
            }
            this._paramGrid = new ParamFittingGrid<>(this._model.getParams(), this._grid);
            this._powerScores = new RectangularDoubleArray(length, reachableCount);
            fillPowerScores();
        }
    }

    public synchronized RectangularDoubleArray getPowerScores() {
        return this._powerScores;
    }

    public FitResult<S, R, T> calibrateExistingCurve(int i, S s, double d) throws ConvergenceException {
        ItemParameters<S, R, T> params = this._model.getParams();
        FitResult<S, R, T> expandParameters = expandParameters(params.dropIndex(i), params.getEntryCurveParams(i), s, true, d);
        if (expandParameters.calculateAicDifference() > this._settings.getAicCutoff()) {
            return null;
        }
        return expandParameters;
    }

    public FitResult<S, R, T> calibrateCurveAddition(T t, R r, S s) throws ConvergenceException {
        LOG.info("\nCalculating Curve[" + t + ", " + r + ", " + s + "]");
        QuantileDistribution generateDistribution = generateDistribution(r, s);
        ItemCurveParams<R, T> generateStartingParameters = this._factory.generateStartingParameters(t, r, generateDistribution, this._settings.getRandom());
        CurveOptimizerFunction<S, R, T> generateFunction = generateFunction(generateStartingParameters, s, false);
        double computeStartingLogLikelihood = computeStartingLogLikelihood(generateFunction);
        FitResult<S, R, T> generateFit = generateFit(s, this._model.getParams(), generateFunction, computeStartingLogLikelihood, generateStartingParameters);
        if (this._settings.getPolishStartingParams()) {
            try {
                ItemCurveParams<R, T> polishCurveParameters = RawCurveCalibrator.polishCurveParameters(this._factory, this._settings, generateDistribution, r, generateStartingParameters);
                if (polishCurveParameters != generateStartingParameters) {
                    FitResult<S, R, T> generateFit2 = generateFit(s, this._model.getParams(), generateFunction, computeStartingLogLikelihood, polishCurveParameters);
                    double calculateAicDifference = generateFit.calculateAicDifference();
                    double calculateAicDifference2 = generateFit2.calculateAicDifference();
                    LOG.info("Polished params[" + calculateAicDifference + " <> " + calculateAicDifference2 + "]: " + (calculateAicDifference > calculateAicDifference2 ? "BETTER" : calculateAicDifference2 > calculateAicDifference ? "WORSE" : "SAME"));
                    if (calculateAicDifference > calculateAicDifference2) {
                        return generateFit2;
                    }
                }
            } catch (Exception e) {
                LOG.info("Exception during polish: " + e.toString());
            }
        }
        return generateFit;
    }

    private QuantileDistribution generateDistribution(R r, S s) {
        return new ItemQuantileDistribution(this._paramGrid, this._powerScores, this._fromStatus, r, s, this._indexList).getAdjusted();
    }

    public FitResult<S, R, T> expandParameters(ItemParameters<S, R, T> itemParameters, ItemCurveParams<R, T> itemCurveParams, S s, boolean z) throws ConvergenceException {
        CurveOptimizerFunction<S, R, T> generateFunction = generateFunction(itemCurveParams, s, z);
        return generateFit(s, itemParameters, generateFunction, computeStartingLogLikelihood(generateFunction), itemCurveParams);
    }

    public FitResult<S, R, T> expandParameters(ItemParameters<S, R, T> itemParameters, ItemCurveParams<R, T> itemCurveParams, S s, boolean z, double d) throws ConvergenceException {
        return generateFit(s, itemParameters, generateFunction(itemCurveParams, s, z), d, itemCurveParams);
    }

    private CurveOptimizerFunction<S, R, T> generateFunction(ItemCurveParams<R, T> itemCurveParams, S s, boolean z) {
        return new CurveOptimizerFunction<>(itemCurveParams, this._factory, this._fromStatus, s, this, this._actualOutcomes, this._paramGrid, this._indexList, this._settings, z);
    }

    public FitResult<S, R, T> generateFit(S s, ItemParameters<S, R, T> itemParameters, CurveOptimizerFunction<S, R, T> curveOptimizerFunction, double d, ItemCurveParams<R, T> itemCurveParams) throws ConvergenceException {
        OptimizationResult<MultivariatePoint> optimize = this._optimizer.optimize(curveOptimizerFunction, new MultivariatePoint(itemCurveParams.generatePoint()));
        ItemCurveParams<R, T> itemCurveParams2 = new ItemCurveParams<>(itemCurveParams, this._factory, optimize.getOptimum().getElements());
        return new FitResult<>(itemParameters.addBeta(itemCurveParams2, s), itemCurveParams2, s, optimize.minValue(), d, optimize.dataElementCount());
    }

    private double computeStartingLogLikelihood(CurveOptimizerFunction<S, R, T> curveOptimizerFunction) {
        MultivariatePoint multivariatePoint = new MultivariatePoint(curveOptimizerFunction.dimension());
        EvaluationResult generateResult = curveOptimizerFunction.generateResult();
        curveOptimizerFunction.value(multivariatePoint, 0, curveOptimizerFunction.numRows(), generateResult);
        return generateResult.getMean();
    }

    public ItemParameters<S, R, T> getParams() {
        return this._model.getParams();
    }

    private static <S extends ItemStatus<S>, R extends ItemRegressor<R>> int[] generateIndexList(ItemStatusGrid<S, R> itemStatusGrid, S s) {
        int size = itemStatusGrid.size();
        int ordinal = s.ordinal();
        int[] iArr = new int[size];
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            if (itemStatusGrid.getStatus(i2) == ordinal && itemStatusGrid.hasNextStatus(i2)) {
                int i3 = i;
                i++;
                iArr[i3] = i2;
            }
        }
        return Arrays.copyOf(iArr, i);
    }

    private void fillPowerScores() {
        int reachableCount = this._fromStatus.getReachableCount();
        double[] dArr = new double[reachableCount];
        int indexOf = this._fromStatus.getReachable().indexOf(this._fromStatus);
        int length = this._actualOutcomes.length;
        for (int i = 0; i < length; i++) {
            this._model.transitionProbability(this._paramGrid, this._indexList[i], dArr);
            MultiLogistic.multiLogitFunction(indexOf, dArr, dArr);
            for (int i2 = 0; i2 < reachableCount; i2++) {
                this._powerScores.set(i, i2, dArr[i2]);
            }
        }
    }
}
