package org.apache.spark.ml.regression;

import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorImplicits$;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.optimization.Gradient;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Tuple2;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;

/* compiled from: FMRegressor.scala */
@ScalaSignature(bytes = "\u0006\u0005\u00154a\u0001D\u0007\u0002\u0002=9\u0002\u0002\u0003\u0011\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u0012\t\u0011!\u0002!\u0011!Q\u0001\n%B\u0001\u0002\f\u0001\u0003\u0002\u0003\u0006I!\u000b\u0005\t[\u0001\u0011\t\u0011)A\u0005E!)a\u0006\u0001C\u0001_!)a\u0007\u0001C!o!)\u0011\n\u0001D\u0001\u0015\")Q\n\u0001D\t\u001d\")\u0011\u000b\u0001D\t%\")Q\u000b\u0001C\u0001-\")q\f\u0001C\u0005A\n\t#)Y:f\r\u0006\u001cGo\u001c:ju\u0006$\u0018n\u001c8NC\u000eD\u0017N\\3t\u000fJ\fG-[3oi*\u0011abD\u0001\u000be\u0016<'/Z:tS>t'B\u0001\t\u0012\u0003\tiGN\u0003\u0002\u0013'\u0005)1\u000f]1sW*\u0011A#F\u0001\u0007CB\f7\r[3\u000b\u0003Y\t1a\u001c:h'\t\u0001\u0001\u0004\u0005\u0002\u001a=5\t!D\u0003\u0002\u001c9\u0005aq\u000e\u001d;j[&T\u0018\r^5p]*\u0011Q$E\u0001\u0006[2d\u0017NY\u0005\u0003?i\u0011\u0001b\u0012:bI&,g\u000e^\u0001\u000bM\u0006\u001cGo\u001c:TSj,7\u0001\u0001\t\u0003G\u0019j\u0011\u0001\n\u0006\u0002K\u0005)1oY1mC&\u0011q\u0005\n\u0002\u0004\u0013:$\u0018\u0001\u00044ji&sG/\u001a:dKB$\bCA\u0012+\u0013\tYCEA\u0004C_>dW-\u00198\u0002\u0013\u0019LG\u000fT5oK\u0006\u0014\u0018a\u00038v[\u001a+\u0017\r^;sKN\fa\u0001P5oSRtD#\u0002\u00193gQ*\u0004CA\u0019\u0001\u001b\u0005i\u0001\"\u0002\u0011\u0006\u0001\u0004\u0011\u0003\"\u0002\u0015\u0006\u0001\u0004I\u0003\"\u0002\u0017\u0006\u0001\u0004I\u0003\"B\u0017\u0006\u0001\u0004\u0011\u0013aB2p[B,H/\u001a\u000b\u0006qm\u001aUi\u0012\t\u0003GeJ!A\u000f\u0013\u0003\r\u0011{WO\u00197f\u0011\u0015ad\u00011\u0001>\u0003\u0011!\u0017\r^1\u0011\u0005y\nU\"A \u000b\u0005\u0001c\u0012A\u00027j]\u0006dw-\u0003\u0002C\u007f\t1a+Z2u_JDQ\u0001\u0012\u0004A\u0002a\nQ\u0001\\1cK2DQA\u0012\u0004A\u0002u\nqa^3jO\"$8\u000fC\u0003I\r\u0001\u0007Q(A\u0006dk6<%/\u00193jK:$\u0018!D4fiB\u0013X\rZ5di&|g\u000e\u0006\u00029\u0017\")Aj\u0002a\u0001q\u0005i!/Y<Qe\u0016$\u0017n\u0019;j_:\fQbZ3u\u001bVdG/\u001b9mS\u0016\u0014Hc\u0001\u001dP!\")A\n\u0003a\u0001q!)A\t\u0003a\u0001q\u00059q-\u001a;M_N\u001cHc\u0001\u001dT)\")A*\u0003a\u0001q!)A)\u0003a\u0001q\u0005\u0001r-\u001a;SC^\u0004&/\u001a3jGRLwN\u001c\u000b\u0004/vs\u0006\u0003B\u0012YqiK!!\u0017\u0013\u0003\rQ+\b\u000f\\33!\r\u00193\fO\u0005\u00039\u0012\u0012Q!\u0011:sCfDQ\u0001\u0010\u0006A\u0002uBQA\u0012\u0006A\u0002u\nabZ3u%\u0006<xI]1eS\u0016tG\u000f\u0006\u0003>C\n\u001c\u0007\"\u0002\u001f\f\u0001\u0004i\u0004\"\u0002$\f\u0001\u0004i\u0004\"\u00023\f\u0001\u0004Q\u0016!B:v[ZC\u0006")
/* loaded from: input_file:org/apache/spark/ml/regression/BaseFactorizationMachinesGradient.class */
public abstract class BaseFactorizationMachinesGradient extends Gradient {
    private final int factorSize;
    private final boolean fitIntercept;
    private final boolean fitLinear;
    private final int numFeatures;

    @Override // org.apache.spark.mllib.optimization.Gradient
    public double compute(Vector vector, double d, Vector vector2, Vector vector3) {
        Tuple2<Object, double[]> rawPrediction = getRawPrediction(vector, vector2);
        if (rawPrediction == null) {
            throw new MatchError(rawPrediction);
        }
        double _1$mcD$sp = rawPrediction._1$mcD$sp();
        Tuple2 tuple2 = new Tuple2(BoxesRunTime.boxToDouble(_1$mcD$sp), (double[]) rawPrediction._2());
        double _1$mcD$sp2 = tuple2._1$mcD$sp();
        Vector rawGradient = getRawGradient(vector, vector2, (double[]) tuple2._2());
        BLAS$.MODULE$.axpy(getMultiplier(_1$mcD$sp2, d), VectorImplicits$.MODULE$.mllibVectorToMLVector(rawGradient), VectorImplicits$.MODULE$.mllibVectorToMLVector(vector3));
        return getLoss(_1$mcD$sp2, d);
    }

    public abstract double getPrediction(double d);

    public abstract double getMultiplier(double d, double d2);

    public abstract double getLoss(double d, double d2);

    public Tuple2<Object, double[]> getRawPrediction(Vector vector, Vector vector2) {
        double[] dArr = new double[this.factorSize];
        DoubleRef create = DoubleRef.create(0.0d);
        int i = this.numFeatures * this.factorSize;
        if (this.fitIntercept) {
            create.elem += vector2.apply(vector2.size() - 1);
        }
        if (this.fitLinear) {
            vector.foreachNonZero((i2, d) -> {
                Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i2, d);
                if (spVar == null) {
                    throw new MatchError(spVar);
                }
                create.elem += vector2.apply(i + spVar._1$mcI$sp()) * spVar._2$mcD$sp();
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
        }
        RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), this.factorSize).foreach$mVc$sp(i3 -> {
            DoubleRef create2 = DoubleRef.create(0.0d);
            DoubleRef create3 = DoubleRef.create(0.0d);
            vector.foreachNonZero((i3, d2) -> {
                Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i3, d2);
                if (spVar == null) {
                    throw new MatchError(spVar);
                }
                double apply = vector2.apply((spVar._1$mcI$sp() * this.factorSize) + i3) * spVar._2$mcD$sp();
                create2.elem += apply * apply;
                create3.elem += apply;
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            dArr[i3] = create3.elem;
            create.elem += 0.5d * ((create3.elem * create3.elem) - create2.elem);
        });
        return new Tuple2<>(BoxesRunTime.boxToDouble(create.elem), dArr);
    }

    private Vector getRawGradient(Vector vector, Vector vector2, double[] dArr) {
        Vector dense;
        if (vector instanceof SparseVector) {
            SparseVector sparseVector = (SparseVector) vector;
            int length = (sparseVector.indices().length * this.factorSize) + (this.fitLinear ? sparseVector.indices().length : 0) + (this.fitIntercept ? 1 : 0);
            int[] iArr = (int[]) Array$.MODULE$.ofDim(length, ClassTag$.MODULE$.Int());
            double[] dArr2 = (double[]) Array$.MODULE$.ofDim(length, ClassTag$.MODULE$.Double());
            IntRef create = IntRef.create(0);
            int i = this.numFeatures * this.factorSize;
            sparseVector.foreachNonZero((i2, d) -> {
                Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i2, d);
                if (spVar == null) {
                    throw new MatchError(spVar);
                }
                int _1$mcI$sp = spVar._1$mcI$sp();
                double _2$mcD$sp = spVar._2$mcD$sp();
                RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), this.factorSize).foreach$mVc$sp(i2 -> {
                    iArr[create.elem] = (_1$mcI$sp * this.factorSize) + i2;
                    dArr2[create.elem] = (_2$mcD$sp * dArr[i2]) - ((vector2.apply((_1$mcI$sp * this.factorSize) + i2) * _2$mcD$sp) * _2$mcD$sp);
                    create.elem++;
                });
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            });
            if (this.fitLinear) {
                sparseVector.foreachNonZero((i3, d2) -> {
                    Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i3, d2);
                    if (spVar == null) {
                        throw new MatchError(spVar);
                    }
                    int _1$mcI$sp = spVar._1$mcI$sp();
                    double _2$mcD$sp = spVar._2$mcD$sp();
                    iArr[create.elem] = i + _1$mcI$sp;
                    dArr2[create.elem] = _2$mcD$sp;
                    create.elem++;
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                });
            }
            if (this.fitIntercept) {
                iArr[create.elem] = vector2.size() - 1;
                dArr2[create.elem] = 1.0d;
            }
            dense = Vectors$.MODULE$.sparse(vector2.size(), iArr, dArr2);
        } else {
            if (!(vector instanceof DenseVector)) {
                throw new MatchError(vector);
            }
            DenseVector denseVector = (DenseVector) vector;
            double[] dArr3 = (double[]) Array$.MODULE$.ofDim(vector2.size(), ClassTag$.MODULE$.Double());
            int i4 = this.numFeatures * this.factorSize;
            if (this.fitIntercept) {
                int size = vector2.size() - 1;
                dArr3[size] = dArr3[size] + 1.0d;
            }
            if (this.fitLinear) {
                denseVector.foreachNonZero((i5, d3) -> {
                    Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i5, d3);
                    if (spVar == null) {
                        throw new MatchError(spVar);
                    }
                    int _1$mcI$sp = spVar._1$mcI$sp();
                    double _2$mcD$sp = spVar._2$mcD$sp();
                    int i5 = i4 + _1$mcI$sp;
                    dArr3[i5] = dArr3[i5] + _2$mcD$sp;
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                });
            }
            RichInt$.MODULE$.until$extension(Predef$.MODULE$.intWrapper(0), this.factorSize).foreach$mVc$sp(i6 -> {
                denseVector.foreachNonZero((i6, d4) -> {
                    Tuple2.mcID.sp spVar = new Tuple2.mcID.sp(i6, d4);
                    if (spVar == null) {
                        throw new MatchError(spVar);
                    }
                    int _1$mcI$sp = spVar._1$mcI$sp();
                    double _2$mcD$sp = spVar._2$mcD$sp();
                    int i6 = (_1$mcI$sp * this.factorSize) + i6;
                    dArr3[i6] = dArr3[i6] + ((_2$mcD$sp * dArr[i6]) - ((vector2.apply((_1$mcI$sp * this.factorSize) + i6) * _2$mcD$sp) * _2$mcD$sp));
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                });
            });
            dense = Vectors$.MODULE$.dense(dArr3);
        }
        return dense;
    }

    public BaseFactorizationMachinesGradient(int i, boolean z, boolean z2, int i2) {
        this.factorSize = i;
        this.fitIntercept = z;
        this.fitLinear = z2;
        this.numFeatures = i2;
    }
}
