/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.naivebayes.discrete;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Optional;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesSumsHolder;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;

public class DiscreteNaiveBayesTrainer
extends SingleLabelDatasetTrainer<DiscreteNaiveBayesModel> {
    private static final double PRECISION = 1.0E-10;
    private double[] priorProbabilities;
    private boolean equiprobableClasses;
    private double[][] bucketThresholds;

    @Override
    public <K, V> DiscreteNaiveBayesModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        return this.updateModel((DiscreteNaiveBayesModel)null, datasetBuilder, extractor);
    }

    @Override
    public boolean isUpdateable(DiscreteNaiveBayesModel mdl) {
        if (mdl.getBucketThresholds().length != this.bucketThresholds.length) {
            return false;
        }
        for (int i = 0; i < this.bucketThresholds.length; ++i) {
            int j = 0;
            while (i < this.bucketThresholds[i].length) {
                if (Math.abs(mdl.getBucketThresholds()[i][j] - this.bucketThresholds[i][j]) > 1.0E-10) {
                    return false;
                }
                ++i;
            }
        }
        return true;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    protected <K, V> DiscreteNaiveBayesModel updateModel(DiscreteNaiveBayesModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
        try (Dataset<EmptyContext, DiscreteNaiveBayesSumsHolder> dataset = datasetBuilder.build(this.envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
            DiscreteNaiveBayesSumsHolder res = new DiscreteNaiveBayesSumsHolder();
            while (upstream.hasNext()) {
                Object valuesInBucket;
                UpstreamEntry entity = (UpstreamEntry)upstream.next();
                LabeledVector lv = (LabeledVector)extractor.apply(entity.getKey(), entity.getValue());
                Object features = lv.features();
                Double lb = (Double)lv.label();
                int size = features.size();
                if (!res.valuesInBucketPerLbl.containsKey(lb)) {
                    valuesInBucket = new long[size][];
                    for (int i = 0; i < size; ++i) {
                        valuesInBucket[i] = new long[this.bucketThresholds[i].length + 1];
                        Arrays.fill(valuesInBucket[i], 0L);
                    }
                    res.valuesInBucketPerLbl.put(lb, (long[][])valuesInBucket);
                }
                if (!res.featureCountersPerLbl.containsKey(lb)) {
                    res.featureCountersPerLbl.put(lb, 0);
                }
                res.featureCountersPerLbl.put(lb, res.featureCountersPerLbl.get(lb) + 1);
                valuesInBucket = res.valuesInBucketPerLbl.get(lb);
                for (int j = 0; j < size; ++j) {
                    double x = features.get(j);
                    int bucketNum = this.toBucketNumber(x, this.bucketThresholds[j]);
                    long[] lArray = valuesInBucket[j];
                    int n = bucketNum;
                    lArray[n] = lArray[n] + 1L;
                }
            }
            return res;
        }, this.learningEnvironment());){
            DiscreteNaiveBayesSumsHolder sumsHolder = (DiscreteNaiveBayesSumsHolder)dataset.compute(t -> t, (a, b) -> {
                if (a == null) {
                    return b;
                }
                if (b == null) {
                    return a;
                }
                return a.merge((DiscreteNaiveBayesSumsHolder)b);
            });
            if (mdl != null && this.isUpdateable(mdl) && this.checkSumsHolder(sumsHolder, mdl.getSumsHolder())) {
                sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
            }
            ArrayList<Double> sortedLabels = new ArrayList<Double>(sumsHolder.featureCountersPerLbl.keySet());
            sortedLabels.sort(Double::compareTo);
            assert (!sortedLabels.isEmpty()) : "The dataset should contain at least one feature";
            int lbCnt = sortedLabels.size();
            int featureCnt = sumsHolder.valuesInBucketPerLbl.get(sortedLabels.get(0)).length;
            double[][][] probabilities = new double[lbCnt][featureCnt][];
            double[] classProbabilities = new double[lbCnt];
            double[] labels = new double[lbCnt];
            long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i -> i).sum();
            int lbl = 0;
            Object discreteNaiveBayesModel = sortedLabels.iterator();
            while (true) {
                long[][] sum;
                int cnt;
                Double label;
                if (discreteNaiveBayesModel.hasNext()) {
                    label = (Double)discreteNaiveBayesModel.next();
                    cnt = sumsHolder.featureCountersPerLbl.get(label);
                    sum = sumsHolder.valuesInBucketPerLbl.get(label);
                } else {
                    discreteNaiveBayesModel = new DiscreteNaiveBayesModel(probabilities, classProbabilities, labels, this.bucketThresholds, sumsHolder);
                    return discreteNaiveBayesModel;
                }
                for (int i2 = 0; i2 < featureCnt; ++i2) {
                    int bucketsCnt = sum[i2].length;
                    probabilities[lbl][i2] = new double[bucketsCnt];
                    for (int j = 0; j < bucketsCnt; ++j) {
                        probabilities[lbl][i2][j] = (double)sum[i2][j] / (double)cnt;
                    }
                }
                if (this.equiprobableClasses) {
                    classProbabilities[lbl] = 1.0 / (double)lbCnt;
                } else if (this.priorProbabilities != null) {
                    assert (classProbabilities.length == this.priorProbabilities.length);
                    classProbabilities[lbl] = this.priorProbabilities[lbl];
                } else {
                    classProbabilities[lbl] = (double)cnt / (double)datasetSize;
                }
                labels[lbl] = label;
                ++lbl;
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private boolean checkSumsHolder(DiscreteNaiveBayesSumsHolder holder1, DiscreteNaiveBayesSumsHolder holder2) {
        if (holder1 == null || holder2 == null) {
            return false;
        }
        Optional<long[][]> optionalFirst = holder1.valuesInBucketPerLbl.values().stream().findFirst();
        Optional<long[][]> optionalSecond = holder2.valuesInBucketPerLbl.values().stream().findFirst();
        if (optionalFirst.isPresent()) {
            if (optionalSecond.isPresent()) {
                return optionalFirst.get().length == optionalSecond.get().length;
            }
            return false;
        }
        return !optionalSecond.isPresent();
    }

    public DiscreteNaiveBayesTrainer withEquiprobableClasses() {
        this.resetProbabilitiesSettings();
        this.equiprobableClasses = true;
        return this;
    }

    public DiscreteNaiveBayesTrainer setPriorProbabilities(double[] priorProbabilities) {
        this.resetProbabilitiesSettings();
        this.priorProbabilities = (double[])priorProbabilities.clone();
        return this;
    }

    public DiscreteNaiveBayesTrainer setBucketThresholds(double[][] bucketThresholds) {
        this.bucketThresholds = bucketThresholds;
        return this;
    }

    public DiscreteNaiveBayesTrainer resetProbabilitiesSettings() {
        this.equiprobableClasses = false;
        this.priorProbabilities = null;
        return this;
    }

    private int toBucketNumber(double val, double[] thresholds) {
        for (int i = 0; i < thresholds.length; ++i) {
            if (!(val < thresholds[i])) continue;
            return i;
        }
        return thresholds.length;
    }
}

