package org.nd4j.linalg.api.activation;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.ArrayOps;
import org.nd4j.linalg.ops.factory.ElementWiseOpFactories;
import org.nd4j.linalg.ops.factory.ElementWiseOpFactory;

/* loaded from: input_file:org/nd4j/linalg/api/activation/SoftMax.class */
public class SoftMax extends BaseActivationFunction {
    private boolean rows;
    private static final long serialVersionUID = -3407472284248637360L;

    public SoftMax(boolean z) {
        this.rows = z;
    }

    public SoftMax() {
        this(false);
    }

    public static INDArray softmax(INDArray iNDArray, boolean z) {
        if (!z) {
            if (iNDArray.ordering() == 'f') {
                INDArray subRowVector = iNDArray.subRowVector(iNDArray.max(0).transpose());
                new ArrayOps().from(subRowVector).op(ElementWiseOpFactories.exp()).build().exec();
                subRowVector.diviRowVector(subRowVector.sum(0));
                return subRowVector;
            }
            INDArray subRowVector2 = iNDArray.subRowVector(iNDArray.max(0).transpose());
            new ArrayOps().from(subRowVector2).op(ElementWiseOpFactories.exp()).build().exec();
            subRowVector2.diviRowVector(subRowVector2.sum(0));
            return subRowVector2;
        }
        if (iNDArray.ordering() == 'f') {
            INDArray max = iNDArray.max(1);
            if (!max.isColumnVector()) {
                max = max.transpose();
            }
            INDArray subColumnVector = iNDArray.subColumnVector(max);
            new ArrayOps().from(subColumnVector).op(ElementWiseOpFactories.exp()).build().exec();
            subColumnVector.diviColumnVector(subColumnVector.sum(1).transpose());
            return subColumnVector;
        }
        INDArray max2 = iNDArray.max(1);
        if (!max2.isColumnVector()) {
            max2 = max2.transpose();
        }
        INDArray subColumnVector2 = iNDArray.subColumnVector(max2);
        new ArrayOps().from(subColumnVector2).op(ElementWiseOpFactories.exp()).build().exec();
        subColumnVector2.diviColumnVector(subColumnVector2.sum(1).transpose());
        return subColumnVector2;
    }

    @Override // org.nd4j.linalg.api.activation.BaseActivationFunction
    public INDArray apply(INDArray iNDArray) {
        return softmax(iNDArray, this.rows);
    }

    @Override // org.nd4j.linalg.api.activation.ActivationFunction
    public ElementWiseOpFactory transformFactory() {
        return null;
    }

    @Override // org.nd4j.linalg.api.activation.ActivationFunction
    public INDArray applyDerivative(INDArray iNDArray) {
        return iNDArray instanceof IComplexNDArray ? softmax(iNDArray, this.rows).mul(Nd4j.complexOnes(iNDArray.shape()).subi(softmax(iNDArray, this.rows))) : softmax(iNDArray, this.rows).mul(Nd4j.ones(iNDArray.shape()).subi(softmax(iNDArray, this.rows)));
    }
}
