/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartition;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.tree.randomforest.RandomForestModel;
import org.apache.ignite.ml.tree.randomforest.RandomForestTrainer;
import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.statistics.ClassifierLeafValuesComputer;
import org.apache.ignite.ml.tree.randomforest.data.statistics.LeafValuesComputer;

public class RandomForestClassifierTrainer
extends RandomForestTrainer<ObjectHistogram<BootstrappedVector>, GiniHistogram, RandomForestClassifierTrainer> {
    private Map<Double, Integer> lblMapping = new HashMap<Double, Integer>();

    public RandomForestClassifierTrainer(List<FeatureMeta> meta) {
        super(meta);
    }

    @Override
    protected RandomForestClassifierTrainer instance() {
        return this;
    }

    @Override
    protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
        Set uniqLabels = (Set)dataset.compute(x -> {
            HashSet labels = new HashSet();
            for (int i = 0; i < x.getRowsCount(); ++i) {
                labels.add(x.getRow(i).label());
            }
            return labels;
        }, (l, r) -> {
            if (l == null) {
                return r;
            }
            if (r == null) {
                return l;
            }
            HashSet lbls = new HashSet();
            lbls.addAll(l);
            lbls.addAll(r);
            return lbls;
        });
        if (uniqLabels == null) {
            return false;
        }
        int i = 0;
        for (Double label : uniqLabels) {
            this.lblMapping.put(label, i++);
        }
        return super.init(dataset);
    }

    @Override
    protected RandomForestModel buildComposition(List<RandomForestTreeModel> models) {
        return new RandomForestModel(models, (PredictionsAggregator)new OnMajorityPredictionsAggregator());
    }

    @Override
    protected ImpurityHistogramsComputer<GiniHistogram> createImpurityHistogramsComputer() {
        return new GiniHistogramsComputer(this.lblMapping);
    }

    @Override
    protected LeafValuesComputer<ObjectHistogram<BootstrappedVector>> createLeafStatisticsAggregator() {
        return new ClassifierLeafValuesComputer(this.lblMapping);
    }

    public RandomForestClassifierTrainer withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
        return (RandomForestClassifierTrainer)super.withEnvironmentBuilder(envBuilder);
    }
}

