/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.knn.ann;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import java.util.UUID;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.inference.json.JSONModel;
import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
import org.apache.ignite.ml.knn.ann.ANNModelFormat;
import org.apache.ignite.ml.knn.ann.KNNModelFormat;
import org.apache.ignite.ml.knn.ann.ProbableLabel;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.DatasetRow;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;

public final class ANNClassificationModel
extends NNClassificationModel
implements JSONWritable,
DeployableObject {
    private static final long serialVersionUID = -127312378991350345L;
    private LabeledVectorSet<LabeledVector> candidates;
    private ANNClassificationTrainer.CentroidStat centroindsStat;

    public ANNClassificationModel(LabeledVectorSet<LabeledVector> centers, ANNClassificationTrainer.CentroidStat centroindsStat) {
        this.candidates = centers;
        this.centroindsStat = centroindsStat;
    }

    private ANNClassificationModel() {
    }

    public LabeledVectorSet<LabeledVector> getCandidates() {
        return this.candidates;
    }

    public ANNClassificationTrainer.CentroidStat getCentroindsStat() {
        return this.centroindsStat;
    }

    @Override
    public Double predict(Vector v) {
        List<LabeledVector> neighbors = this.findKNearestNeighbors(v);
        return this.classify(neighbors, v, this.weighted);
    }

    @Override
    public <P> void saveModel(Exporter<KNNModelFormat, P> exporter, P path) {
        ANNModelFormat mdlData = new ANNModelFormat(this.k, this.distanceMeasure, this.weighted, this.candidates, this.centroindsStat);
        exporter.save(mdlData, path);
    }

    private List<LabeledVector> findKNearestNeighbors(Vector v) {
        return Arrays.asList(this.getKClosestVectors(this.getDistances(v)));
    }

    private LabeledVector[] getKClosestVectors(TreeMap<Double, Set<Integer>> distanceIdxPairs) {
        LabeledVector[] res;
        if (this.candidates.rowSize() <= this.k) {
            res = new LabeledVector[this.candidates.rowSize()];
            for (int i = 0; i < this.candidates.rowSize(); ++i) {
                res[i] = (LabeledVector)this.candidates.getRow(i);
            }
        } else {
            res = new LabeledVector[this.k];
            int i = 0;
            Iterator<Double> iter = distanceIdxPairs.keySet().iterator();
            block1: while (i < this.k) {
                double key = iter.next();
                Set<Integer> idxs = distanceIdxPairs.get(key);
                for (Integer idx : idxs) {
                    res[i] = (LabeledVector)this.candidates.getRow(idx);
                    if (++i < this.k) continue;
                    continue block1;
                }
            }
        }
        return res;
    }

    private TreeMap<Double, Set<Integer>> getDistances(Vector v) {
        TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<Double, Set<Integer>>();
        for (int i = 0; i < this.candidates.rowSize(); ++i) {
            LabeledVector labeledVector = (LabeledVector)this.candidates.getRow(i);
            if (labeledVector == null) continue;
            double distance = this.distanceMeasure.compute(v, (Vector)labeledVector.features());
            this.putDistanceIdxPair(distanceIdxPairs, i, distance);
        }
        return distanceIdxPairs;
    }

    private double classify(List<LabeledVector> neighbors, Vector v, boolean weighted) {
        HashMap<Double, Double> clsVotes = new HashMap<Double, Double>();
        for (LabeledVector neighbor : neighbors) {
            TreeMap<Double, Double> probableClsLb = ((ProbableLabel)neighbor.label()).clsLbls;
            double distance = this.distanceMeasure.compute(v, (Vector)neighbor.features());
            probableClsLb.forEach((label, probability) -> {
                double cnt = clsVotes.containsKey(label) ? (Double)clsVotes.get(label) : 0.0;
                clsVotes.put((Double)label, cnt + probability * this.getClassVoteForVector(weighted, distance));
            });
        }
        return this.getClassWithMaxVotes(clsVotes);
    }

    @Override
    public int hashCode() {
        int res = 1;
        res = res * 37 + this.k;
        res = res * 37 + this.distanceMeasure.hashCode();
        res = res * 37 + Boolean.hashCode(this.weighted);
        res = res * 37 + this.candidates.hashCode();
        return res;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        ANNClassificationModel that = (ANNClassificationModel)obj;
        return this.k == that.k && this.distanceMeasure.equals(that.distanceMeasure) && this.weighted == that.weighted && this.candidates.equals(that.candidates);
    }

    @Override
    public String toString() {
        return this.toString(false);
    }

    @Override
    public String toString(boolean pretty) {
        return ModelTrace.builder("KNNClassificationModel", pretty).addField("k", String.valueOf(this.k)).addField("measure", this.distanceMeasure.getClass().getSimpleName()).addField("weighted", String.valueOf(this.weighted)).addField("amount of candidates", String.valueOf(this.candidates.rowSize())).toString();
    }

    @Override
    @JsonIgnore
    public List<Object> getDependencies() {
        return Collections.emptyList();
    }

    public static ANNClassificationModel fromJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        try {
            ANNJSONExportModel exportModel = (ANNJSONExportModel)mapper.readValue(new File(path.toAbsolutePath().toString()), ANNJSONExportModel.class);
            return exportModel.convert();
        }
        catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    @Override
    public void toJSON(Path path) {
        ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        try {
            ANNJSONExportModel exportModel = new ANNJSONExportModel(System.currentTimeMillis(), "ann_" + UUID.randomUUID(), ANNClassificationModel.class.getSimpleName());
            ArrayList<double[]> listOfCandidates = new ArrayList<double[]>();
            ProbableLabel[] labels = new ProbableLabel[this.candidates.rowSize()];
            for (int i = 0; i < this.candidates.rowSize(); ++i) {
                labels[i] = (ProbableLabel)((LabeledVector)this.candidates.getRow(i)).getLb();
                listOfCandidates.add(this.candidates.features(i).asArray());
            }
            exportModel.candidateFeatures = listOfCandidates;
            exportModel.distanceMeasure = this.distanceMeasure;
            exportModel.k = this.k;
            exportModel.weighted = this.weighted;
            exportModel.candidateLabels = labels;
            exportModel.centroindsStat = this.centroindsStat;
            File file = new File(path.toAbsolutePath().toString());
            mapper.writeValue(file, (Object)exportModel);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static class ANNJSONExportModel
    extends JSONModel {
        public List<double[]> candidateFeatures;
        public ProbableLabel[] candidateLabels;
        public DistanceMeasure distanceMeasure;
        public int k;
        public boolean weighted;
        public ANNClassificationTrainer.CentroidStat centroindsStat;

        public ANNJSONExportModel(Long timestamp, String uid, String modelClass) {
            super(timestamp, uid, modelClass);
        }

        @JsonCreator
        public ANNJSONExportModel() {
        }

        @Override
        public ANNClassificationModel convert() {
            if (this.candidateFeatures == null || this.candidateFeatures.isEmpty()) {
                throw new IllegalArgumentException("Loaded list of candidates is empty. It should be not empty.");
            }
            double[] firstRow = this.candidateFeatures.get(0);
            LabeledVectorSet<LabeledVector> candidatesForANN = new LabeledVectorSet<LabeledVector>(this.candidateFeatures.size(), firstRow.length);
            DatasetRow[] data = new LabeledVector[this.candidateFeatures.size()];
            for (int i = 0; i < this.candidateFeatures.size(); ++i) {
                data[i] = new LabeledVector<ProbableLabel>(VectorUtils.of(this.candidateFeatures.get(i)), this.candidateLabels[i]);
            }
            candidatesForANN.setData(data);
            ANNClassificationModel mdl = new ANNClassificationModel(candidatesForANN, this.centroindsStat);
            mdl.withDistanceMeasure(this.distanceMeasure);
            mdl.withK(this.k);
            mdl.withWeighted(this.weighted);
            return mdl;
        }
    }
}

