package org.nd4j.linalg.dataset.api.preprocessor.stats;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/stats/DistributionStats.class */
public class DistributionStats implements NormalizerStats {
    private static final Logger logger = LoggerFactory.getLogger(NormalizerStandardize.class);
    private final INDArray mean;
    private final INDArray std;

    /* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/stats/DistributionStats$Builder.class */
    public static class Builder implements NormalizerStats.Builder<DistributionStats> {
        private long runningCount = 0;
        private INDArray runningMean;
        private INDArray runningVariance;

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: addFeatures */
        public NormalizerStats.Builder<DistributionStats> addFeatures2(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet is marked @NonNull but is null");
            }
            return add2(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: addLabels */
        public NormalizerStats.Builder<DistributionStats> addLabels2(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet is marked @NonNull but is null");
            }
            return add2(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        /* renamed from: add */
        public NormalizerStats.Builder<DistributionStats> add2(@NonNull INDArray iNDArray, INDArray iNDArray2) {
            if (iNDArray == null) {
                throw new NullPointerException("data is marked @NonNull but is null");
            }
            INDArray tailor2d = DataSetUtil.tailor2d(iNDArray, iNDArray2);
            if (tailor2d == null) {
                return this;
            }
            INDArray reshape = tailor2d.mean(0).reshape(1L, tailor2d.size(1));
            INDArray reshape2 = tailor2d.var(false, 0).reshape(1L, tailor2d.size(1));
            long size = tailor2d.size(0);
            if (this.runningMean == null) {
                this.runningMean = reshape;
                this.runningVariance = reshape2;
                this.runningCount = size;
                if (tailor2d.size(0) == 1) {
                    this.runningMean = this.runningMean.dup();
                    this.runningVariance = this.runningVariance.dup();
                }
            } else {
                this.runningVariance.muli(Long.valueOf(this.runningCount)).addiRowVector(reshape2.muli(Long.valueOf(size))).addiRowVector(Transforms.pow(reshape.subRowVector(this.runningMean), (Number) 2).muli(Float.valueOf(((float) (this.runningCount * size)) / ((float) (this.runningCount + size))))).divi(Long.valueOf(this.runningCount + size));
                this.runningCount += size;
                this.runningMean.addi(tailor2d.subRowVector(this.runningMean).sum(0).divi(Long.valueOf(this.runningCount)));
            }
            return this;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats.Builder
        public DistributionStats build() {
            if (this.runningMean == null) {
                throw new RuntimeException("No data was added, statistics cannot be determined");
            }
            return new DistributionStats(this.runningMean.dup(), Transforms.sqrt(this.runningVariance, true));
        }
    }

    public DistributionStats(@NonNull INDArray iNDArray, @NonNull INDArray iNDArray2) {
        if (iNDArray == null) {
            throw new NullPointerException("mean is marked @NonNull but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("std is marked @NonNull but is null");
        }
        Transforms.max(iNDArray2, Nd4j.EPS_THRESHOLD, false);
        this.mean = iNDArray;
        this.std = iNDArray2;
    }

    public static DistributionStats load(@NonNull File file, @NonNull File file2) throws IOException {
        if (file == null) {
            throw new NullPointerException("meanFile is marked @NonNull but is null");
        }
        if (file2 == null) {
            throw new NullPointerException("stdFile is marked @NonNull but is null");
        }
        return new DistributionStats(Nd4j.readBinary(file), Nd4j.readBinary(file2));
    }

    public void save(@NonNull File file, @NonNull File file2) throws IOException {
        if (file == null) {
            throw new NullPointerException("meanFile is marked @NonNull but is null");
        }
        if (file2 == null) {
            throw new NullPointerException("stdFile is marked @NonNull but is null");
        }
        Nd4j.saveBinary(getMean(), file);
        Nd4j.saveBinary(getStd(), file2);
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof DistributionStats)) {
            return false;
        }
        DistributionStats distributionStats = (DistributionStats) obj;
        if (!distributionStats.canEqual(this)) {
            return false;
        }
        INDArray mean = getMean();
        INDArray mean2 = distributionStats.getMean();
        if (mean == null) {
            if (mean2 != null) {
                return false;
            }
        } else if (!mean.equals(mean2)) {
            return false;
        }
        INDArray std = getStd();
        INDArray std2 = distributionStats.getStd();
        return std == null ? std2 == null : std.equals(std2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof DistributionStats;
    }

    public int hashCode() {
        INDArray mean = getMean();
        int hashCode = (1 * 59) + (mean == null ? 43 : mean.hashCode());
        INDArray std = getStd();
        return (hashCode * 59) + (std == null ? 43 : std.hashCode());
    }
}
