/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.clustering.gmm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.clustering.gmm.CovarianceMatricesAggregator;
import org.apache.ignite.ml.clustering.gmm.GmmModel;
import org.apache.ignite.ml.clustering.gmm.GmmPartitionData;
import org.apache.ignite.ml.clustering.gmm.MeanWithClusterProbAggregator;
import org.apache.ignite.ml.clustering.gmm.NewComponentStatisticsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.math.exceptions.math.SingularMatrixException;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.DatasetRow;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;

public class GmmTrainer
extends SingleLabelDatasetTrainer<GmmModel> {
    private double eps = 0.001;
    private int countOfComponents = 2;
    private int maxCountOfIterations = 10;
    private Vector[] initialMeans;
    private int maxCountOfInitTries = 3;
    private int maxCountOfClusters = 2;
    private double maxLikelihoodDivergence = 5.0;
    private double minElementsForNewCluster = 300.0;
    private double minClusterProbability = 0.05;

    public GmmTrainer() {
    }

    public GmmTrainer(int countOfComponents, int maxCountOfIterations) {
        this.countOfComponents = countOfComponents;
        this.maxCountOfIterations = maxCountOfIterations;
    }

    public GmmTrainer(int countOfComponents) {
        this.countOfComponents = countOfComponents;
    }

    @Override
    public <K, V> GmmModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((GmmModel)null, datasetBuilder, extractor);
    }

    private static IgniteBiFunction<GmmPartitionData, LearningEnvironment, Vector[][]> selectNRandomXsMapper(int n) {
        return (data, env) -> {
            Vector[] result = data.size() <= n ? (Vector[])data.getAllXs().stream().map(DatasetRow::features).toArray(Vector[]::new) : (Vector[])env.randomNumbersGenerator().ints(0, data.size()).distinct().mapToObj(data::getX).limit(n).toArray(Vector[]::new);
            return new Vector[][]{result};
        };
    }

    private static Vector[][] selectNRandomXsReducer(Vector[][] l, Vector[][] r) {
        A.ensure((l != null || r != null ? 1 : 0) != 0, (String)"l != null || r != null");
        if (l == null) {
            return r;
        }
        if (r == null) {
            return l;
        }
        Vector[][] res = new Vector[l.length + r.length][];
        System.arraycopy(l, 0, res, 0, l.length);
        System.arraycopy(r, 0, res, l.length, r.length);
        return res;
    }

    public GmmTrainer withInitialCountOfComponents(int numberOfComponents) {
        A.ensure((numberOfComponents > 0 ? 1 : 0) != 0, (String)"Number of components in GMM cannot equal 0");
        this.countOfComponents = numberOfComponents;
        this.initialMeans = null;
        if (this.countOfComponents > this.maxCountOfClusters) {
            this.maxCountOfClusters = this.countOfComponents;
        }
        return this;
    }

    public GmmTrainer withInitialMeans(List<Vector> means) {
        A.notEmpty(means, (String)"GMM should start with non empty initial components list");
        this.initialMeans = means.toArray(new Vector[means.size()]);
        this.countOfComponents = means.size();
        if (this.countOfComponents > this.maxCountOfClusters) {
            this.maxCountOfClusters = this.countOfComponents;
        }
        return this;
    }

    public GmmTrainer withMaxCountIterations(int maxCountOfIterations) {
        A.ensure((maxCountOfIterations > 0 ? 1 : 0) != 0, (String)"Max count iterations cannot be less or equal zero or negative");
        this.maxCountOfIterations = maxCountOfIterations;
        return this;
    }

    public GmmTrainer withEps(double eps) {
        A.ensure((eps > 0.0 && eps < 1.0 ? 1 : 0) != 0, (String)"Min divergence beween iterations should be between 0.0 and 1.0");
        this.eps = eps;
        return this;
    }

    public GmmTrainer withMaxCountOfInitTries(int maxCountOfInitTries) {
        A.ensure((maxCountOfInitTries > 0 ? 1 : 0) != 0, (String)"Max initialization count should be great than zero.");
        this.maxCountOfInitTries = maxCountOfInitTries;
        return this;
    }

    public GmmTrainer withMaxCountOfClusters(int maxCountOfClusters) {
        A.ensure((maxCountOfClusters >= this.countOfComponents ? 1 : 0) != 0, (String)"Max count of components should be greater than initial count of components or equal to it");
        this.maxCountOfClusters = maxCountOfClusters;
        return this;
    }

    public GmmTrainer withMaxLikelihoodDivergence(double maxLikelihoodDivergence) {
        A.ensure((maxLikelihoodDivergence > 0.0 ? 1 : 0) != 0, (String)"Max likelihood divergence should be > 0");
        this.maxLikelihoodDivergence = maxLikelihoodDivergence;
        return this;
    }

    private Optional<GmmModel> fit(Dataset<EmptyContext, GmmPartitionData> dataset) {
        return this.init(dataset).map(model -> {
            GmmModel currentModel = model;
            while (true) {
                UpdateResult updateResult = this.updateModel(dataset, currentModel);
                currentModel = updateResult.model;
                double minCompProb = currentModel.componentsProbs().minElement().get();
                if (this.countOfComponents >= this.maxCountOfClusters || minCompProb < this.minClusterProbability) break;
                double maxXProb = updateResult.maxProbInDataset;
                NewComponentStatisticsAggregator newMeanAdder = NewComponentStatisticsAggregator.computeNewMean(dataset, maxXProb, this.maxLikelihoodDivergence, currentModel);
                Vector newMean = newMeanAdder.mean();
                if ((double)newMeanAdder.rowCountForNewCluster() < this.minElementsForNewCluster) break;
                ++this.countOfComponents;
                Vector[] newMeans = new Vector[this.countOfComponents];
                for (int i = 0; i < currentModel.countOfComponents(); ++i) {
                    newMeans[i] = ((MultivariateGaussianDistribution)currentModel.distributions().get(i)).mean();
                }
                newMeans[this.countOfComponents - 1] = newMean;
                this.initialMeans = newMeans;
                Optional<GmmModel> newModelOpt = this.init(dataset);
                if (!newModelOpt.isPresent()) break;
                currentModel = newModelOpt.get();
            }
            return this.filterModel(currentModel);
        });
    }

    public GmmTrainer withMinElementsForNewCluster(int minElementsForNewCluster) {
        A.ensure((minElementsForNewCluster > 0 ? 1 : 0) != 0, (String)"Min elements for new cluster should be > 0");
        this.minElementsForNewCluster = minElementsForNewCluster;
        return this;
    }

    public GmmTrainer withMinClusterProbability(double minClusterProbability) {
        this.minClusterProbability = minClusterProbability;
        return this;
    }

    private GmmModel filterModel(GmmModel model) {
        ArrayList<Double> componentProbs = new ArrayList<Double>();
        ArrayList<MultivariateGaussianDistribution> distributions = new ArrayList<MultivariateGaussianDistribution>();
        Vector originalComponentProbs = model.componentsProbs();
        List originalDistr = model.distributions();
        for (int i = 0; i < model.countOfComponents(); ++i) {
            double prob = originalComponentProbs.get(i);
            if (!(prob > this.minClusterProbability)) continue;
            componentProbs.add(prob);
            distributions.add((MultivariateGaussianDistribution)originalDistr.get(i));
        }
        return new GmmModel(VectorUtils.of(componentProbs.toArray(new Double[0])), (List<MultivariateGaussianDistribution>)distributions);
    }

    @NotNull
    private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
        boolean isConverged = false;
        int countOfIterations = 0;
        double maxProbInDataset = Double.NEGATIVE_INFINITY;
        while (!isConverged) {
            MeanWithClusterProbAggregator.AggregatedStats stats = MeanWithClusterProbAggregator.aggreateStats(dataset, this.countOfComponents);
            Vector clusterProbs = stats.clusterProbabilities();
            Vector[] newMeans = stats.means().toArray(new Vector[this.countOfComponents]);
            A.ensure((newMeans.length == model.countOfComponents() ? 1 : 0) != 0, (String)"newMeans.size() == count of components");
            A.ensure((newMeans[0].size() == this.initialMeans[0].size() ? 1 : 0) != 0, (String)"newMeans[0].size() == initialMeans[0].size()");
            List<Matrix> newCovs = CovarianceMatricesAggregator.computeCovariances(dataset, clusterProbs, newMeans);
            try {
                List<MultivariateGaussianDistribution> components = this.buildComponents(newMeans, newCovs);
                GmmModel newModel = new GmmModel(clusterProbs, components);
                isConverged = this.isConverged(model, newModel) || ++countOfIterations > this.maxCountOfIterations;
                model = newModel;
                maxProbInDataset = GmmPartitionData.updatePcxiAndComputeLikelihood(dataset, clusterProbs, components);
            }
            catch (IllegalArgumentException | SingularMatrixException e) {
                String msg = "Cannot construct non-singular covariance matrix by data. Try to select other initial means or other model trainer. Iterations will stop.";
                this.environment.logger().log(MLLogger.VerboseLevel.HIGH, msg, new Object[0]);
                isConverged = true;
            }
        }
        return new UpdateResult(model, maxProbInDataset);
    }

    private Optional<GmmModel> init(Dataset<EmptyContext, GmmPartitionData> dataset) {
        int cntOfTries = 0;
        while (true) {
            try {
                if (this.initialMeans == null) {
                    List randomMeansSets = Stream.of((Object[])dataset.compute(GmmTrainer.selectNRandomXsMapper(this.countOfComponents), GmmTrainer::selectNRandomXsReducer)).flatMap(Stream::of).sorted(Comparator.comparingDouble(Vector::getLengthSquared)).collect(Collectors.toList());
                    Collections.shuffle(randomMeansSets, this.environment.randomNumbersGenerator());
                    A.ensure((randomMeansSets.size() >= this.countOfComponents ? 1 : 0) != 0, (String)"There is not enough data in dataset for select N random means");
                    this.initialMeans = randomMeansSets.subList(0, this.countOfComponents).toArray(new Vector[this.countOfComponents]);
                }
                dataset.compute(data -> GmmPartitionData.estimateLikelihoodClusters(data, this.initialMeans));
                List<Matrix> initialCovs = CovarianceMatricesAggregator.computeCovariances(dataset, VectorUtils.fill(1.0 / (double)this.countOfComponents, this.countOfComponents), this.initialMeans);
                if (initialCovs.isEmpty()) {
                    return Optional.empty();
                }
                ArrayList<MultivariateGaussianDistribution> distributions = new ArrayList<MultivariateGaussianDistribution>();
                for (int i = 0; i < this.countOfComponents; ++i) {
                    distributions.add(new MultivariateGaussianDistribution(this.initialMeans[i], initialCovs.get(i)));
                }
                return Optional.of(new GmmModel(VectorUtils.of(DoubleStream.generate(() -> 1.0 / (double)this.countOfComponents).limit(this.countOfComponents).toArray()), (List<MultivariateGaussianDistribution>)distributions));
            }
            catch (IllegalArgumentException | SingularMatrixException e) {
                String msg = "Cannot construct non-singular covariance matrix by data. Try to select other initial means or other model trainer [number of tries = " + cntOfTries + "]";
                this.environment.logger().log(MLLogger.VerboseLevel.HIGH, msg, new Object[0]);
                this.initialMeans = null;
                if (++cntOfTries < this.maxCountOfInitTries) continue;
                throw new RuntimeException(msg, (Throwable)e);
            }
            break;
        }
    }

    @Override
    public boolean isUpdateable(GmmModel mdl) {
        return mdl.countOfComponents() == this.countOfComponents;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(this.envBuilder, new EmptyContextBuilder(), new GmmPartitionData.Builder<K, V>(extractor, this.maxCountOfClusters), this.learningEnvironment());){
            Optional<GmmModel> model;
            if (mdl != null) {
                if (this.initialMeans != null) {
                    this.environment.logger().log(MLLogger.VerboseLevel.HIGH, "Initial means will be replaced by model from update", new Object[0]);
                }
                this.initialMeans = (Vector[])mdl.distributions().stream().map(MultivariateGaussianDistribution::mean).toArray(Vector[]::new);
            }
            if ((model = this.fit(dataset)).isPresent()) {
                GmmModel gmmModel = model.get();
                return gmmModel;
            }
            if (mdl != null) {
                GmmModel gmmModel = mdl;
                return gmmModel;
            }
            throw new IllegalArgumentException("Cannot learn model on empty dataset.");
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private List<MultivariateGaussianDistribution> buildComponents(Vector[] means, List<Matrix> covs) {
        A.ensure((means.length == covs.size() ? 1 : 0) != 0, (String)"means.size() == covs.size()");
        ArrayList<MultivariateGaussianDistribution> res = new ArrayList<MultivariateGaussianDistribution>();
        for (int i = 0; i < means.length; ++i) {
            res.add(new MultivariateGaussianDistribution(means[i], covs.get(i)));
        }
        return res;
    }

    private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
        A.ensure((oldModel.countOfComponents() == newModel.countOfComponents() ? 1 : 0) != 0, (String)"oldModel.countOfComponents() == newModel.countOfComponents()");
        for (int i = 0; i < oldModel.countOfComponents(); ++i) {
            MultivariateGaussianDistribution d1 = (MultivariateGaussianDistribution)oldModel.distributions().get(i);
            MultivariateGaussianDistribution d2 = (MultivariateGaussianDistribution)newModel.distributions().get(i);
            if (!(Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= this.eps)) continue;
            return false;
        }
        return true;
    }

    private static class UpdateResult {
        private final GmmModel model;
        private final double maxProbInDataset;

        public UpdateResult(GmmModel model, double maxProbInDataset) {
            this.model = model;
            this.maxProbInDataset = maxProbInDataset;
        }
    }
}

