/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.boosting.convergence.mean;

import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;

public class MeanAbsValueConvergenceChecker<K, V>
extends ConvergenceChecker<K, V> {
    private static final long serialVersionUID = 8534776439755210864L;

    public MeanAbsValueConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor, double precision) {
        super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, preprocessor, precision);
    }

    @Override
    public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition mdl) {
        IgniteBiTuple sumAndCnt = (IgniteBiTuple)dataset.compute(partition -> this.computeStatisticOnPartition(mdl, (FeatureMatrixWithLabelsOnHeapData)partition), this::reduce);
        if (sumAndCnt == null || (Long)sumAndCnt.getValue() == 0L) {
            return Double.NaN;
        }
        return (Double)sumAndCnt.getKey() / (double)((Long)sumAndCnt.getValue()).longValue();
    }

    private IgniteBiTuple<Double, Long> computeStatisticOnPartition(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData part) {
        Double sum = 0.0;
        for (int i = 0; i < part.getFeatures().length; ++i) {
            double error = this.computeError(VectorUtils.of(part.getFeatures()[i]), part.getLabels()[i], mdl);
            sum = sum + Math.abs(error);
        }
        return new IgniteBiTuple((Object)sum, (Object)part.getLabels().length);
    }

    private IgniteBiTuple<Double, Long> reduce(IgniteBiTuple<Double, Long> left, IgniteBiTuple<Double, Long> right) {
        if (left == null) {
            if (right != null) {
                return right;
            }
            return new IgniteBiTuple((Object)0.0, (Object)0L);
        }
        if (right == null) {
            return left;
        }
        return new IgniteBiTuple((Object)((Double)left.getKey() + (Double)right.getKey()), (Object)((Long)right.getValue() + (Long)left.getValue()));
    }
}

