/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.optimization;

import java.io.Serializable;
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.mllib.linalg.BLAS$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.optimization.Gradient;
import org.apache.spark.mllib.util.MLUtils$;
import scala.Array$;
import scala.Function1;
import scala.Function2;
import scala.Predef$;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.DoubleRef;
import scala.runtime.IntRef;
import scala.runtime.RichInt$;
import scala.runtime.java8.JFunction1;
import scala.runtime.java8.JFunction2;

@DeveloperApi
@ScalaSignature(bytes="\u0006\u0001q2A!\u0002\u0004\u0001#!Aa\u0003\u0001B\u0001B\u0003%q\u0003C\u0003\u001e\u0001\u0011\u0005a\u0004C\u0003\u001e\u0001\u0011\u0005\u0011\u0005C\u0003#\u0001\u0011\u00053E\u0001\tM_\u001eL7\u000f^5d\u000fJ\fG-[3oi*\u0011q\u0001C\u0001\r_B$\u0018.\\5{CRLwN\u001c\u0006\u0003\u0013)\tQ!\u001c7mS\nT!a\u0003\u0007\u0002\u000bM\u0004\u0018M]6\u000b\u00055q\u0011AB1qC\u000eDWMC\u0001\u0010\u0003\ry'oZ\u0002\u0001'\t\u0001!\u0003\u0005\u0002\u0014)5\ta!\u0003\u0002\u0016\r\tAqI]1eS\u0016tG/\u0001\u0006ok6\u001cE.Y:tKN\u0004\"\u0001G\u000e\u000e\u0003eQ\u0011AG\u0001\u0006g\u000e\fG.Y\u0005\u00039e\u00111!\u00138u\u0003\u0019a\u0014N\\5u}Q\u0011q\u0004\t\t\u0003'\u0001AQA\u0006\u0002A\u0002]!\u0012aH\u0001\bG>l\u0007/\u001e;f)\u0015!seL\u00194!\tAR%\u0003\u0002'3\t1Ai\\;cY\u0016DQ\u0001\u000b\u0003A\u0002%\nA\u0001Z1uCB\u0011!&L\u0007\u0002W)\u0011A\u0006C\u0001\u0007Y&t\u0017\r\\4\n\u00059Z#A\u0002,fGR|'\u000fC\u00031\t\u0001\u0007A%A\u0003mC\n,G\u000eC\u00033\t\u0001\u0007\u0011&A\u0004xK&<\u0007\u000e^:\t\u000bQ\"\u0001\u0019A\u0015\u0002\u0017\r,Xn\u0012:bI&,g\u000e\u001e\u0015\u0003\u0001Y\u0002\"a\u000e\u001e\u000e\u0003aR!!\u000f\u0006\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002<q\taA)\u001a<fY>\u0004XM]!qS\u0002")
public class LogisticGradient
extends Gradient {
    private final int numClasses;

    @Override
    public double compute(Vector data, double label, Vector weights, Vector cumGradient) {
        double d;
        int dataSize = data.size();
        Predef$.MODULE$.require(weights.size() % dataSize == 0 && this.numClasses == weights.size() / dataSize + 1);
        int n = this.numClasses;
        switch (n) {
            case 2: {
                double margin = -1.0 * BLAS$.MODULE$.dot(data, weights);
                double multiplier = 1.0 / (1.0 + package$.MODULE$.exp(margin)) - label;
                BLAS$.MODULE$.axpy(multiplier, data, cumGradient);
                if (label > 0.0) {
                    d = MLUtils$.MODULE$.log1pExp(margin);
                    break;
                }
                d = MLUtils$.MODULE$.log1pExp(margin) - margin;
                break;
            }
            default: {
                double loss;
                Vector vector = weights;
                if (!(vector instanceof DenseVector)) {
                    throw new IllegalArgumentException(new StringBuilder(49).append("weights only supports dense vector but got type ").append(weights.getClass()).append(".").toString());
                }
                DenseVector denseVector = (DenseVector)vector;
                double[] dArray = denseVector.values();
                double[] weightsArray = dArray;
                Vector vector2 = cumGradient;
                if (!(vector2 instanceof DenseVector)) {
                    throw new IllegalArgumentException(new StringBuilder(53).append("cumGradient only supports dense vector but got type ").append(cumGradient.getClass()).append(".").toString());
                }
                DenseVector denseVector2 = (DenseVector)vector2;
                double[] dArray2 = denseVector2.values();
                double[] cumGradientArray = dArray2;
                DoubleRef marginY = DoubleRef.create((double)0.0);
                DoubleRef maxMargin = DoubleRef.create((double)Double.NEGATIVE_INFINITY);
                IntRef maxMarginIndex = IntRef.create((int)0);
                double[] margins = (double[])Array$.MODULE$.tabulate(this.numClasses - 1, (Function1)(JFunction1.mcDI.sp & Serializable & scala.Serializable)i -> {
                    DoubleRef margin;
                    block1: {
                        margin = DoubleRef.create((double)0.0);
                        data.foreachActive((Function2<Object, Object, BoxedUnit>)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
                            block0: {
                                if (value == 0.0) break block0;
                                margin$1.elem += value * weightsArray[i * dataSize + index];
                            }
                        });
                        if (i == (int)label - 1) {
                            marginY$1.elem = margin.elem;
                        }
                        if (!(margin.elem > maxMargin$1.elem)) break block1;
                        maxMargin$1.elem = margin.elem;
                        maxMarginIndex$1.elem = i;
                    }
                    return margin.elem;
                }, ClassTag$.MODULE$.Double());
                DoubleRef temp = DoubleRef.create((double)0.0);
                if (maxMargin.elem > 0.0) {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                        margins$1[i] = margins[i] - maxMargin$1.elem;
                        temp$1.elem = i == maxMarginIndex$1.elem ? (temp$1.elem += package$.MODULE$.exp(-maxMargin$1.elem)) : (temp$1.elem += package$.MODULE$.exp(margins[i]));
                    });
                } else {
                    RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> temp$1.elem += package$.MODULE$.exp(margins[i]));
                }
                double sum = temp.elem;
                RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), this.numClasses - 1).foreach$mVc$sp((Function1)(JFunction1.mcVI.sp & Serializable & scala.Serializable)i -> {
                    double multiplier = package$.MODULE$.exp(margins[i]) / (sum + 1.0) - (label != 0.0 && label == (double)(i + 1) ? 1.0 : 0.0);
                    data.foreachActive((Function2<Object, Object, BoxedUnit>)(JFunction2.mcVID.sp & Serializable & scala.Serializable)(index, value) -> {
                        block0: {
                            if (value == 0.0) break block0;
                            int n = i * dataSize + index;
                            cumGradientArray$1[n] = cumGradientArray[n] + multiplier * value;
                        }
                    });
                });
                double d2 = loss = label > 0.0 ? package$.MODULE$.log1p(sum) - marginY.elem : package$.MODULE$.log1p(sum);
                if (maxMargin.elem > 0.0) {
                    d = loss + maxMargin.elem;
                    break;
                }
                d = loss;
                break;
            }
        }
        return d;
    }

    public LogisticGradient(int numClasses) {
        this.numClasses = numClasses;
    }

    public LogisticGradient() {
        this(2);
    }
}

