/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.ensemble;

import gov.sandia.cognition.algorithm.IterativeAlgorithm;
import gov.sandia.cognition.algorithm.IterativeAlgorithmListener;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.CollectionUtil;
import gov.sandia.cognition.collection.FiniteCapacityBuffer;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.factory.Factory;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearner;
import gov.sandia.cognition.learning.algorithm.BatchLearnerContainer;
import gov.sandia.cognition.learning.algorithm.ensemble.WeightedVotingCategorizerEnsemble;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.UnivariateStatisticsUtil;
import gov.sandia.cognition.statistics.DataHistogram;
import gov.sandia.cognition.statistics.distribution.MapBasedDataHistogram;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Random;

@PublicationReference(author={"Leo Breiman"}, title="Pasting small votes for classification in large databases and on-line", year=1999, type=PublicationType.Journal, publication="Machine Learning", pages={85, 103}, url="http://www.springerlink.com/content/mnu2r28218651707/fulltext.pdf")
public class IVotingCategorizerLearner<InputType, CategoryType>
extends AbstractAnytimeSupervisedBatchLearner<InputType, CategoryType, WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>>>
implements Randomized,
BatchLearnerContainer<BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>>> {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_PERCENT_TO_SAMPLE = 0.1;
    public static final double DEFAULT_PROPORTION_INCORRECT_IN_SAMPLE = 0.5;
    public static final boolean DEFAULT_VOTE_OUT_OF_BAG_ONLY = true;
    protected BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner;
    protected double percentToSample;
    protected double proportionIncorrectInSample;
    protected boolean voteOutOfBagOnly;
    protected Factory<? extends DataHistogram<CategoryType>> counterFactory;
    protected Random random;
    protected transient WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> ensemble;
    protected transient ArrayList<? extends InputOutputPair<? extends InputType, CategoryType>> dataList;
    protected transient ArrayList<DataHistogram<CategoryType>> dataFullEstimates;
    protected transient ArrayList<DataHistogram<CategoryType>> dataOutOfBagEstimates;
    protected transient boolean[] currentEnsembleCorrect;
    protected transient ArrayList<Integer> currentCorrectIndices;
    protected transient ArrayList<Integer> currentIncorrectIndices;
    protected transient int sampleSize;
    protected transient int numCorrectToSample;
    protected transient int numIncorrectToSample;
    protected transient ArrayList<InputOutputPair<? extends InputType, CategoryType>> currentBag;
    protected transient int[] dataInBag;
    protected transient Evaluator<? super InputType, ? extends CategoryType> currentMember;
    protected transient ArrayList<CategoryType> currentMemberEstimates;

    public IVotingCategorizerLearner() {
        this(null, 100, 0.1, new Random());
    }

    public IVotingCategorizerLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, int maxIterations, double percentToSample, Random random) {
        this(learner, maxIterations, percentToSample, 0.5, true, new MapBasedDataHistogram.DefaultFactory(2), random);
    }

    public IVotingCategorizerLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner, int maxIterations, double percentToSample, double proportionIncorrectInSample, boolean voteOutOfBagOnly, Factory<? extends DataHistogram<CategoryType>> counterFactory, Random random) {
        super(maxIterations);
        this.setLearner(learner);
        this.setPercentToSample(percentToSample);
        this.setProportionIncorrectInSample(proportionIncorrectInSample);
        this.setVoteOutOfBagOnly(voteOutOfBagOnly);
        this.setCounterFactory(counterFactory);
        this.setRandom(random);
    }

    @Override
    protected boolean initializeAlgorithm() {
        int i;
        int dataSize = ((Collection)this.data).size();
        if (dataSize <= 0) {
            return false;
        }
        if (this.random == null) {
            this.random = new Random();
        }
        this.ensemble = new WeightedVotingCategorizerEnsemble(DatasetUtil.findUniqueOutputs((Iterable)this.data));
        this.dataList = CollectionUtil.asArrayList((Iterable)((Iterable)this.data));
        this.dataFullEstimates = new ArrayList(dataSize);
        this.dataOutOfBagEstimates = new ArrayList(dataSize);
        this.currentEnsembleCorrect = new boolean[dataSize];
        this.currentCorrectIndices = new ArrayList(dataSize);
        this.currentIncorrectIndices = new ArrayList(dataSize);
        for (i = 0; i < dataSize; ++i) {
            this.dataFullEstimates.add(new MapBasedDataHistogram(2));
            this.dataOutOfBagEstimates.add(new MapBasedDataHistogram(2));
            this.dataOutOfBagEstimates.add((DataHistogram<CategoryType>)this.counterFactory.create());
            this.currentIncorrectIndices.add(i);
        }
        this.sampleSize = Math.max(1, (int)(this.percentToSample * (double)dataSize));
        this.numIncorrectToSample = (int)(this.proportionIncorrectInSample * (double)this.sampleSize);
        this.numCorrectToSample = this.sampleSize - this.numIncorrectToSample;
        this.currentBag = new ArrayList(this.numCorrectToSample + this.numIncorrectToSample);
        this.dataInBag = new int[dataSize];
        this.currentMember = null;
        this.currentMemberEstimates = new ArrayList(dataSize);
        for (i = 0; i < dataSize; ++i) {
            this.currentMemberEstimates.add(null);
        }
        return true;
    }

    @Override
    protected boolean step() {
        int dataSize = this.dataList.size();
        this.currentBag.clear();
        for (int i = 0; i < dataSize; ++i) {
            this.dataInBag[i] = 0;
        }
        ArrayList<Integer> correctIndices = this.currentCorrectIndices;
        ArrayList<Integer> incorrectIndices = this.currentIncorrectIndices;
        if (incorrectIndices.isEmpty()) {
            incorrectIndices = correctIndices;
        } else if (correctIndices.isEmpty()) {
            correctIndices = incorrectIndices;
        }
        this.currentBag.clear();
        this.createBag(correctIndices, incorrectIndices);
        this.currentMember = this.learner.learn(this.currentBag);
        this.ensemble.add(this.currentMember, 1.0);
        this.currentCorrectIndices.clear();
        this.currentIncorrectIndices.clear();
        for (int i = 0; i < dataSize; ++i) {
            boolean ensembleCorrect;
            InputOutputPair<InputType, CategoryType> example = this.dataList.get(i);
            CategoryType actual = example.getOutput();
            Object memberGuess = this.currentMember.evaluate(example.getInput());
            this.currentMemberEstimates.set(i, memberGuess);
            DataHistogram<CategoryType> fullEstimate = this.dataFullEstimates.get(i);
            DataHistogram<CategoryType> outOfBagEstimate = this.dataOutOfBagEstimates.get(i);
            if (memberGuess != null) {
                fullEstimate.add(memberGuess);
                if (this.dataInBag[i] <= 0) {
                    outOfBagEstimate.add(memberGuess);
                }
            }
            Object ensembleGuess = null;
            ensembleGuess = this.voteOutOfBagOnly && outOfBagEstimate.getTotalCount() > 0 ? outOfBagEstimate.getMaximumValue() : fullEstimate.getMaximumValue();
            this.currentEnsembleCorrect[i] = ensembleCorrect = ensembleGuess == null || ObjectUtil.equalsSafe(actual, ensembleGuess);
            if (ensembleCorrect) {
                this.currentCorrectIndices.add(i);
                continue;
            }
            this.currentIncorrectIndices.add(i);
        }
        return true;
    }

    protected void createBag(ArrayList<Integer> correctIndices, ArrayList<Integer> incorrectIndices) {
        IVotingCategorizerLearner.sampleIndicesWithReplacementInto(correctIndices, this.dataList, this.numCorrectToSample, this.random, this.currentBag, this.dataInBag);
        IVotingCategorizerLearner.sampleIndicesWithReplacementInto(incorrectIndices, this.dataList, this.numIncorrectToSample, this.random, this.currentBag, this.dataInBag);
    }

    protected static <DataType> void sampleIndicesWithReplacementInto(ArrayList<Integer> fromIndices, ArrayList<? extends DataType> baseData, int numToSample, Random random, ArrayList<DataType> output, int[] dataInBag) {
        int fromSize = fromIndices.size();
        for (int i = 0; i < numToSample; ++i) {
            int randomInt = random.nextInt(fromSize);
            int index = fromIndices.get(randomInt);
            output.add(baseData.get(index));
            int n = index;
            dataInBag[n] = dataInBag[n] + 1;
        }
    }

    @Override
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.dataFullEstimates = null;
        this.dataOutOfBagEstimates = null;
        this.dataInBag = null;
        this.currentMember = null;
        this.currentCorrectIndices = null;
        this.currentIncorrectIndices = null;
        this.currentBag = null;
        this.currentEnsembleCorrect = null;
        this.currentMemberEstimates = null;
    }

    public WeightedVotingCategorizerEnsemble<InputType, CategoryType, Evaluator<? super InputType, ? extends CategoryType>> getResult() {
        return this.ensemble;
    }

    @Override
    public BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> getLearner() {
        return this.learner;
    }

    public void setLearner(BatchLearner<? super Collection<? extends InputOutputPair<? extends InputType, CategoryType>>, ? extends Evaluator<? super InputType, ? extends CategoryType>> learner) {
        this.learner = learner;
    }

    public double getPercentToSample() {
        return this.percentToSample;
    }

    public void setPercentToSample(double percentToSample) {
        if (percentToSample <= 0.0) {
            throw new IllegalArgumentException("percentToSample must be greater than zero.");
        }
        this.percentToSample = percentToSample;
    }

    public double getProportionIncorrectInSample() {
        return this.proportionIncorrectInSample;
    }

    public void setProportionIncorrectInSample(double proportionIncorrectInSample) {
        if (proportionIncorrectInSample < 0.0 || proportionIncorrectInSample > 1.0) {
            throw new IllegalArgumentException("proportionIncorrectInSample must be between 0.0 and 1.0 (inclusive).");
        }
        this.proportionIncorrectInSample = proportionIncorrectInSample;
    }

    public boolean isVoteOutOfBagOnly() {
        return this.voteOutOfBagOnly;
    }

    public void setVoteOutOfBagOnly(boolean voteOutOfBagOnly) {
        this.voteOutOfBagOnly = voteOutOfBagOnly;
    }

    public Factory<? extends DataHistogram<CategoryType>> getCounterFactory() {
        return this.counterFactory;
    }

    public void setCounterFactory(Factory<? extends DataHistogram<CategoryType>> counterFactory) {
        this.counterFactory = counterFactory;
    }

    public Random getRandom() {
        return this.random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    public List<DataHistogram<CategoryType>> getDataFullEstimates() {
        return Collections.unmodifiableList(this.dataFullEstimates);
    }

    public List<DataHistogram<CategoryType>> getDataOutOfBagEstimates() {
        return Collections.unmodifiableList(this.dataOutOfBagEstimates);
    }

    public boolean[] getCurrentEnsembleCorrect() {
        return this.currentEnsembleCorrect;
    }

    public static class OutOfBagErrorStoppingCriteria<InputType, CategoryType>
    extends AbstractCloneableSerializable
    implements IterativeAlgorithmListener {
        public static final int DEFAULT_SMOOTHING_WINDOW_SIZE = 25;
        protected int smoothingWindowSize;
        protected transient IVotingCategorizerLearner<InputType, CategoryType> learner;
        protected transient boolean[] outOfBagCorrect;
        protected transient int outOfBagErrorCount;
        protected transient ArrayList<Double> rawErrorRates;
        protected transient ArrayList<Double> smoothedErrorRates;
        protected transient FiniteCapacityBuffer<Double> smoothingBuffer;
        protected transient double previousSmoothedErrorRate;

        public OutOfBagErrorStoppingCriteria() {
            this(25);
        }

        public OutOfBagErrorStoppingCriteria(int smoothingWindowSize) {
            this.setSmoothingWindowSize(smoothingWindowSize);
        }

        public void algorithmStarted(IterativeAlgorithm algorithm) {
            this.learner = (IVotingCategorizerLearner)algorithm;
            int dataSize = ((Collection)((IVotingCategorizerLearner)this.learner).data).size();
            this.outOfBagCorrect = new boolean[dataSize];
            this.outOfBagErrorCount = dataSize;
            this.rawErrorRates = new ArrayList();
            this.smoothedErrorRates = new ArrayList();
            this.smoothingBuffer = new FiniteCapacityBuffer(this.smoothingWindowSize);
            this.previousSmoothedErrorRate = Double.MAX_VALUE;
        }

        public void algorithmEnded(IterativeAlgorithm algorithm) {
            this.learner = null;
            this.outOfBagCorrect = null;
            this.rawErrorRates = null;
            this.smoothedErrorRates = null;
            this.smoothingBuffer = null;
        }

        public void stepStarted(IterativeAlgorithm algorithm) {
        }

        public void stepEnded(IterativeAlgorithm algorithm) {
            int dataSize = ((Collection)((IVotingCategorizerLearner)this.learner).data).size();
            for (int i = 0; i < dataSize; ++i) {
                DataHistogram outOfBagVotes;
                Object ensembleGuess;
                Object actual;
                boolean newEnsembleCorrect;
                boolean oldEnsembleCorrect;
                if (this.learner.dataInBag[i] > 0 || (oldEnsembleCorrect = this.outOfBagCorrect[i]) == (newEnsembleCorrect = ObjectUtil.equalsSafe(actual = this.learner.dataList.get(i).getOutput(), ensembleGuess = (outOfBagVotes = this.learner.dataOutOfBagEstimates.get(i)).getMaximumValue()))) continue;
                this.outOfBagCorrect[i] = newEnsembleCorrect;
                if (newEnsembleCorrect) {
                    --this.outOfBagErrorCount;
                    continue;
                }
                ++this.outOfBagErrorCount;
            }
            double outOfBagEnsembleErrorRate = (double)this.outOfBagErrorCount / (double)((Collection)((IVotingCategorizerLearner)this.learner).data).size();
            this.rawErrorRates.add(outOfBagEnsembleErrorRate);
            this.smoothingBuffer.add((Object)outOfBagEnsembleErrorRate);
            double smoothedErrorRate = UnivariateStatisticsUtil.computeMean(this.smoothingBuffer);
            this.smoothedErrorRates.add(smoothedErrorRate);
            if (smoothedErrorRate >= this.previousSmoothedErrorRate) {
                int i;
                this.learner.stop();
                int ensembleSize = this.rawErrorRates.size();
                int bestIndex = 0;
                double bestRawErrorRate = Double.MAX_VALUE;
                for (i = 0; i < this.smoothingBuffer.size(); ++i) {
                    int index = ensembleSize - i - 1;
                    double rawErrorRate = this.rawErrorRates.get(index);
                    if (!(rawErrorRate <= bestRawErrorRate)) continue;
                    bestIndex = index;
                    bestRawErrorRate = rawErrorRate;
                }
                for (i = ensembleSize - 1; i > bestIndex; --i) {
                    this.learner.ensemble.members.remove(i);
                }
            }
            this.previousSmoothedErrorRate = smoothedErrorRate;
        }

        public int getSmoothingWindowSize() {
            return this.smoothingWindowSize;
        }

        public void setSmoothingWindowSize(int smoothingWindowSize) {
            if (smoothingWindowSize < 0) {
                throw new IllegalArgumentException("smoothingWindowSize must be positive.");
            }
            this.smoothingWindowSize = smoothingWindowSize;
        }
    }
}

