/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.regression;

import java.util.List;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.knn.KNNModel;
import org.apache.ignite.ml.knn.utils.indices.SpatialIndex;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;

public class KNNRegressionModel
extends KNNModel<Double> {
    private final KNNRegressionPredictor predictor;

    KNNRegressionModel(Dataset<EmptyContext, SpatialIndex<Double>> dataset, DistanceMeasure distanceMeasure, int k, boolean weighted) {
        super(dataset, distanceMeasure, k, weighted);
        this.predictor = weighted ? new KNNRegressionWeightedPredictor() : new KNNRegressionSimplePredictor();
    }

    @Override
    public Double predict(Vector input) {
        List<LabeledVector<Double>> neighbors = this.findKClosest(this.k, input);
        return this.predictor.predict(neighbors, input);
    }

    private class KNNRegressionWeightedPredictor
    extends KNNRegressionSimplePredictor {
        private KNNRegressionWeightedPredictor() {
        }

        @Override
        public Double predict(List<LabeledVector<Double>> neighbors, Vector pnt) {
            if (neighbors.isEmpty()) {
                return null;
            }
            double sum = 0.0;
            double div = 0.0;
            for (LabeledVector<Double> neighbor : neighbors) {
                double distance = KNNRegressionModel.this.distanceMeasure.compute(pnt, (Vector)neighbor.features());
                sum += neighbor.label() * distance;
                div += distance;
            }
            if (div == 0.0) {
                return super.predict(neighbors, pnt);
            }
            return sum / div;
        }
    }

    private class KNNRegressionSimplePredictor
    implements KNNRegressionPredictor {
        private KNNRegressionSimplePredictor() {
        }

        @Override
        public Double predict(List<LabeledVector<Double>> neighbors, Vector pnt) {
            if (neighbors.isEmpty()) {
                return null;
            }
            double sum = 0.0;
            for (LabeledVector<Double> neighbor : neighbors) {
                sum += neighbor.label().doubleValue();
            }
            return sum / (double)KNNRegressionModel.this.k;
        }
    }

    private static interface KNNRegressionPredictor {
        public Double predict(List<LabeledVector<Double>> var1, Vector var2);
    }
}

