/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.tree.randomforest.data.statistics;

import java.util.Comparator;
import java.util.Map;
import org.apache.ignite.ml.dataset.feature.ObjectHistogram;
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.tree.randomforest.data.statistics.LeafValuesComputer;

public class ClassifierLeafValuesComputer
extends LeafValuesComputer<ObjectHistogram<BootstrappedVector>> {
    private static final long serialVersionUID = 420416095877577599L;
    private final Map<Double, Integer> lblMapping;

    public ClassifierLeafValuesComputer(Map<Double, Integer> lblMapping) {
        this.lblMapping = lblMapping;
    }

    @Override
    protected void addElementToLeafStatistic(ObjectHistogram<BootstrappedVector> leafStatAggr, BootstrappedVector vec, int sampleId) {
        leafStatAggr.addElement(vec);
    }

    @Override
    protected ObjectHistogram<BootstrappedVector> mergeLeafStats(ObjectHistogram<BootstrappedVector> leftStats, ObjectHistogram<BootstrappedVector> rightStats) {
        return leftStats.plus(rightStats);
    }

    @Override
    protected ObjectHistogram<BootstrappedVector> createLeafStatsAggregator(int sampleId) {
        return new LeafStatsHistogram(this.lblMapping, sampleId);
    }

    @Override
    protected double computeLeafValue(ObjectHistogram<BootstrappedVector> stat) {
        Integer bucketId = stat.buckets().stream().max(Comparator.comparing(b -> stat.getValue((Integer)b).orElse(0.0))).orElse(-1);
        if (bucketId == -1) {
            return Double.NaN;
        }
        return (Double)this.lblMapping.entrySet().stream().filter(x -> ((Integer)x.getValue()).equals(bucketId)).findFirst().get().getKey();
    }

    private static class LeafStatsHistogram
    extends ObjectHistogram<BootstrappedVector> {
        private static final long serialVersionUID = 4017587488421003308L;
        private final Map<Double, Integer> lblMapping;
        private final int sampleId;

        public LeafStatsHistogram(Map<Double, Integer> lblMapping, int sampleId) {
            this.lblMapping = lblMapping;
            this.sampleId = sampleId;
        }

        @Override
        public Integer mapToBucket(BootstrappedVector x) {
            return this.lblMapping.get(x.label());
        }

        @Override
        public Double mapToCounter(BootstrappedVector x) {
            return x.counters()[this.sampleId];
        }

        @Override
        public ObjectHistogram<BootstrappedVector> newInstance() {
            return new LeafStatsHistogram(this.lblMapping, this.sampleId);
        }
    }
}

