package dk.alexandra.fresco.lib.lp;

import dk.alexandra.fresco.framework.DRes;
import dk.alexandra.fresco.framework.builder.Computation;
import dk.alexandra.fresco.framework.builder.numeric.Numeric;
import dk.alexandra.fresco.framework.builder.numeric.ProtocolBuilderNumeric;
import dk.alexandra.fresco.framework.util.Pair;
import dk.alexandra.fresco.framework.value.SInt;
import dk.alexandra.fresco.lib.collections.Matrix;
import dk.alexandra.fresco.lib.conditional.ConditionalSelect;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:dk/alexandra/fresco/lib/lp/UpdateMatrix.class */
public class UpdateMatrix implements Computation<Matrix<DRes<SInt>>, ProtocolBuilderNumeric> {
    private Matrix<DRes<SInt>> oldUpdateMatrix;
    private List<DRes<SInt>> L;
    private List<DRes<SInt>> C;
    private DRes<SInt> p;
    private DRes<SInt> p_prime;

    /* JADX INFO: Access modifiers changed from: package-private */
    public UpdateMatrix(Matrix<DRes<SInt>> matrix, List<DRes<SInt>> list, List<DRes<SInt>> list2, DRes<SInt> dRes, DRes<SInt> dRes2) {
        this.oldUpdateMatrix = matrix;
        this.L = list;
        this.C = list2;
        this.p = dRes;
        this.p_prime = dRes2;
    }

    @Override // dk.alexandra.fresco.framework.builder.Computation
    public DRes<Matrix<DRes<SInt>>> buildComputation(ProtocolBuilderNumeric protocolBuilderNumeric) {
        int height = this.oldUpdateMatrix.getHeight();
        int width = this.oldUpdateMatrix.getWidth();
        DRes<SInt> known = protocolBuilderNumeric.numeric().known(BigInteger.ONE);
        return protocolBuilderNumeric.seq(protocolBuilderNumeric2 -> {
            return protocolBuilderNumeric2.par(protocolBuilderNumeric2 -> {
                Numeric numeric = protocolBuilderNumeric2.numeric();
                Matrix matrix = new Matrix(height, width, i -> {
                    ArrayList arrayList = new ArrayList(width);
                    ArrayList<DRes<SInt>> row = this.oldUpdateMatrix.getRow(i);
                    for (int i = 0; i < width; i++) {
                        if (i < width - 1) {
                            arrayList.add(numeric.mult(this.L.get(i), row.get(i)));
                        } else {
                            arrayList.add(numeric.known(BigInteger.ZERO));
                        }
                    }
                    return arrayList;
                });
                DRes seq = protocolBuilderNumeric2.seq(protocolBuilderNumeric2 -> {
                    DRes<SInt> invert = protocolBuilderNumeric2.advancedNumeric().invert(this.p_prime);
                    DRes<SInt> mult = protocolBuilderNumeric2.numeric().mult(this.p, invert);
                    return protocolBuilderNumeric2.par(protocolBuilderNumeric2 -> {
                        ArrayList arrayList = new ArrayList(this.C.size());
                        for (int i2 = 0; i2 < this.C.size() - 1; i2++) {
                            int i3 = i2;
                            arrayList.add(protocolBuilderNumeric2.seq(protocolBuilderNumeric2 -> {
                                return protocolBuilderNumeric2.numeric().mult(this.C.get(i3), protocolBuilderNumeric2.seq(new ConditionalSelect(this.L.get(i3), known, invert)));
                            }));
                        }
                        arrayList.add(protocolBuilderNumeric2.numeric().mult(this.C.get(this.C.size() - 1), (DRes<SInt>) invert));
                        return Pair.lazy(arrayList, mult);
                    });
                });
                return () -> {
                    return new Pair(matrix, seq.out2());
                };
            }).par((protocolBuilderNumeric3, pair) -> {
                Matrix matrix = (Matrix) pair.getFirst();
                List list = (List) ((Pair) pair.getSecond()).getFirst();
                DRes dRes = (DRes) ((Pair) pair.getSecond()).getSecond();
                Numeric numeric = protocolBuilderNumeric3.numeric();
                ArrayList arrayList = new ArrayList(height);
                Matrix matrix2 = new Matrix(height, width, i -> {
                    ArrayList arrayList2 = new ArrayList(width);
                    ArrayList<DRes<SInt>> row = this.oldUpdateMatrix.getRow(i);
                    ArrayList row2 = matrix.getRow(i);
                    for (int i = 0; i < width; i++) {
                        arrayList2.add(numeric.sub(row.get(i), (DRes<SInt>) row2.get(i)));
                    }
                    return arrayList2;
                });
                for (int i2 = 0; i2 < width; i2++) {
                    arrayList.add(protocolBuilderNumeric3.advancedNumeric().sum(matrix.getColumn(i2)));
                }
                return () -> {
                    return new Pair(new Pair(list, dRes), new Pair(matrix2, arrayList));
                };
            });
        }).par((protocolBuilderNumeric3, pair) -> {
            List list = (List) ((Pair) pair.getFirst()).getFirst();
            DRes dRes = (DRes) ((Pair) pair.getFirst()).getSecond();
            Matrix matrix = (Matrix) ((Pair) pair.getSecond()).getFirst();
            List list2 = (List) ((Pair) pair.getSecond()).getSecond();
            Numeric numeric = protocolBuilderNumeric3.numeric();
            return Pair.lazy(new Matrix(height, width, i -> {
                ArrayList arrayList = new ArrayList(width);
                for (int i = 0; i < width; i++) {
                    arrayList.add(numeric.mult((DRes<SInt>) list.get(i), (DRes<SInt>) list2.get(i)));
                }
                return arrayList;
            }), new Matrix(height, width, i2 -> {
                ArrayList arrayList = new ArrayList(width);
                ArrayList row = matrix.getRow(i2);
                for (int i2 = 0; i2 < width; i2++) {
                    arrayList.add(numeric.mult((DRes<SInt>) row.get(i2), (DRes<SInt>) dRes));
                }
                return arrayList;
            }));
        }).par((protocolBuilderNumeric4, pair2) -> {
            Matrix matrix = (Matrix) pair2.getFirst();
            Matrix matrix2 = (Matrix) pair2.getSecond();
            Numeric numeric = protocolBuilderNumeric4.numeric();
            Matrix matrix3 = new Matrix(height, width, i -> {
                ArrayList arrayList = new ArrayList(width);
                ArrayList row = matrix.getRow(i);
                ArrayList row2 = matrix2.getRow(i);
                for (int i = 0; i < width; i++) {
                    arrayList.add(numeric.add((DRes<SInt>) row.get(i), (DRes<SInt>) row2.get(i)));
                }
                return arrayList;
            });
            return () -> {
                return matrix3;
            };
        });
    }
}
