/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorInputEvaluator;
import gov.sandia.cognition.statistics.AbstractDistribution;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.AbstractCollection;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author={"Wikipedia"}, title="Multinomial distribution", type=PublicationType.WebPage, year=2009, url="http://en.wikipedia.org/wiki/Multinomial_distribution")
public class MultinomialDistribution
extends AbstractDistribution<Vector>
implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final int DEFAULT_NUM_CLASSES = 2;
    public static final int DEFAULT_NUM_TRIALS = 1;
    private int numTrials;
    private Vector parameters;

    public MultinomialDistribution() {
        this(2, 1);
    }

    public MultinomialDistribution(int numClasses, int numTrials) {
        this(VectorFactory.getDefault().createVector(numClasses, 1.0), numTrials);
    }

    public MultinomialDistribution(Vector parameters, int numTrials) {
        this.setNumTrials(numTrials);
        this.setParameters(parameters);
    }

    public MultinomialDistribution(MultinomialDistribution other) {
        this((Vector)ObjectUtil.cloneSafe((CloneableSerializable)other.getParameters()), other.getNumTrials());
    }

    public MultinomialDistribution clone() {
        MultinomialDistribution clone = (MultinomialDistribution)super.clone();
        clone.setParameters((Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.getParameters()));
        return clone;
    }

    public Vector getParameters() {
        return this.parameters;
    }

    public void setParameters(Vector parameters) {
        int N = parameters.getDimensionality();
        if (N < 2) {
            throw new IllegalArgumentException("Dimensionality must be >= 2");
        }
        for (int i = 0; i < N; ++i) {
            if (!(parameters.getElement(i) < 0.0)) continue;
            throw new IllegalArgumentException("All parameter elements must be >= 0.0");
        }
        this.parameters = parameters;
    }

    public Vector convertToVector() {
        return (Vector)ObjectUtil.cloneSafe((CloneableSerializable)this.getParameters());
    }

    public void convertFromVector(Vector parameters) {
        parameters.assertSameDimensionality(this.getParameters());
        this.setParameters((Vector)ObjectUtil.cloneSafe((CloneableSerializable)parameters));
    }

    public int getNumTrials() {
        return this.numTrials;
    }

    public void setNumTrials(int numTrials) {
        if (numTrials <= 0) {
            throw new IllegalArgumentException("numTrials must be > 0");
        }
        this.numTrials = numTrials;
    }

    @Override
    public Vector getMean() {
        return (Vector)this.parameters.scale((double)this.numTrials / this.parameters.norm1());
    }

    @Override
    public ArrayList<Vector> sample(Random random, int numSamples) {
        int numClasses = this.parameters.getDimensionality();
        double[] probs = new double[numClasses];
        double probsum = this.parameters.norm1();
        for (int j = 0; j < numClasses; ++j) {
            probs[j] = this.parameters.getElement(j) / probsum;
        }
        ArrayList<Vector> samples = new ArrayList<Vector>(numSamples);
        for (int n = 0; n < numSamples; ++n) {
            double[] successes = new double[numClasses];
            block2: for (int i = 0; i < this.numTrials; ++i) {
                double p = random.nextDouble();
                for (int k = 0; k < numClasses; ++k) {
                    if (p <= probs[k]) {
                        int n2 = k;
                        successes[n2] = successes[n2] + 1.0;
                        continue block2;
                    }
                    p -= probs[k];
                }
            }
            samples.add(VectorFactory.getDefault().copyArray(successes));
        }
        return samples;
    }

    public Domain getDomain() {
        return new Domain(this.getParameters().getDimensionality(), this.getNumTrials());
    }

    @Override
    public int getDomainSize() {
        return this.getDomain().size();
    }

    public PMF getProbabilityFunction() {
        return new PMF(this);
    }

    public static class Domain
    extends AbstractCollection<Vector> {
        private int numClasses;
        private int numTrials;

        public Domain(int numClasses, int numTrials) {
            this.numClasses = numClasses;
            this.numTrials = numTrials;
        }

        @Override
        public Iterator<Vector> iterator() {
            return new MultinomialIterator(this.numClasses, this.numTrials);
        }

        @Override
        public int size() {
            return MathUtil.binomialCoefficient((int)(this.numClasses + this.numTrials - 1), (int)(this.numClasses - 1));
        }

        public double logSize() {
            return MathUtil.logBinomialCoefficient((int)(this.numClasses + this.numTrials - 1), (int)(this.numClasses - 1));
        }

        protected static class MultinomialIterator
        extends AbstractCloneableSerializable
        implements Iterator<Vector> {
            private int value;
            private int numClasses;
            private int numTrials;
            private MultinomialIterator child;

            public MultinomialIterator(int numClasses, int numTrials) {
                if (numClasses <= 0) {
                    throw new IllegalArgumentException("NumClasses <= 0");
                }
                this.numClasses = numClasses;
                if (numTrials < 0) {
                    throw new IllegalArgumentException("numTrials < 0");
                }
                this.numTrials = numTrials;
                this.value = this.numClasses <= 1 ? this.numTrials : 0;
                if (this.numClasses > 1) {
                    this.child = new MultinomialIterator(this.numClasses - 1, this.numTrials - this.value);
                }
            }

            @Override
            public boolean hasNext() {
                if (this.value < this.numTrials) {
                    return true;
                }
                if (this.child == null) {
                    return this.value <= this.numTrials;
                }
                return this.child.hasNext();
            }

            @Override
            public Vector next() {
                if (this.child == null) {
                    Vector subset = VectorFactory.getDefault().createVector(1, (double)this.value);
                    ++this.value;
                    return subset;
                }
                if (!this.child.hasNext()) {
                    ++this.value;
                    this.child = new MultinomialIterator(this.numClasses - 1, this.numTrials - this.value);
                }
                return VectorFactory.getDefault().createVector(1, (double)this.value).stack(this.child.next());
            }

            @Override
            public void remove() {
                throw new UnsupportedOperationException("Cannot remove from MultinomialDomain");
            }
        }
    }

    public static class PMF
    extends MultinomialDistribution
    implements ProbabilityMassFunction<Vector>,
    VectorInputEvaluator<Vector, Double> {
        public PMF() {
        }

        public PMF(int numClasses, int numTrials) {
            super(numClasses, numTrials);
        }

        public PMF(Vector parameters, int numTrials) {
            super(parameters, numTrials);
        }

        public PMF(MultinomialDistribution other) {
            super(other);
        }

        public int getInputDimensionality() {
            return this.getParameters().getDimensionality();
        }

        public Double evaluate(Vector input) {
            return Math.exp(this.logEvaluate(input));
        }

        @Override
        public double logEvaluate(Vector input) {
            int N = this.getInputDimensionality();
            input.assertDimensionalityEquals(N);
            Vector p = this.getParameters();
            double psum = p.norm1();
            double logCoeff = MathUtil.logFactorial((int)this.getNumTrials());
            double logProb = 0.0;
            int total = 0;
            for (int i = 0; i < N; ++i) {
                int xi = (int)input.getElement(i);
                total += xi;
                double pi = p.getElement(i) / psum;
                if (pi < 0.0) {
                    throw new IllegalArgumentException("pi < 0.0" + p);
                }
                if (pi == 0.0) {
                    if (xi == 0) continue;
                    return Math.log(0.0);
                }
                if (xi == 0) continue;
                logCoeff -= MathUtil.logFactorial((int)xi);
                logProb += (double)xi * Math.log(pi);
            }
            if (total != this.getNumTrials()) {
                throw new IllegalArgumentException("Integer input sum != num trials");
            }
            return logCoeff + logProb;
        }

        @Override
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override
        public PMF getProbabilityFunction() {
            return this;
        }
    }
}

