package org.nd4j.linalg.api.ops.impl.loss;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.class */
public class SoftmaxCrossEntropyLoss extends BaseLoss {
    public static final double DEFAULT_LABEL_SMOOTHING = 0.0d;
    private double labelSmoothing;

    public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3, double d) {
        super(sameDiff, lossReduce, sDVariable, sDVariable2, sDVariable3);
        this.labelSmoothing = 0.0d;
        this.labelSmoothing = d;
        addArgs();
    }

    public SoftmaxCrossEntropyLoss(SameDiff sameDiff, LossReduce lossReduce, SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        this(sameDiff, lossReduce, sDVariable, sDVariable2, sDVariable3, 0.0d);
    }

    @Override // org.nd4j.linalg.api.ops.impl.loss.BaseLoss
    public void addArgs() {
        super.addArgs();
        addTArgument(this.labelSmoothing);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.impl.loss.BaseLoss, org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "softmax_cross_entropy_loss";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "SoftmaxCrossEntropy";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return Arrays.asList(f().lossSoftmaxCrossEntropyBp(arg(2), arg(0), arg(1), this.lossReduce, this.labelSmoothing));
    }

    public SoftmaxCrossEntropyLoss() {
        this.labelSmoothing = 0.0d;
    }
}
