/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.hmm;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.hmm.BaumWelchAlgorithm;
import gov.sandia.cognition.learning.algorithm.hmm.HiddenMarkovModel;
import gov.sandia.cognition.learning.algorithm.hmm.MarkovChain;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.WeightedValue;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

@PublicationReference(author={"William Turin"}, title="Unidirectional and Parallel Baum\u2013Welch Algorithms", type=PublicationType.Journal, publication="IEEE Transactions on Speech and Audio Processing", year=1998, url="http://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=00725318")
public class ParallelBaumWelchAlgorithm<ObservationType>
extends BaumWelchAlgorithm<ObservationType>
implements ParallelAlgorithm {
    private transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<DistributionEstimatorTask<ObservationType>> distributionEstimatorTasks;

    public ParallelBaumWelchAlgorithm() {
    }

    public ParallelBaumWelchAlgorithm(HiddenMarkovModel<ObservationType> initialGuess, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> distributionLearner, boolean reestimateInitialProbabilities) {
        super(initialGuess, distributionLearner, reestimateInitialProbabilities);
    }

    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.setThreadPool(ParallelUtil.createThreadPool());
        }
        return this.threadPool;
    }

    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    public int getNumThreads() {
        return ParallelUtil.getNumThreads((ParallelAlgorithm)this);
    }

    @Override
    protected boolean initializeAlgorithm() {
        this.distributionEstimatorTasks = this.createDistributionEstimatorTasks();
        return super.initializeAlgorithm();
    }

    @Override
    protected ArrayList<ProbabilityFunction<ObservationType>> updateProbabilityFunctions(ArrayList<Vector> sequenceGammas) {
        int N = ((MarkovChain)((Object)this.getResult())).getNumStates();
        for (int i = 0; i < N; ++i) {
            this.distributionEstimatorTasks.get(i).setGammas(sequenceGammas);
        }
        ArrayList fs = null;
        try {
            fs = ParallelUtil.executeInParallel(this.distributionEstimatorTasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return fs;
    }

    protected ArrayList<DistributionEstimatorTask<ObservationType>> createDistributionEstimatorTasks() {
        int N = this.initialGuess.getNumStates();
        ArrayList<DistributionEstimatorTask<ObservationType>> tasks = new ArrayList<DistributionEstimatorTask<ObservationType>>(N);
        for (int i = 0; i < N; ++i) {
            tasks.add(new DistributionEstimatorTask((Collection)this.data, this.distributionLearner, i));
        }
        return tasks;
    }

    protected static class DistributionEstimatorTask<ObservationType>
    extends AbstractCloneableSerializable
    implements Callable<ProbabilityFunction<ObservationType>> {
        protected ArrayList<DefaultWeightedValue<ObservationType>> weightedValues;
        protected BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> distributionLearner;
        private ArrayList<Vector> gammas;
        protected int index;

        public DistributionEstimatorTask(Collection<? extends ObservationType> data, BatchLearner<Collection<? extends WeightedValue<? extends ObservationType>>, ? extends ComputableDistribution<ObservationType>> distributionLearner, int index) {
            this.index = index;
            this.distributionLearner = distributionLearner;
            this.weightedValues = new ArrayList(data.size());
            for (ObservationType v : data) {
                this.weightedValues.add(new DefaultWeightedValue(v));
            }
        }

        public void setGammas(ArrayList<Vector> gammas) {
            this.gammas = gammas;
        }

        @Override
        public ProbabilityFunction<ObservationType> call() {
            int N = this.gammas.size();
            for (int n = 0; n < N; ++n) {
                this.weightedValues.get(n).setWeight(this.gammas.get(n).getElement(this.index));
            }
            return this.distributionLearner.learn(this.weightedValues).getProbabilityFunction();
        }
    }
}

