package org.nd4j.autodiff.validation.listeners;

import java.security.MessageDigest;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/* loaded from: input_file:org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.class */
public class NonInplaceValidationListener extends BaseListener {
    private static AtomicInteger useCounter = new AtomicInteger();
    private static AtomicInteger passCounter = new AtomicInteger();
    private static AtomicInteger failCounter = new AtomicInteger();
    protected INDArray[] opInputs;

    public NonInplaceValidationListener() {
        useCounter.getAndIncrement();
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void preOpExecution(SameDiff sameDiff, At at, SameDiffOp sameDiffOp) {
        if (sameDiffOp.getOp().isInPlace()) {
            return;
        }
        if (sameDiffOp.getOp() instanceof Op) {
            Op op = (Op) sameDiffOp.getOp();
            if (op.x() == null) {
                return;
            }
            if (op.y() == null) {
                this.opInputs = new INDArray[]{op.x().dup()};
                return;
            } else {
                this.opInputs = new INDArray[]{op.x().dup(), op.y().dup()};
                return;
            }
        }
        if (!(sameDiffOp.getOp() instanceof DynamicCustomOp)) {
            throw new IllegalStateException("Unknown op type: " + sameDiffOp.getOp().getClass());
        }
        INDArray[] inputArguments = ((DynamicCustomOp) sameDiffOp.getOp()).inputArguments();
        this.opInputs = new INDArray[inputArguments.length];
        for (int i = 0; i < inputArguments.length; i++) {
            this.opInputs[i] = inputArguments[i].dup();
        }
    }

    @Override // org.nd4j.autodiff.listeners.BaseListener, org.nd4j.autodiff.listeners.Listener
    public void opExecution(SameDiff sameDiff, At at, MultiDataSet multiDataSet, SameDiffOp sameDiffOp, INDArray[] iNDArrayArr) {
        INDArray[] inputArguments;
        if (sameDiffOp.getOp().isInPlace()) {
            return;
        }
        if (sameDiffOp.getOp() instanceof Op) {
            Op op = (Op) sameDiffOp.getOp();
            if (op.x() == null) {
                return;
            } else {
                inputArguments = op.y() == null ? new INDArray[]{op.x()} : new INDArray[]{op.x(), op.y()};
            }
        } else {
            if (!(sameDiffOp.getOp() instanceof DynamicCustomOp)) {
                throw new IllegalStateException("Unknown op type: " + sameDiffOp.getOp().getClass());
            }
            inputArguments = ((DynamicCustomOp) sameDiffOp.getOp()).inputArguments();
        }
        try {
            MessageDigest messageDigest = MessageDigest.getInstance("MD5");
            for (int i = 0; i < this.opInputs.length; i++) {
                if (!this.opInputs[i].isEmpty()) {
                    byte[] asBytes = this.opInputs[i].data().asBytes();
                    INDArray iNDArray = inputArguments[i];
                    boolean z = false;
                    if (this.opInputs[i].ordering() != inputArguments[i].ordering() || Arrays.equals(this.opInputs[i].stride(), inputArguments[i].stride()) || this.opInputs[i].elementWiseStride() != inputArguments[i].elementWiseStride()) {
                        iNDArray = inputArguments[i].dup();
                        z = true;
                    }
                    boolean equals = Arrays.equals(messageDigest.digest(asBytes), messageDigest.digest(iNDArray.data().asBytes()));
                    if (equals) {
                        passCounter.addAndGet(1);
                    } else {
                        failCounter.addAndGet(1);
                    }
                    Preconditions.checkState(equals, "Input array for non-inplace op was modified during execution for op %s - input %s", sameDiffOp.getOp().getClass(), Integer.valueOf(i));
                    if (z && iNDArray.closeable()) {
                        iNDArray.close();
                    }
                    if (this.opInputs[i].closeable()) {
                        this.opInputs[i].close();
                    }
                }
            }
        } catch (Throwable th) {
            throw new RuntimeException(th);
        }
    }

    @Override // org.nd4j.autodiff.listeners.Listener
    public boolean isActive(Operation operation) {
        return true;
    }

    public static AtomicInteger getUseCounter() {
        return useCounter;
    }

    public static AtomicInteger getPassCounter() {
        return passCounter;
    }

    public static AtomicInteger getFailCounter() {
        return failCounter;
    }
}
