/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.gmm;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.clustering.gmm.GmmPartitionData;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;

public class CovarianceMatricesAggregator
implements Serializable {
    private static final long serialVersionUID = 4163253784526780812L;
    private final Vector mean;
    private Matrix weightedSum;
    private int rowCnt;

    CovarianceMatricesAggregator(Vector mean) {
        this.mean = mean;
    }

    CovarianceMatricesAggregator(Vector mean, Matrix weightedSum, int rowCnt) {
        this.mean = mean;
        this.weightedSum = weightedSum;
        this.rowCnt = rowCnt;
    }

    static List<Matrix> computeCovariances(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs, Vector[] means) {
        List aggregators = (List)dataset.compute(data -> CovarianceMatricesAggregator.map(data, means), CovarianceMatricesAggregator::reduce);
        if (aggregators == null) {
            return Collections.emptyList();
        }
        ArrayList<Matrix> res = new ArrayList<Matrix>();
        for (int i = 0; i < aggregators.size(); ++i) {
            res.add(((CovarianceMatricesAggregator)aggregators.get(i)).covariance(clusterProbs.get(i)));
        }
        return res;
    }

    void add(Vector x, double pcxi) {
        Matrix deltaCol = x.minus(this.mean).toMatrix(false);
        Matrix weightedCovComponent = deltaCol.times(deltaCol.transpose()).times(pcxi);
        this.weightedSum = this.weightedSum == null ? weightedCovComponent : this.weightedSum.plus(weightedCovComponent);
        ++this.rowCnt;
    }

    CovarianceMatricesAggregator plus(CovarianceMatricesAggregator other) {
        A.ensure((boolean)this.mean.equals(other.mean), (String)"this.mean == other.mean");
        return new CovarianceMatricesAggregator(this.mean, this.weightedSum.plus(other.weightedSum), this.rowCnt + other.rowCnt);
    }

    static List<CovarianceMatricesAggregator> map(GmmPartitionData data, Vector[] means) {
        int i;
        int cntOfComponents = means.length;
        ArrayList<CovarianceMatricesAggregator> aggregators = new ArrayList<CovarianceMatricesAggregator>();
        for (i = 0; i < cntOfComponents; ++i) {
            aggregators.add(new CovarianceMatricesAggregator(means[i]));
        }
        for (i = 0; i < data.size(); ++i) {
            for (int c = 0; c < cntOfComponents; ++c) {
                ((CovarianceMatricesAggregator)aggregators.get(c)).add(data.getX(i), data.pcxi(c, i));
            }
        }
        return aggregators;
    }

    private Matrix covariance(double clusterProb) {
        return this.weightedSum.divide((double)this.rowCnt * clusterProb);
    }

    static List<CovarianceMatricesAggregator> reduce(List<CovarianceMatricesAggregator> l, List<CovarianceMatricesAggregator> r) {
        A.ensure((l != null || r != null ? 1 : 0) != 0, (String)"Both partitions cannot equal to null");
        if (l == null || l.isEmpty()) {
            return r;
        }
        if (r == null || r.isEmpty()) {
            return l;
        }
        A.ensure((l.size() == r.size() ? 1 : 0) != 0, (String)"l.size() == r.size()");
        ArrayList<CovarianceMatricesAggregator> res = new ArrayList<CovarianceMatricesAggregator>();
        for (int i = 0; i < l.size(); ++i) {
            res.add(l.get(i).plus(r.get(i)));
        }
        return res;
    }

    Vector mean() {
        return this.mean.copy();
    }

    Matrix weightedSum() {
        return this.weightedSum.copy();
    }

    public int rowCount() {
        return this.rowCnt;
    }
}

