/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.naivebayes.gaussian;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesSumsHolder;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class GaussianNaiveBayesTrainer
extends SingleLabelDatasetTrainer<GaussianNaiveBayesModel> {
    private double[] priorProbabilities;
    private boolean equiprobableClasses;

    @Override
    public <K, V> GaussianNaiveBayesModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((GaussianNaiveBayesModel)null, datasetBuilder, extractor);
    }

    @Override
    public boolean isUpdateable(GaussianNaiveBayesModel mdl) {
        return true;
    }

    public GaussianNaiveBayesTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (GaussianNaiveBayesTrainer)super.withEnvironmentBuilder(envBuilder);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        assert (datasetBuilder != null);
        try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(this.envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
            GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
            while (upstream.hasNext()) {
                double[] sqSum;
                double[] toMeans;
                UpstreamEntry entity = (UpstreamEntry)upstream.next();
                LabeledVector lv = (LabeledVector)extractor.apply(entity.getKey(), entity.getValue());
                Object features = lv.features();
                Double label = (Double)lv.label();
                if (!res.featureSumsPerLbl.containsKey(label)) {
                    toMeans = new double[features.size()];
                    Arrays.fill(toMeans, 0.0);
                    res.featureSumsPerLbl.put(label, toMeans);
                }
                if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
                    sqSum = new double[features.size()];
                    res.featureSquaredSumsPerLbl.put(label, sqSum);
                }
                if (!res.featureCountersPerLbl.containsKey(label)) {
                    res.featureCountersPerLbl.put(label, 0);
                }
                res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label) + 1);
                toMeans = res.featureSumsPerLbl.get(label);
                sqSum = res.featureSquaredSumsPerLbl.get(label);
                int j = 0;
                while (j < features.size()) {
                    double x = features.get(j);
                    int n = j;
                    toMeans[n] = toMeans[n] + x;
                    int n2 = j++;
                    sqSum[n2] = sqSum[n2] + x * x;
                }
            }
            return res;
        }, this.learningEnvironment());){
            GaussianNaiveBayesSumsHolder sumsHolder = (GaussianNaiveBayesSumsHolder)dataset.compute(t -> t, (a, b) -> {
                if (a == null) {
                    return b;
                }
                if (b == null) {
                    return a;
                }
                return a.merge((GaussianNaiveBayesSumsHolder)b);
            });
            if (mdl != null && mdl.getSumsHolder() != null) {
                sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
            }
            ArrayList<Double> sortedLabels = new ArrayList<Double>(sumsHolder.featureCountersPerLbl.keySet());
            sortedLabels.sort(Double::compareTo);
            assert (!sortedLabels.isEmpty()) : "The dataset should contain at least one feature";
            int labelCount = sortedLabels.size();
            int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
            double[][] means = new double[labelCount][featureCount];
            double[][] variances = new double[labelCount][featureCount];
            double[] classProbabilities = new double[labelCount];
            double[] labels = new double[labelCount];
            long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
            int lbl = 0;
            for (Double label : sortedLabels) {
                int count = sumsHolder.featureCountersPerLbl.get(label);
                double[] sum = sumsHolder.featureSumsPerLbl.get(label);
                double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
                for (int i2 = 0; i2 < featureCount; ++i2) {
                    means[lbl][i2] = sum[i2] / (double)count;
                    variances[lbl][i2] = (sqSum[i2] - sum[i2] * sum[i2] / (double)count) / (double)count;
                }
                if (this.equiprobableClasses) {
                    classProbabilities[lbl] = 1.0 / (double)labelCount;
                } else if (this.priorProbabilities != null) {
                    assert (classProbabilities.length == this.priorProbabilities.length);
                    classProbabilities[lbl] = this.priorProbabilities[lbl];
                } else {
                    classProbabilities[lbl] = (double)count / (double)datasetSize;
                }
                labels[lbl] = label;
                ++lbl;
            }
            GaussianNaiveBayesModel gaussianNaiveBayesModel = new GaussianNaiveBayesModel(means, variances, classProbabilities, labels, sumsHolder);
            return gaussianNaiveBayesModel;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public GaussianNaiveBayesTrainer withEquiprobableClasses() {
        this.resetSettings();
        this.equiprobableClasses = true;
        return this;
    }

    public GaussianNaiveBayesTrainer setPriorProbabilities(double[] priorProbabilities) {
        this.resetSettings();
        this.priorProbabilities = (double[])priorProbabilities.clone();
        return this;
    }

    public GaussianNaiveBayesTrainer resetSettings() {
        this.equiprobableClasses = false;
        this.priorProbabilities = null;
        return this;
    }
}

