/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.optimization.updatecalculators;

import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.optimization.SmoothParametrized;
import org.apache.ignite.ml.optimization.updatecalculators.NesterovParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;

public class NesterovUpdateCalculator<M extends SmoothParametrized<M>>
implements ParameterUpdateCalculator<M, NesterovParameterUpdate> {
    private static final long serialVersionUID = 251066184668190622L;
    private final double learningRate;
    private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
    protected double momentum;

    public NesterovUpdateCalculator(double learningRate, double momentum) {
        this.learningRate = learningRate;
        this.momentum = momentum;
    }

    @Override
    public NesterovParameterUpdate calculateNewUpdate(M mdl, NesterovParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) {
        Vector prevUpdates = updaterParameters.prevIterationUpdates();
        Object newMdl = mdl;
        if (iteration > 0) {
            newMdl = (SmoothParametrized)mdl.withParameters(mdl.parameters().minus(prevUpdates.times(this.momentum)));
        }
        Vector gradient = newMdl.differentiateByParameters(this.loss, inputs, groundTruth);
        return new NesterovParameterUpdate(prevUpdates.times(this.momentum).plus(gradient.times(this.learningRate)));
    }

    @Override
    public NesterovParameterUpdate init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
        this.loss = loss;
        return new NesterovParameterUpdate(mdl.parametersCount());
    }

    @Override
    public <M1 extends M> M1 update(M1 obj, NesterovParameterUpdate update) {
        Vector parameters = obj.parameters();
        return (M1)((SmoothParametrized)obj.setParameters(parameters.minus(update.prevIterationUpdates())));
    }
}

