/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.statistics.bayesian.DirichletProcessMixtureModel;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

public class ParallelDirichletProcessMixtureModel<ObservationType>
extends DirichletProcessMixtureModel<ObservationType>
implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<ObservationAssignmentTask> assignmentTasks;
    protected transient ArrayList<ClusterUpdaterTask> clusterUpdaterTasks;

    public int getNumThreads() {
        return ParallelUtil.getNumThreads((ParallelAlgorithm)this);
    }

    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    @Override
    protected ArrayList<Collection<ObservationType>> assignObservationsToClusters(int K, DirichletProcessMixtureModel.DPMMLogConditional logConditional) {
        ArrayList results;
        if (this.assignmentTasks == null) {
            ArrayList dataArray = CollectionUtil.asArrayList((Iterable)((Iterable)this.data));
            int N = dataArray.size();
            int numThreads = this.getNumThreads();
            this.assignmentTasks = new ArrayList(numThreads);
            int numPerTask = N / numThreads;
            int endIndex = 0;
            for (int n = 0; n < numThreads - 1; ++n) {
                int startIndex = endIndex;
                this.assignmentTasks.add(new ObservationAssignmentTask(dataArray.subList(startIndex, endIndex += numPerTask)));
            }
            this.assignmentTasks.add(new ObservationAssignmentTask(dataArray.subList(endIndex, N)));
        }
        try {
            results = ParallelUtil.executeInParallel(this.assignmentTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        ArrayList<Collection<ObservationType>> clusterAssignments = new ArrayList<Collection<ObservationType>>(K + 1);
        for (int k = 0; k < K + 1; ++k) {
            clusterAssignments.add(new LinkedList());
        }
        for (int n = 0; n < results.size(); ++n) {
            logConditional.logConditional += ((DPMMAssignments)results.get((int)n)).logConditional.logConditional;
            ArrayList<Integer> assignments = ((DPMMAssignments)results.get((int)n)).assignments;
            int index = 0;
            for (Object observation : this.assignmentTasks.get(n).observations) {
                int assignment = assignments.get(index);
                clusterAssignments.get(assignment).add(observation);
                ++index;
            }
        }
        return clusterAssignments;
    }

    @Override
    protected ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> updateClusters(ArrayList<Collection<ObservationType>> clusterAssignments) {
        int k;
        int Kp1 = clusterAssignments.size();
        if (this.clusterUpdaterTasks == null || this.clusterUpdaterTasks.size() != Kp1) {
            this.clusterUpdaterTasks = new ArrayList(Kp1);
            for (k = 0; k < Kp1; ++k) {
                this.clusterUpdaterTasks.add(new ClusterUpdaterTask());
            }
        }
        for (k = 0; k < Kp1; ++k) {
            Collection<ObservationType> observations = clusterAssignments.get(k);
            if (observations.size() <= 1) {
                observations = null;
            }
            this.clusterUpdaterTasks.get((int)k).observations = observations;
        }
        ArrayList clusters = null;
        try {
            clusters = ParallelUtil.executeInParallel(this.clusterUpdaterTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> results = new ArrayList<DirichletProcessMixtureModel.DPMMCluster<ObservationType>>(Kp1);
        for (int k2 = 0; k2 < Kp1; ++k2) {
            DirichletProcessMixtureModel.DPMMCluster cluster = (DirichletProcessMixtureModel.DPMMCluster)clusters.get(k2);
            if (cluster == null) continue;
            results.add(cluster);
        }
        return results;
    }

    protected class ClusterUpdaterTask
    extends AbstractCloneableSerializable
    implements Callable<DirichletProcessMixtureModel.DPMMCluster<ObservationType>> {
        Collection<ObservationType> observations;
        DirichletProcessMixtureModel.Updater<ObservationType> localUpdater;

        public ClusterUpdaterTask() {
            this.localUpdater = (DirichletProcessMixtureModel.Updater)ObjectUtil.cloneSafe((CloneableSerializable)ParallelDirichletProcessMixtureModel.this.updater);
        }

        @Override
        public DirichletProcessMixtureModel.DPMMCluster<ObservationType> call() {
            return ParallelDirichletProcessMixtureModel.this.createCluster(this.observations, this.localUpdater);
        }
    }

    protected class ObservationAssignmentTask
    extends AbstractCloneableSerializable
    implements Callable<DPMMAssignments> {
        private Collection<? extends ObservationType> observations;
        private double[] weights = null;
        private ArrayList<Integer> assignments;
        private DirichletProcessMixtureModel.DPMMLogConditional logConditional;

        public ObservationAssignmentTask(Collection<? extends ObservationType> observations) {
            this.observations = observations;
        }

        @Override
        public DPMMAssignments call() throws Exception {
            int K = ((DirichletProcessMixtureModel.Sample)((Object)ParallelDirichletProcessMixtureModel.this.currentParameter)).getNumClusters();
            if (this.weights == null || this.weights.length != K + 1) {
                this.weights = new double[K + 1];
            }
            if (this.assignments == null) {
                this.assignments = new ArrayList(this.observations.size());
                for (int n = 0; n < this.observations.size(); ++n) {
                    this.assignments.add(null);
                }
            }
            this.logConditional = new DirichletProcessMixtureModel.DPMMLogConditional();
            int index = 0;
            for (Object observation : this.observations) {
                int clusterAssignment = ParallelDirichletProcessMixtureModel.this.assignObservationToCluster(observation, this.weights, this.logConditional);
                this.assignments.set(index, clusterAssignment);
                ++index;
            }
            return new DPMMAssignments(this.assignments, this.logConditional);
        }
    }

    public static class DPMMAssignments {
        protected ArrayList<Integer> assignments;
        protected DirichletProcessMixtureModel.DPMMLogConditional logConditional;

        public DPMMAssignments(ArrayList<Integer> assignments, DirichletProcessMixtureModel.DPMMLogConditional logConditional) {
            this.assignments = assignments;
            this.logConditional = logConditional;
        }
    }
}

