/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.function.cost;

import gov.sandia.cognition.algorithm.ParallelAlgorithm;
import gov.sandia.cognition.algorithm.ParallelUtil;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.learning.function.cost.NegativeLogLikelihood;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.ComputableDistribution;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.Callable;
import java.util.concurrent.ThreadPoolExecutor;

public class ParallelNegativeLogLikelihood<DataType>
extends NegativeLogLikelihood<DataType>
implements ParallelAlgorithm {
    protected transient ThreadPoolExecutor threadPool;
    protected transient ArrayList<NegativeLogLikelihoodTask<DataType>> tasks;

    public ParallelNegativeLogLikelihood() {
        this(null);
    }

    public ParallelNegativeLogLikelihood(Collection<? extends DataType> costParameters) {
        super(costParameters);
    }

    @Override
    public Double evaluate(ComputableDistribution<DataType> target) {
        ProbabilityFunction<DataType> probabilityFunction = target.getProbabilityFunction();
        int N = ((Collection)this.costParameters).size();
        int numThreads = this.getNumThreads();
        if (this.tasks == null || this.tasks.size() != numThreads) {
            ArrayList dataArray = CollectionUtil.asArrayList((Iterable)((Iterable)this.costParameters));
            this.tasks = new ArrayList(numThreads);
            int numPerTask = N / numThreads;
            int endIndex = 0;
            int beginIndex = 0;
            for (int i = 0; i < numThreads; ++i) {
                beginIndex = endIndex;
                endIndex += numPerTask;
                if (i == numThreads - 1) {
                    endIndex = N;
                }
                this.tasks.add(new NegativeLogLikelihoodTask(dataArray.subList(beginIndex, endIndex)));
            }
        }
        for (int i = 0; i < numThreads; ++i) {
            this.tasks.get((int)i).probabilityFunction = probabilityFunction;
        }
        ArrayList results = null;
        try {
            results = ParallelUtil.executeInParallel(this.tasks, (ThreadPoolExecutor)this.getThreadPool());
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        return UnivariateStatisticsUtil.computeSum((Iterable)results) / (double)((Collection)this.costParameters).size();
    }

    public ThreadPoolExecutor getThreadPool() {
        if (this.threadPool == null) {
            this.threadPool = ParallelUtil.createThreadPool();
        }
        return this.threadPool;
    }

    public void setThreadPool(ThreadPoolExecutor threadPool) {
        this.threadPool = threadPool;
    }

    public int getNumThreads() {
        return ParallelUtil.getNumThreads((ParallelAlgorithm)this);
    }

    protected static class NegativeLogLikelihoodTask<DataType>
    implements Callable<Double> {
        private Collection<? extends DataType> data;
        protected ProbabilityFunction<DataType> probabilityFunction;

        public NegativeLogLikelihoodTask(Collection<? extends DataType> data) {
            this.data = data;
        }

        @Override
        public Double call() throws Exception {
            return (double)this.data.size() * NegativeLogLikelihood.evaluate(this.probabilityFunction, this.data);
        }
    }
}

