/*
 * Decompiled with CFR 0.152.
 */
package smile.stat.distribution;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.stat.distribution.MultivariateExponentialFamilyMixture;
import smile.stat.distribution.MultivariateGaussianDistribution;
import smile.stat.distribution.MultivariateMixture;

public class MultivariateGaussianMixture
extends MultivariateExponentialFamilyMixture {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(MultivariateGaussianMixture.class);

    public MultivariateGaussianMixture(MultivariateMixture.Component ... components) {
        this(0.0, 1, components);
    }

    private MultivariateGaussianMixture(double L, int n, MultivariateMixture.Component ... components) {
        super(L, n, components);
        for (MultivariateMixture.Component component : components) {
            if (component.distribution instanceof MultivariateGaussianDistribution) continue;
            throw new IllegalArgumentException("Component " + component + " is not of Gaussian distribution.");
        }
    }

    public static MultivariateGaussianMixture fit(int k, double[][] data) {
        return MultivariateGaussianMixture.fit(k, data, false);
    }

    public static MultivariateGaussianMixture fit(int k, double[][] data, boolean diagonal) {
        int i;
        MultivariateGaussianDistribution gaussian;
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of components in the mixture.");
        }
        int n = data.length;
        int d = data[0].length;
        double[] mu = MathEx.colMeans(data);
        double[] centroid = data[MathEx.randomInt(n)];
        double[] variance = null;
        Matrix cov = null;
        if (diagonal) {
            variance = new double[d];
            for (int i2 = 0; i2 < n; ++i2) {
                double[] x = data[i2];
                for (int j = 0; j < d; ++j) {
                    int n2 = j;
                    variance[n2] = variance[n2] + (x[j] - mu[j]) * (x[j] - mu[j]);
                }
            }
            int n1 = n - 1;
            int j = 0;
            while (j < d) {
                int n3 = j++;
                variance[n3] = variance[n3] / (double)n1;
            }
            gaussian = new MultivariateGaussianDistribution(centroid, variance);
        } else {
            cov = new Matrix(MathEx.cov(data, mu));
            gaussian = new MultivariateGaussianDistribution(centroid, cov);
        }
        MultivariateMixture.Component[] components = new MultivariateMixture.Component[k];
        components[0] = new MultivariateMixture.Component(1.0 / (double)k, gaussian);
        double[] D = new double[n];
        for (i = 0; i < n; ++i) {
            D[i] = Double.MAX_VALUE;
        }
        for (i = 1; i < k; ++i) {
            int index;
            for (int j = 0; j < n; ++j) {
                double dist = MathEx.squaredDistance(data[j], centroid);
                if (!(dist < D[j])) continue;
                D[j] = dist;
            }
            double cutoff = MathEx.random() * MathEx.sum(D);
            double cost = 0.0;
            for (index = 0; index < n && !((cost += D[index]) >= cutoff); ++index) {
            }
            centroid = data[index];
            gaussian = diagonal ? new MultivariateGaussianDistribution(centroid, variance) : new MultivariateGaussianDistribution(centroid, cov);
            components[i] = new MultivariateMixture.Component(1.0 / (double)k, gaussian);
        }
        MultivariateExponentialFamilyMixture model = MultivariateGaussianMixture.fit(data, components);
        return new MultivariateGaussianMixture(model.L, data.length, model.components);
    }

    public static MultivariateGaussianMixture fit(double[][] data) {
        return MultivariateGaussianMixture.fit(data, false);
    }

    public static MultivariateGaussianMixture fit(double[][] data, boolean diagonal) {
        if (data.length < 20) {
            throw new IllegalArgumentException("Too few samples.");
        }
        MultivariateGaussianMixture mixture = new MultivariateGaussianMixture(new MultivariateMixture.Component(1.0, MultivariateGaussianDistribution.fit(data, diagonal)));
        double bic = mixture.bic(data);
        logger.info(String.format("The BIC of %s = %.4f", mixture, bic));
        for (int k = 2; k < data.length / 20; ++k) {
            MultivariateGaussianMixture model = MultivariateGaussianMixture.fit(k, data);
            logger.info(String.format("The BIC of %s = %.4f", model, model.bic));
            if (model.bic <= bic) break;
            mixture = new MultivariateGaussianMixture(model.L, data.length, model.components);
            bic = model.bic;
        }
        return mixture;
    }

    private static MultivariateMixture.Component[] split(MultivariateMixture.Component[] components) {
        int k = components.length;
        int index = -1;
        double maxSigma = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < k; ++i) {
            MultivariateMixture.Component c = components[i];
            double sigma = ((MultivariateGaussianDistribution)c.distribution).scatter();
            if (!(sigma > maxSigma)) continue;
            maxSigma = sigma;
            index = i;
        }
        MultivariateMixture.Component component = components[index];
        double priori = component.priori / 2.0;
        Matrix delta = component.distribution.cov();
        double[] mu = component.distribution.mean();
        MultivariateMixture.Component[] mixture = new MultivariateMixture.Component[k + 1];
        System.arraycopy(components, 0, mixture, 0, k);
        double[] mu1 = new double[mu.length];
        double[] mu2 = new double[mu.length];
        for (int i = 0; i < mu.length; ++i) {
            mu1[i] = mu[i] + Math.sqrt(delta.get(i, i)) / 2.0;
            mu2[i] = mu[i] - Math.sqrt(delta.get(i, i)) / 2.0;
        }
        mixture[index] = new MultivariateMixture.Component(priori, new MultivariateGaussianDistribution(mu1, delta));
        mixture[k] = new MultivariateMixture.Component(priori, new MultivariateGaussianDistribution(mu2, delta));
        return mixture;
    }
}

