package org.openimaj.ml.clustering.kmeans;

import com.rits.cloning.Cloner;
import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import org.openimaj.data.ArrayBackedDataSource;
import org.openimaj.data.DataSource;
import org.openimaj.feature.FeatureVector;
import org.openimaj.knn.ObjectNearestNeighbours;
import org.openimaj.knn.ObjectNearestNeighboursExact;
import org.openimaj.knn.ObjectNearestNeighboursProvider;
import org.openimaj.ml.clustering.FeatureVectorCentroidsResult;
import org.openimaj.ml.clustering.IndexClusters;
import org.openimaj.ml.clustering.SpatialClusterer;
import org.openimaj.ml.clustering.assignment.HardAssigner;
import org.openimaj.ml.clustering.assignment.hard.ExactFeatureVectorAssigner;
import org.openimaj.ml.clustering.kmeans.FeatureVectorKMeansInit;
import org.openimaj.util.comparator.DistanceComparator;
import org.openimaj.util.pair.IntFloatPair;

/* loaded from: input_file:org/openimaj/ml/clustering/kmeans/FeatureVectorKMeans.class */
public class FeatureVectorKMeans<T extends FeatureVector> implements SpatialClusterer<FeatureVectorCentroidsResult<T>, T> {
    private FeatureVectorKMeansInit<T> init;
    private KMeansConfiguration<ObjectNearestNeighbours<T>, T> conf;
    private Random rng;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/openimaj/ml/clustering/kmeans/FeatureVectorKMeans$CentroidAssignmentJob.class */
    public static class CentroidAssignmentJob<T extends FeatureVector> implements Callable<Boolean> {
        private final DataSource<T> ds;
        private final int startRow;
        private final int stopRow;
        private final ObjectNearestNeighbours<T> nno;
        private final double[][] centroids_accum;
        private final int[] counts;

        public CentroidAssignmentJob(DataSource<T> dataSource, int i, int i2, ObjectNearestNeighbours<T> objectNearestNeighbours, double[][] dArr, int[] iArr) {
            this.ds = dataSource;
            this.startRow = i;
            this.stopRow = i2;
            this.nno = objectNearestNeighbours;
            this.centroids_accum = dArr;
            this.counts = iArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Boolean call() {
            try {
                int length = ((FeatureVector) this.ds.getData(0)).length();
                FeatureVector[] featureVectorArr = (FeatureVector[]) this.ds.createTemporaryArray(this.stopRow - this.startRow);
                this.ds.getData(this.startRow, this.stopRow, featureVectorArr);
                int[] iArr = new int[featureVectorArr.length];
                this.nno.searchNN(featureVectorArr, iArr, new float[featureVectorArr.length]);
                synchronized (this.centroids_accum) {
                    for (int i = 0; i < featureVectorArr.length; i++) {
                        int i2 = iArr[i];
                        double[] asDoubleVector = featureVectorArr[i].asDoubleVector();
                        for (int i3 = 0; i3 < length; i3++) {
                            double[] dArr = this.centroids_accum[i2];
                            int i4 = i3;
                            dArr[i4] = dArr[i4] + asDoubleVector[i3];
                        }
                        int[] iArr2 = this.counts;
                        iArr2[i2] = iArr2[i2] + 1;
                    }
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
            return true;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/openimaj/ml/clustering/kmeans/FeatureVectorKMeans$Result.class */
    public static class Result<T extends FeatureVector> extends FeatureVectorCentroidsResult<T> implements ObjectNearestNeighboursProvider<T> {
        protected ObjectNearestNeighbours<T> nn;

        private Result() {
        }

        public ObjectNearestNeighbours<T> getNearestNeighbours() {
            return this.nn;
        }

        @Override // org.openimaj.ml.clustering.FeatureVectorCentroidsResult, org.openimaj.ml.clustering.SpatialClusters
        /* renamed from: defaultHardAssigner */
        public HardAssigner<T, float[], IntFloatPair> defaultHardAssigner2() {
            return new ExactFeatureVectorAssigner(this, this.nn.distanceComparator());
        }
    }

    public FeatureVectorKMeans(KMeansConfiguration<ObjectNearestNeighbours<T>, T> kMeansConfiguration) {
        this.init = new FeatureVectorKMeansInit.RANDOM();
        this.rng = new Random();
        this.conf = kMeansConfiguration;
    }

    protected FeatureVectorKMeans() {
        this(new KMeansConfiguration());
    }

    public FeatureVectorKMeansInit<T> getInit() {
        return this.init;
    }

    public void setInit(FeatureVectorKMeansInit<T> featureVectorKMeansInit) {
        this.init = featureVectorKMeansInit;
    }

    public void seed(long j) {
        if (j < 0) {
            this.rng = new Random();
        } else {
            this.rng = new Random(j);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public FeatureVectorCentroidsResult<T> cluster(List<T> list) {
        return cluster((FeatureVector[]) list.toArray((FeatureVector[]) Array.newInstance(list.get(0).getClass(), list.size())));
    }

    @Override // org.openimaj.ml.clustering.SpatialClusterer
    public FeatureVectorCentroidsResult<T> cluster(T[] tArr) {
        try {
            Result<T> cluster = cluster(new ArrayBackedDataSource<T>(tArr, this.rng) { // from class: org.openimaj.ml.clustering.kmeans.FeatureVectorKMeans.1
                public int numDimensions() {
                    return ((FeatureVector[]) this.data)[0].length();
                }
            }, this.conf.K);
            cluster.nn = this.conf.factory.create(cluster.centroids);
            return cluster;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public int[][] performClustering(T[] tArr) {
        return new IndexClusters(cluster((FeatureVector[]) tArr).defaultHardAssigner2().assign(tArr)).clusters();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public int[][] performClustering(List<T> list) {
        FeatureVector[] featureVectorArr = (FeatureVector[]) list.toArray((FeatureVector[]) Array.newInstance(list.get(0).getClass(), list.size()));
        return new IndexClusters(cluster(featureVectorArr).defaultHardAssigner2().assign((Object[]) featureVectorArr)).clusters();
    }

    protected Result<T> cluster(DataSource<T> dataSource, int i) throws Exception {
        Result<T> result = new Result<>();
        result.centroids = (T[]) ((FeatureVector[]) dataSource.createTemporaryArray(i));
        this.init.initKMeans(dataSource, result.centroids);
        cluster(dataSource, result);
        return result;
    }

    protected void cluster(DataSource<T> dataSource, Result<T> result) throws Exception {
        FeatureVector[] featureVectorArr = result.centroids;
        int length = featureVectorArr.length;
        int length2 = featureVectorArr[0].length();
        int size = dataSource.size();
        double[][] dArr = new double[length][length2];
        int[] iArr = new int[length];
        ExecutorService executorService = this.conf.threadpool;
        for (int i = 0; i < this.conf.niters; i++) {
            System.err.println("Iteration " + i);
            for (int i2 = 0; i2 < length; i2++) {
                Arrays.fill(dArr[i2], 0.0d);
            }
            Arrays.fill(iArr, 0);
            ObjectNearestNeighbours create = this.conf.factory.create(featureVectorArr);
            ArrayList arrayList = new ArrayList();
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 >= size) {
                    break;
                }
                arrayList.add(new CentroidAssignmentJob(dataSource, i4, Math.min(i4 + this.conf.blockSize, size), create, dArr, iArr));
                i3 = i4 + this.conf.blockSize;
            }
            executorService.invokeAll(arrayList);
            for (int i5 = 0; i5 < length; i5++) {
                if (iArr[i5] == 0) {
                    iArr[i5] = 1;
                    FeatureVector[] featureVectorArr2 = (FeatureVector[]) dataSource.createTemporaryArray(1);
                    dataSource.getRandomRows(featureVectorArr2);
                    featureVectorArr[i5] = (FeatureVector) new Cloner().deepClone(featureVectorArr2[0]);
                } else {
                    for (int i6 = 0; i6 < length2; i6++) {
                        featureVectorArr[i5].setFromDouble(i6, dArr[i5][i6] / iArr[i5]);
                    }
                }
            }
        }
    }

    protected float roundFloat(double d) {
        return (float) d;
    }

    protected double roundDouble(double d) {
        return d;
    }

    protected long roundLong(double d) {
        return Math.round(d);
    }

    protected int roundInt(double d) {
        return (int) Math.round(d);
    }

    @Override // org.openimaj.ml.clustering.SpatialClusterer
    public FeatureVectorCentroidsResult<T> cluster(DataSource<T> dataSource) {
        try {
            Result<T> cluster = cluster(dataSource, this.conf.K);
            cluster.nn = this.conf.factory.create(cluster.centroids);
            return cluster;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public KMeansConfiguration<ObjectNearestNeighbours<T>, T> getConfiguration() {
        return this.conf;
    }

    public void setConfiguration(KMeansConfiguration<ObjectNearestNeighbours<T>, T> kMeansConfiguration) {
        this.conf = kMeansConfiguration;
    }

    public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int i, DistanceComparator<? super T> distanceComparator) {
        return new FeatureVectorKMeans<>(new KMeansConfiguration(i, new ObjectNearestNeighboursExact.Factory(distanceComparator)));
    }

    public static <T extends FeatureVector> FeatureVectorKMeans<T> createExact(int i, DistanceComparator<? super T> distanceComparator, int i2) {
        return new FeatureVectorKMeans<>(new KMeansConfiguration(i, new ObjectNearestNeighboursExact.Factory(distanceComparator), i2));
    }

    public String toString() {
        return String.format("%s: {K=%d, NN=%s}", getClass().getSimpleName(), Integer.valueOf(this.conf.K), this.conf.getNearestNeighbourFactory().getClass().getSimpleName());
    }
}
