/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.loss.HingeLoss;
import ai.djl.training.loss.L1Loss;
import ai.djl.training.loss.L2Loss;
import ai.djl.training.loss.SigmoidBinaryCrossEntropyLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public abstract class Loss
extends Evaluator {
    private Map<String, Float> totalLoss = new ConcurrentHashMap<String, Float>();

    public Loss(String name) {
        super(name);
    }

    public static L1Loss l1Loss() {
        return new L1Loss();
    }

    public static L1Loss l1Loss(String name) {
        return new L1Loss(name);
    }

    public static L1Loss l1Loss(String name, float weight) {
        return new L1Loss(name, weight);
    }

    public static L2Loss l2Loss() {
        return new L2Loss();
    }

    public static L2Loss l2Loss(String name) {
        return new L2Loss(name);
    }

    public static L2Loss l2Loss(String name, float weight) {
        return new L2Loss(name, weight);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss() {
        return new SigmoidBinaryCrossEntropyLoss();
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String name) {
        return new SigmoidBinaryCrossEntropyLoss(name);
    }

    public static SigmoidBinaryCrossEntropyLoss sigmoidBinaryCrossEntropyLoss(String name, float weight, boolean fromSigmoid) {
        return new SigmoidBinaryCrossEntropyLoss(name, weight, fromSigmoid);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss() {
        return new SoftmaxCrossEntropyLoss();
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String name) {
        return new SoftmaxCrossEntropyLoss(name);
    }

    public static SoftmaxCrossEntropyLoss softmaxCrossEntropyLoss(String name, float weight, int classAxis, boolean sparseLabel, boolean fromLogit) {
        return new SoftmaxCrossEntropyLoss(name, weight, classAxis, sparseLabel, fromLogit);
    }

    public static HingeLoss hingeLoss() {
        return new HingeLoss();
    }

    public static HingeLoss hingeLoss(String name) {
        return new HingeLoss(name);
    }

    public static HingeLoss hingeLoss(String name, int margin, float weight) {
        return new HingeLoss(name, margin, weight);
    }

    @Override
    public void addAccumulator(String key) {
        this.totalInstances.put(key, 0L);
        this.totalLoss.put(key, Float.valueOf(0.0f));
    }

    @Override
    public void updateAccumulator(String key, NDList labels, NDList predictions) {
        float update = this.evaluate(labels, predictions).sum().getFloat(new long[0]);
        this.totalInstances.compute(key, (k, v) -> v + 1L);
        this.totalLoss.compute(key, (k, v) -> Float.valueOf(v.floatValue() + update));
    }

    @Override
    public void resetAccumulator(String key) {
        this.totalInstances.compute(key, (k, v) -> 0L);
        this.totalLoss.compute(key, (k, v) -> Float.valueOf(0.0f));
    }

    @Override
    public float getAccumulator(String key) {
        Long total = (Long)this.totalInstances.get(key);
        if (total == null) {
            throw new IllegalArgumentException("No loss found at that path");
        }
        if (total == 0L) {
            return Float.NaN;
        }
        return this.totalLoss.get(key).floatValue() / (float)((Long)this.totalInstances.get(key)).longValue();
    }
}

