package marytts.machinelearning;

import java.awt.Color;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import javax.swing.JFrame;
import marytts.signalproc.adaptation.codebook.WeightedCodebookMapperParams;
import marytts.signalproc.display.FunctionGraph;
import marytts.util.math.Polynomial;

/* loaded from: input_file:marytts/machinelearning/PolynomialHierarchicalClusteringTrainer.class */
public class PolynomialHierarchicalClusteringTrainer {
    private static final double INFINITE = 1.0E7d;
    private static final int CLUSTER_DEFAULT_SIZE = 5;
    private HashSet<String> dataPointSet;
    private ArrayList<Cluster> clusterList;
    private HashMap<String, Double> distanceTableMap;
    private boolean isSimilarityMeasure;
    private double MINDISTANCE;
    Polynomial[] polynomials;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:marytts/machinelearning/PolynomialHierarchicalClusteringTrainer$Cluster.class */
    public class Cluster {
        private ArrayList<String> dataPoints;
        private int clusterSize;

        public Cluster(ArrayList<String> arrayList) {
            if (arrayList == null) {
                throw new NullPointerException("Input dataset for a cluster should not be null");
            }
            this.dataPoints = arrayList;
            this.clusterSize = arrayList.size();
        }

        public ArrayList<String> getAllDataPoints() {
            return this.dataPoints;
        }

        public void mergeCluster(Cluster cluster) {
            if (cluster == null) {
                throw new NullPointerException("Input cluster should not be null");
            }
            Iterator<String> it = cluster.getAllDataPoints().iterator();
            while (it.hasNext()) {
                this.dataPoints.add(it.next());
            }
            this.clusterSize = this.dataPoints.size();
        }
    }

    static {
        $assertionsDisabled = !PolynomialHierarchicalClusteringTrainer.class.desiredAssertionStatus();
    }

    public PolynomialHierarchicalClusteringTrainer(Polynomial[] polynomialArr) {
        if (polynomialArr == null) {
            throw new NullPointerException("Input polynomial array should not be null");
        }
        if (polynomialArr.length <= 2) {
            throw new IllegalArgumentException("Number of samples for clustering should be more than two.");
        }
        this.dataPointSet = new HashSet<>();
        this.distanceTableMap = new HashMap<>();
        this.clusterList = new ArrayList<>();
        this.polynomials = polynomialArr;
        setSimilarityMeasure(true);
        computeSampleDistances();
        initializeClustering();
    }

    private double getClusterDistance(Cluster cluster, Cluster cluster2, String str) {
        if (cluster == null || cluster2 == null) {
            throw new NullPointerException("Input clusters should not be null");
        }
        if (!"Short".equals(str) && !"Complete".equals(str) && !"Average".equals(str)) {
            throw new IllegalArgumentException("Only Short, Complete, or Average linkage clustering supported");
        }
        ArrayList<String> allDataPoints = cluster.getAllDataPoints();
        ArrayList<String> allDataPoints2 = cluster2.getAllDataPoints();
        ArrayList arrayList = new ArrayList();
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < allDataPoints.size(); i2++) {
            for (int i3 = 0; i3 < allDataPoints2.size(); i3++) {
                String str2 = String.valueOf(allDataPoints.get(i2)) + "_" + allDataPoints2.get(i3);
                if (this.distanceTableMap.containsKey(str2)) {
                    arrayList.add(this.distanceTableMap.get(str2));
                    d = this.distanceTableMap.get(str2).doubleValue() + d;
                    i++;
                } else {
                    String str3 = String.valueOf(allDataPoints2.get(i3)) + "_" + allDataPoints.get(i2);
                    if (this.distanceTableMap.containsKey(str3)) {
                        arrayList.add(this.distanceTableMap.get(str3));
                        d = this.distanceTableMap.get(str3).doubleValue() + d;
                        i++;
                    }
                }
            }
        }
        if (str.equals("Short")) {
            Double[] dArr = (Double[]) arrayList.toArray(new Double[0]);
            double d2 = Double.NaN;
            for (int i4 = 0; i4 < dArr.length; i4++) {
                if (!Double.isNaN(dArr[i4].doubleValue()) && (Double.isNaN(d2) || dArr[i4].doubleValue() < d2)) {
                    d2 = dArr[i4].doubleValue();
                }
            }
            return d2;
        }
        if (!str.equals("Complete")) {
            return d / i;
        }
        Double[] dArr2 = (Double[]) arrayList.toArray(new Double[0]);
        double d3 = Double.NaN;
        for (int i5 = 0; i5 < dArr2.length; i5++) {
            if (!Double.isNaN(dArr2[i5].doubleValue()) && (Double.isNaN(d3) || dArr2[i5].doubleValue() > d3)) {
                d3 = dArr2[i5].doubleValue();
            }
        }
        return d3;
    }

    private void initializeClustering() {
        if (!$assertionsDisabled && this.dataPointSet == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.clusterList == null) {
            throw new AssertionError();
        }
        Iterator<String> it = this.dataPointSet.iterator();
        while (it.hasNext()) {
            ArrayList arrayList = new ArrayList();
            arrayList.add(it.next());
            this.clusterList.add(new Cluster(arrayList));
        }
    }

    private void computeSampleDistances() {
        if (!$assertionsDisabled && this.polynomials == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.polynomials.length <= 2) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.dataPointSet == null) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && this.distanceTableMap == null) {
            throw new AssertionError();
        }
        int length = this.polynomials.length;
        this.polynomials[0].getOrder();
        int[] iArr = new int[length];
        double[][] dArr = new double[length][length];
        for (int i = 0; i < length; i++) {
            this.dataPointSet.add(new StringBuilder().append(i).toString());
            for (int i2 = 0; i2 < length; i2++) {
                dArr[i][i2] = Polynomial.polynomialPearsonProductMomentCorr(this.polynomials[i].coeffs, this.polynomials[i2].coeffs);
                this.distanceTableMap.put(String.valueOf(i) + "_" + i2, new Double(dArr[i][i2]));
            }
        }
    }

    private boolean hasSimilarityMeasure() {
        return this.isSimilarityMeasure;
    }

    private double getClusterDistance(Cluster cluster, Cluster cluster2) {
        return getClusterDistance(cluster, cluster2, "Average");
    }

    private void clustering() {
        clustering(5, "Average");
    }

    private void clustering(int i) {
        clustering(i, "Average");
    }

    private void clustering(int i, String str) {
        if (!$assertionsDisabled && this.clusterList == null) {
            throw new AssertionError();
        }
        int i2 = 0;
        int i3 = 0;
        for (int size = this.clusterList.size(); size > i; size--) {
            double d = this.MINDISTANCE;
            for (int i4 = 0; i4 < this.clusterList.size(); i4++) {
                Cluster cluster = this.clusterList.get(i4);
                for (int i5 = i4 + 1; i5 < this.clusterList.size(); i5++) {
                    double clusterDistance = getClusterDistance(cluster, this.clusterList.get(i5), str);
                    if (hasSimilarityMeasure()) {
                        if (clusterDistance < d) {
                            d = clusterDistance;
                            i2 = i4;
                            i3 = i5;
                        }
                    } else if (clusterDistance > d) {
                        d = clusterDistance;
                        i2 = i4;
                        i3 = i5;
                    }
                }
            }
            Cluster cluster2 = this.clusterList.get(i2);
            Cluster cluster3 = this.clusterList.get(i3);
            cluster2.mergeCluster(cluster3);
            this.clusterList.remove(cluster3);
        }
        printClusterData();
    }

    private void printClusterData() {
        if (!$assertionsDisabled && this.clusterList == null) {
            throw new AssertionError();
        }
        System.out.println("Total No of Clusters: " + this.clusterList.size());
        Iterator<Cluster> it = this.clusterList.iterator();
        int i = 1;
        while (it.hasNext()) {
            ArrayList<String> allDataPoints = it.next().getAllDataPoints();
            System.out.println("Cluster Number : " + i);
            for (int i2 = 0; i2 < allDataPoints.size(); i2++) {
                System.out.print(String.valueOf(allDataPoints.get(i2)) + " ");
            }
            System.out.println();
            i++;
        }
    }

    private void setSimilarityMeasure(boolean z) {
        this.isSimilarityMeasure = z;
        if (this.isSimilarityMeasure) {
            this.MINDISTANCE = INFINITE;
        } else {
            this.MINDISTANCE = -1.0E7d;
        }
    }

    public PolynomialCluster[] train(int i, String str) {
        if (this.clusterList.size() <= i) {
            throw new IllegalArgumentException("taget cluster size should be less than number of samples");
        }
        if (!"Short".equals(str) && !"Complete".equals(str) && !"Average".equals(str)) {
            throw new IllegalArgumentException("Only Short, Complete, or Average linkage clustering supported");
        }
        clustering(i, str);
        PolynomialCluster[] polynomialClusterArr = new PolynomialCluster[i];
        this.clusterList.size();
        if (!$assertionsDisabled && this.clusterList.size() != i) {
            throw new AssertionError("After clustering, number of clusters and the target cluster size should be same, but now the number of clusters are " + this.clusterList.size());
        }
        for (int i2 = 0; i2 < i; i2++) {
            ArrayList<String> allDataPoints = this.clusterList.get(i2).getAllDataPoints();
            Polynomial[] polynomialArr = new Polynomial[allDataPoints.size()];
            for (int i3 = 0; i3 < allDataPoints.size(); i3++) {
                polynomialArr[i3] = this.polynomials[new Integer(allDataPoints.get(i3)).intValue()];
            }
            polynomialClusterArr[i2] = new PolynomialCluster(Polynomial.mean(polynomialArr), polynomialArr);
        }
        return polynomialClusterArr;
    }

    public static void main(String[] strArr) {
        Polynomial[] polynomialArr = new Polynomial[100];
        for (int i = 0; i < 100; i++) {
            double[] dArr = new double[3 + 1];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr[i2] = Math.random();
            }
            polynomialArr[i] = new Polynomial(dArr);
        }
        PolynomialCluster[] train = new PolynomialHierarchicalClusteringTrainer(polynomialArr).train(5, "Average");
        FunctionGraph functionGraph = new FunctionGraph(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d, new double[1]);
        functionGraph.setYMinMax(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 5.0d);
        functionGraph.setPrimaryDataSeriesStyle(Color.BLUE, 2, 1);
        JFrame showInJFrame = functionGraph.showInJFrame("", false, true);
        for (int i3 = 0; i3 < train.length; i3++) {
            double[] generatePolynomialValues = train[i3].getMeanPolynomial().generatePolynomialValues(100, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d);
            functionGraph.updateData(WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d / generatePolynomialValues.length, generatePolynomialValues);
            Polynomial[] clusterMembers = train[i3].getClusterMembers();
            for (Polynomial polynomial : clusterMembers) {
                functionGraph.addDataSeries(polynomial.generatePolynomialValues(generatePolynomialValues.length, WeightedCodebookMapperParams.DEFAULT_DISTANCE_MEAN, 1.0d), Color.GRAY, 1, -1);
                showInJFrame.repaint();
            }
            showInJFrame.setTitle("Cluster " + (i3 + 1) + " of " + train.length + ": " + clusterMembers.length + " members");
            showInJFrame.repaint();
            try {
                Thread.sleep(5000L);
            } catch (InterruptedException unused) {
            }
        }
    }
}
