/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.preprocessing.encoding;

import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.exceptions.preprocessing.UndefinedLabelException;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPartitionData;
import org.apache.ignite.ml.preprocessing.encoding.EncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.EncoderSortingStrategy;
import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.preprocessing.encoding.frequency.FrequencyEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.label.LabelEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.stringencoder.StringEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.target.TargetCounter;
import org.apache.ignite.ml.preprocessing.encoding.target.TargetEncoderPreprocessor;
import org.apache.ignite.ml.preprocessing.encoding.target.TargetEncodingMeta;
import org.apache.ignite.ml.structures.LabeledVector;
import org.jetbrains.annotations.NotNull;

public class EncoderTrainer<K, V>
implements PreprocessingTrainer<K, V> {
    private Set<Integer> handledIndices = new HashSet<Integer>();
    private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER;
    private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC;
    private Integer targetLabelIndex;
    private Double smoothing = 1.0;
    private Integer minSamplesLeaf = 1;
    private Long minCategorySize = 10L;

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    @Override
    public EncoderPreprocessor<K, V> fit(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> basePreprocessor) {
        if (this.handledIndices.isEmpty() && this.encoderType != EncoderType.LABEL_ENCODER) {
            throw new RuntimeException("Add indices of handled features");
        }
        try (Dataset<EmptyContext, EncoderPartitionData> dataset = datasetBuilder.build(envBuilder, (env, upstream, upstreamSize) -> new EmptyContext(), (env, upstream, upstreamSize, ctx) -> {
            EncoderPartitionData partData = new EncoderPartitionData();
            if (this.encoderType == EncoderType.LABEL_ENCODER) {
                Map<String, Integer> lbFrequencies = null;
                while (upstream.hasNext()) {
                    UpstreamEntry entity = (UpstreamEntry)upstream.next();
                    LabeledVector row = (LabeledVector)basePreprocessor.apply(entity.getKey(), entity.getValue());
                    lbFrequencies = this.updateLabelFrequenciesForNextRow(row, lbFrequencies);
                }
                partData.withLabelFrequencies(lbFrequencies);
            } else if (this.encoderType == EncoderType.TARGET_ENCODER) {
                TargetCounter[] targetCounter = null;
                while (upstream.hasNext()) {
                    UpstreamEntry entity = (UpstreamEntry)upstream.next();
                    LabeledVector row = (LabeledVector)basePreprocessor.apply(entity.getKey(), entity.getValue());
                    targetCounter = this.updateTargetCountersForNextRow(row, targetCounter);
                }
                partData.withTargetCounters(targetCounter);
            } else {
                Map<String, Integer>[] categoryFrequencies = null;
                while (upstream.hasNext()) {
                    UpstreamEntry entity = (UpstreamEntry)upstream.next();
                    LabeledVector row = (LabeledVector)basePreprocessor.apply(entity.getKey(), entity.getValue());
                    categoryFrequencies = this.updateFeatureFrequenciesForNextRow(row, categoryFrequencies);
                }
                partData.withCategoryFrequencies(categoryFrequencies);
            }
            return partData;
        }, this.learningEnvironment(basePreprocessor));){
            switch (this.encoderType) {
                case ONE_HOT_ENCODER: {
                    OneHotEncoderPreprocessor<K, V> oneHotEncoderPreprocessor = new OneHotEncoderPreprocessor<K, V>(this.calculateEncodingValuesByFrequencies(dataset), basePreprocessor, this.handledIndices);
                    return oneHotEncoderPreprocessor;
                }
                case STRING_ENCODER: {
                    StringEncoderPreprocessor<K, V> stringEncoderPreprocessor = new StringEncoderPreprocessor<K, V>(this.calculateEncodingValuesByFrequencies(dataset), basePreprocessor, this.handledIndices);
                    return stringEncoderPreprocessor;
                }
                case LABEL_ENCODER: {
                    LabelEncoderPreprocessor<K, V> labelEncoderPreprocessor = new LabelEncoderPreprocessor<K, V>(this.calculateEncodingValuesForLabelsByFrequencies(dataset), basePreprocessor);
                    return labelEncoderPreprocessor;
                }
                case FREQUENCY_ENCODER: {
                    FrequencyEncoderPreprocessor<K, V> frequencyEncoderPreprocessor = new FrequencyEncoderPreprocessor<K, V>(this.calculateEncodingFrequencies(dataset), basePreprocessor, this.handledIndices);
                    return frequencyEncoderPreprocessor;
                }
                case TARGET_ENCODER: {
                    TargetEncoderPreprocessor<K, V> targetEncoderPreprocessor = new TargetEncoderPreprocessor<K, V>(this.calculateTargetEncodingFrequencies(dataset), basePreprocessor, this.handledIndices);
                    return targetEncoderPreprocessor;
                }
            }
            throw new IllegalStateException("Define the type of the resulting prerocessor.");
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private TargetEncodingMeta[] calculateTargetEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        TargetCounter[] targetCounters = (TargetCounter[])dataset.compute(EncoderPartitionData::targetCounters, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((TargetCounter[])a).length == ((TargetCounter[])b).length);
            for (int i = 0; i < ((TargetCounter[])a).length; ++i) {
                if (!this.handledIndices.contains(i)) continue;
                int finalI = i;
                b[i].setTargetSum(a[i].getTargetSum() + b[i].getTargetSum());
                b[i].setTargetCount(a[i].getTargetCount() + b[i].getTargetCount());
                a[i].getCategoryCounts().forEach((k, v) -> b[finalI].getCategoryCounts().merge((String)k, (Long)v, Long::sum));
                a[i].getCategoryTargetSum().forEach((k, v) -> b[finalI].getCategoryTargetSum().merge((String)k, (Double)v, Double::sum));
            }
            return b;
        });
        TargetEncodingMeta[] targetEncodingMetas = new TargetEncodingMeta[targetCounters.length];
        for (int i = 0; i < targetCounters.length; ++i) {
            if (!this.handledIndices.contains(i)) continue;
            TargetCounter targetCounter = targetCounters[i];
            targetEncodingMetas[i] = new TargetEncodingMeta().withGlobalMean(targetCounter.getTargetSum() / (double)targetCounter.getTargetCount().longValue()).withCategoryMean(this.calculateCategoryTargetEncodingFrequency(targetCounter));
        }
        return targetEncodingMetas;
    }

    private Map<String, Double> calculateCategoryTargetEncodingFrequency(TargetCounter targetCounter) {
        double prior = targetCounter.getTargetSum() / (double)targetCounter.getTargetCount().longValue();
        return targetCounter.getCategoryTargetSum().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, value -> {
            double targetSum = targetCounter.getCategoryTargetSum().get(value.getKey());
            long categorySize = targetCounter.getCategoryCounts().get(value.getKey());
            if (categorySize < this.minCategorySize) {
                return prior;
            }
            double categoryMean = targetSum / (double)categorySize;
            double smoove = 1.0 / (1.0 + Math.exp((double)(-(categorySize - (long)this.minSamplesLeaf.intValue())) / this.smoothing));
            return prior * (1.0 - smoove) + categoryMean * smoove;
        }));
    }

    private Map<String, Double>[] calculateEncodingFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        Map<String, Integer>[] frequencies = this.calculateFrequencies(dataset);
        Map[] res = new Map[frequencies.length];
        int[] counters = new int[frequencies.length];
        for (int i = 0; i < frequencies.length; ++i) {
            counters[i] = frequencies[i].values().stream().reduce(0, Integer::sum);
            int locI = i;
            res[locI] = new HashMap();
            frequencies[i].forEach((k, v) -> res[locI].put(k, (double)v.intValue() / (double)counters[locI]));
        }
        return res;
    }

    private Map<String, Integer>[] calculateFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        return (Map[])dataset.compute(EncoderPartitionData::categoryFrequencies, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            assert (((Map[])a).length == ((Map[])b).length);
            for (int i = 0; i < ((Map[])a).length; ++i) {
                if (!this.handledIndices.contains(i)) continue;
                int finalI = i;
                a[i].forEach((k, v) -> b[finalI].merge(k, v, (f1, f2) -> f1 + f2));
            }
            return b;
        });
    }

    private Map<String, Integer> calculateFrequenciesForLabels(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        return (Map)dataset.compute(EncoderPartitionData::labelFrequencies, (a, b) -> {
            if (a == null) {
                return b;
            }
            if (b == null) {
                return a;
            }
            a.forEach((k, v) -> b.merge(k, v, (f1, f2) -> f1 + f2));
            return b;
        });
    }

    private Map<String, Integer>[] calculateEncodingValuesByFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        Map<String, Integer>[] frequencies = this.calculateFrequencies(dataset);
        Map[] res = new Map[frequencies.length];
        for (int i = 0; i < frequencies.length; ++i) {
            if (!this.handledIndices.contains(i)) continue;
            res[i] = this.transformFrequenciesToEncodingValues(frequencies[i]);
        }
        return res;
    }

    private Map<String, Integer> calculateEncodingValuesForLabelsByFrequencies(Dataset<EmptyContext, EncoderPartitionData> dataset) {
        Map<String, Integer> frequencies = this.calculateFrequenciesForLabels(dataset);
        Map<String, Integer> res = this.transformFrequenciesToEncodingValues(frequencies);
        return res;
    }

    private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
        Comparator comp = this.encoderSortingStgy == EncoderSortingStrategy.FREQUENCY_DESC ? Map.Entry.comparingByValue() : Collections.reverseOrder(Map.Entry.comparingByValue());
        HashMap resMap = frequencies.entrySet().stream().sorted(comp).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new));
        int amountOfLabels = frequencies.size();
        for (Map.Entry entry : resMap.entrySet()) {
            entry.setValue(--amountOfLabels);
        }
        return resMap;
    }

    private Map<String, Integer>[] updateFeatureFrequenciesForNextRow(LabeledVector row, Map<String, Integer>[] categoryFrequencies) {
        if (categoryFrequencies == null) {
            categoryFrequencies = this.initializeCategoryFrequencies(row);
        } else assert (categoryFrequencies.length == row.size()) : "Base preprocessor must return exactly " + categoryFrequencies.length + " features";
        for (int i = 0; i < categoryFrequencies.length; ++i) {
            String strVal;
            if (!this.handledIndices.contains(i)) continue;
            Object featureVal = row.features().getRaw(i);
            if (featureVal.equals(Double.NaN)) {
                strVal = "";
                row.features().setRaw(i, (Serializable)((Object)strVal));
            } else if (featureVal instanceof String) {
                strVal = (String)featureVal;
            } else if (featureVal instanceof Double) {
                strVal = String.valueOf(featureVal);
            } else {
                throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values.");
            }
            Map<String, Integer> map = categoryFrequencies[i];
            if (map.containsKey(strVal)) {
                map.put(strVal, map.get(strVal) + 1);
                continue;
            }
            map.put(strVal, 1);
        }
        return categoryFrequencies;
    }

    private Map<String, Integer> updateLabelFrequenciesForNextRow(LabeledVector row, Map<String, Integer> labelFrequencies) {
        String strVal;
        Object lbVal;
        if (labelFrequencies == null) {
            labelFrequencies = new HashMap<String, Integer>();
        }
        if ((lbVal = row.label()).equals(Double.NaN) || lbVal == null) {
            throw new UndefinedLabelException(row);
        }
        if (lbVal instanceof String) {
            strVal = (String)lbVal;
        } else if (lbVal instanceof Double) {
            strVal = String.valueOf(lbVal);
        } else {
            throw new RuntimeException("The type " + lbVal.getClass() + " is not supported for the feature values.");
        }
        if (labelFrequencies.containsKey(strVal)) {
            labelFrequencies.put(strVal, labelFrequencies.get(strVal) + 1);
        } else {
            labelFrequencies.put(strVal, 1);
        }
        return labelFrequencies;
    }

    @NotNull
    private Map<String, Integer>[] initializeCategoryFrequencies(LabeledVector row) {
        Map[] categoryFrequencies = new Map[row.size()];
        for (int i = 0; i < categoryFrequencies.length; ++i) {
            if (!this.handledIndices.contains(i)) continue;
            categoryFrequencies[i] = new HashMap();
        }
        return categoryFrequencies;
    }

    private TargetCounter[] updateTargetCountersForNextRow(LabeledVector row, TargetCounter[] targetCounters) {
        if (targetCounters == null) {
            targetCounters = this.initializeTargetCounters(row);
        } else assert (targetCounters.length == row.size()) : "Base preprocessor must return exactly " + targetCounters.length + " features";
        double targetValue = row.features().get(this.targetLabelIndex);
        for (int i = 0; i < targetCounters.length; ++i) {
            String strVal;
            if (!this.handledIndices.contains(i)) continue;
            Object featureVal = row.features().getRaw(i);
            if (featureVal.equals(Double.NaN)) {
                strVal = "";
                row.features().setRaw(i, (Serializable)((Object)strVal));
            } else if (featureVal instanceof String) {
                strVal = (String)featureVal;
            } else if (featureVal instanceof Number) {
                strVal = String.valueOf(featureVal);
            } else if (featureVal instanceof Boolean) {
                strVal = String.valueOf(featureVal);
            } else {
                throw new RuntimeException("The type " + featureVal.getClass() + " is not supported for the feature values.");
            }
            TargetCounter targetCounter = targetCounters[i];
            targetCounter.setTargetCount(targetCounter.getTargetCount() + 1L);
            targetCounter.setTargetSum(targetCounter.getTargetSum() + targetValue);
            Map<String, Long> categoryCounts = targetCounter.getCategoryCounts();
            if (categoryCounts.containsKey(strVal)) {
                categoryCounts.put(strVal, categoryCounts.get(strVal) + 1L);
            } else {
                categoryCounts.put(strVal, 1L);
            }
            Map<String, Double> categoryTargetSum = targetCounter.getCategoryTargetSum();
            if (categoryTargetSum.containsKey(strVal)) {
                categoryTargetSum.put(strVal, categoryTargetSum.get(strVal) + targetValue);
                continue;
            }
            categoryTargetSum.put(strVal, targetValue);
        }
        return targetCounters;
    }

    private TargetCounter[] initializeTargetCounters(LabeledVector row) {
        TargetCounter[] targetCounter = new TargetCounter[row.size()];
        for (int i = 0; i < row.size(); ++i) {
            if (!this.handledIndices.contains(i)) continue;
            targetCounter[i] = new TargetCounter();
        }
        return targetCounter;
    }

    public EncoderTrainer<K, V> withEncodedFeature(int idx) {
        this.handledIndices.add(idx);
        return this;
    }

    public EncoderTrainer<K, V> withEncoderIndexingStrategy(EncoderSortingStrategy encoderSortingStgy) {
        this.encoderSortingStgy = encoderSortingStgy;
        return this;
    }

    public EncoderTrainer<K, V> withEncoderType(EncoderType type) {
        this.encoderType = type;
        return this;
    }

    public EncoderTrainer<K, V> withEncodedFeatures(Set<Integer> handledIndices) {
        this.handledIndices.addAll(handledIndices);
        return this;
    }

    public EncoderTrainer<K, V> labeled(Integer targetLabelIndex) {
        this.targetLabelIndex = targetLabelIndex;
        return this;
    }

    public EncoderTrainer<K, V> smoothing(Double smoothing) {
        this.smoothing = smoothing;
        return this;
    }

    public EncoderTrainer<K, V> minSamplesLeaf(Integer minSamplesLeaf) {
        this.minSamplesLeaf = minSamplesLeaf;
        return this;
    }

    public EncoderTrainer<K, V> minCategorySize(Long minCategorySize) {
        this.minCategorySize = minCategorySize;
        return this;
    }
}

