/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.clustering.incremental;

import ch.akuhn.matrix.Matrix;
import ch.akuhn.matrix.SparseMatrix;
import gnu.trove.TIntCollection;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import org.openimaj.experiment.evaluation.cluster.analyser.FScoreAnalysis;
import org.openimaj.experiment.evaluation.cluster.analyser.FScoreClusterAnalyser;
import org.openimaj.math.matrix.MatlibMatrixUtils;
import org.openimaj.ml.clustering.IndexClusters;
import org.openimaj.ml.clustering.SparseMatrixClusterer;
import org.openimaj.util.pair.IntDoublePair;

public class IncrementalSparseClusterer
implements SparseMatrixClusterer<IndexClusters> {
    private SparseMatrixClusterer<? extends IndexClusters> clusterer;
    private int window;
    protected double threshold;
    private int maxwindow = -1;
    private static final Logger logger = Logger.getLogger(IncrementalSparseClusterer.class);

    public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window) {
        this.clusterer = clusterer;
        this.window = window;
        this.threshold = 1.0;
    }

    public IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, double threshold) {
        this.clusterer = clusterer;
        this.window = window;
        this.threshold = threshold;
    }

    private IncrementalSparseClusterer(SparseMatrixClusterer<? extends IndexClusters> clusterer, int window, int maxwindow) {
        this.clusterer = clusterer;
        this.window = window;
        if (maxwindow > 0 && maxwindow < window * 2) {
            maxwindow = window * 2;
        }
        if (maxwindow <= 0) {
            maxwindow = -1;
        }
        this.maxwindow = maxwindow;
        this.threshold = 1.0;
    }

    @Override
    public IndexClusters cluster(SparseMatrix data) {
        if (this.window >= data.rowCount()) {
            this.window = data.rowCount();
        }
        SparseMatrix seen = (SparseMatrix)MatlibMatrixUtils.subMatrix((Matrix)data, (int)0, (int)this.window, (int)0, (int)this.window);
        int seenrows = this.window;
        TIntHashSet inactiveRows = new TIntHashSet(this.window);
        logger.debug((Object)("First clustering!: " + seen.rowCount() + "x" + seen.columnCount()));
        Object oldClusters = this.clusterer.cluster(seen);
        logger.debug((Object)("First clusters:\n" + oldClusters));
        ArrayList<int[]> completedClusters = new ArrayList<int[]>();
        while (seenrows < data.rowCount()) {
            int nextwindow = seenrows + this.window;
            if (nextwindow >= data.rowCount()) {
                nextwindow = data.rowCount();
            }
            if (this.maxwindow > 0 && nextwindow - inactiveRows.size() > this.maxwindow) {
                logger.debug((Object)String.format("Window size (%d) without inactive (%d) = (%d), greater than maximum (%d)", nextwindow, inactiveRows.size(), nextwindow - inactiveRows.size(), this.maxwindow));
                this.deactiveOldItemsAsNoise(nextwindow, (TIntSet)inactiveRows, completedClusters);
            }
            WindowedSparseMatrix wsp = new WindowedSparseMatrix(data, nextwindow, (TIntSet)inactiveRows);
            logger.debug((Object)("Clustering: " + wsp.window.rowCount() + "x" + wsp.window.columnCount()));
            Object newClusters = this.clusterer.cluster(wsp.window);
            wsp.correctClusters((IndexClusters)newClusters);
            logger.debug((Object)("New clusters:\n" + newClusters));
            this.detectInactive((IndexClusters)oldClusters, (IndexClusters)newClusters, (TIntSet)inactiveRows, (List<int[]>)completedClusters);
            oldClusters = newClusters;
            logger.debug((Object)("Seen rows: " + (seenrows += this.window)));
            logger.debug((Object)("Inactive rows: " + inactiveRows.size()));
        }
        for (int i = 0; i < ((IndexClusters)oldClusters).clusters().length; ++i) {
            int[] cluster = ((IndexClusters)oldClusters).clusters()[i];
            if (cluster.length == 0) continue;
            completedClusters.add(cluster);
        }
        return new IndexClusters(completedClusters);
    }

    private void deactiveOldItemsAsNoise(int nextwindow, TIntSet inactiveRows, List<int[]> completedClusters) {
        int toDeactivate = 0;
        while (nextwindow - inactiveRows.size() > this.maxwindow) {
            if (!inactiveRows.contains(toDeactivate)) {
                logger.debug((Object)("Forcing the deactivation of: " + toDeactivate));
                inactiveRows.add(toDeactivate);
                completedClusters.add(new int[]{toDeactivate});
            }
            ++toDeactivate;
        }
    }

    protected void detectInactive(IndexClusters oldClusters, IndexClusters newClusters, TIntSet inactiveRows, List<int[]> completedClusters) {
        Map<Integer, IntDoublePair> stability = this.calculateStability(oldClusters, newClusters, inactiveRows);
        for (Map.Entry<Integer, IntDoublePair> e : stability.entrySet()) {
            if (!(e.getValue().second >= this.threshold)) continue;
            int[] completedCluster = oldClusters.clusters()[e.getKey()];
            inactiveRows.addAll(completedCluster);
            completedClusters.add(completedCluster);
            if (this.threshold != 1.0) continue;
            newClusters.clusters()[e.getValue().first] = new int[0];
        }
    }

    protected Map<Integer, IntDoublePair> calculateStability(IndexClusters c1, IndexClusters c2, TIntSet inactiveRows) {
        HashMap<Integer, IntDoublePair> stability = new HashMap<Integer, IntDoublePair>();
        int[][] clusters1 = c1.clusters();
        int[][] clusters2 = c2.clusters();
        for (int i = 0; i < clusters1.length; ++i) {
            if (clusters1[i].length == 0) continue;
            double maxnmi = 0.0;
            int maxj = -1;
            TIntArrayList cluster = new TIntArrayList(clusters1[i].length);
            for (int j = 0; j < clusters1[i].length; ++j) {
                if (inactiveRows.contains(clusters1[i][j])) continue;
                cluster.add(clusters1[i][j]);
            }
            int[][] correct = new int[][]{cluster.toArray()};
            for (int j = 0; j < clusters2.length; ++j) {
                int[][] estimated = new int[][]{clusters2[j]};
                double score = 0.0;
                score = correct[0].length == 1 && estimated[0].length == 1 ? (correct[0][0] == estimated[0][0] ? 1.0 : 0.0) : ((FScoreAnalysis)new FScoreClusterAnalyser().analyse((int[][])correct, (int[][])estimated)).score();
                if (Double.isNaN(score) || !(score > maxnmi)) continue;
                maxnmi = score;
                maxj = j;
            }
            stability.put(i, IntDoublePair.pair((int)maxj, (double)maxnmi));
        }
        logger.debug((Object)String.format("The stability is:\n%s", stability));
        return stability;
    }

    public int[][] performClustering(SparseMatrix data) {
        return this.cluster(data).clusters();
    }

    class WindowedSparseMatrix {
        SparseMatrix window;
        Map<Integer, Integer> indexCorrection;

        public WindowedSparseMatrix(SparseMatrix sm, int nextwindow, TIntSet inactive) {
            TIntArrayList active = new TIntArrayList(nextwindow);
            this.indexCorrection = new HashMap<Integer, Integer>();
            for (int i = 0; i < nextwindow; ++i) {
                if (inactive.contains(i)) continue;
                this.indexCorrection.put(active.size(), i);
                active.add(i);
            }
            this.window = (SparseMatrix)MatlibMatrixUtils.subMatrix((Matrix)sm, (TIntCollection)active, (TIntCollection)active);
        }

        public void correctClusters(IndexClusters clstrs) {
            int[][] clusters = clstrs.clusters();
            for (int i = 0; i < clusters.length; ++i) {
                int[] cluster = clusters[i];
                for (int j = 0; j < cluster.length; ++j) {
                    cluster[j] = this.indexCorrection.get(cluster[j]);
                }
            }
        }
    }
}

