/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.distribution;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationReferences;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.collection.DefaultMultiCollection;
import gov.sandia.cognition.collection.MultiCollection;
import gov.sandia.cognition.math.MathUtil;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.statistics.ClosedFormComputableDiscreteDistribution;
import gov.sandia.cognition.statistics.ProbabilityMassFunction;
import gov.sandia.cognition.statistics.ProbabilityMassFunctionUtil;
import gov.sandia.cognition.statistics.distribution.MultinomialDistribution;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Random;

@PublicationReferences(references={@PublicationReference(author={"Michael I. Jordan"}, title="Dirichlet Processes, Chinese Restaurant Processes and All That", year=2005, type=PublicationType.Conference, publication="NIPS Tutorial", url="http://www.cs.berkeley.edu/~jordan/nips-tutorial05.ps"), @PublicationReference(author={"Wikipedia"}, title="http://en.wikipedia.org/wiki/Chinese_restaurant_process", year=2010, type=PublicationType.WebPage, url="http://en.wikipedia.org/wiki/Chinese_restaurant_process", notes={"Very poor, unclear description."})})
public class ChineseRestaurantProcess
extends AbstractCloneableSerializable
implements ClosedFormComputableDiscreteDistribution<Vector> {
    public static final double DEFAULT_ALPHA = 1.0;
    public static final int DEFAULT_NUM_CUSTOMERS = 2;
    protected double alpha;
    protected int numCustomers;

    public ChineseRestaurantProcess() {
        this(1.0, 2);
    }

    public ChineseRestaurantProcess(double alpha, int numCustomers) {
        this.setAlpha(alpha);
        this.setNumCustomers(numCustomers);
    }

    public ChineseRestaurantProcess(ChineseRestaurantProcess other) {
        this(other.getAlpha(), other.getNumCustomers());
    }

    public ChineseRestaurantProcess clone() {
        return (ChineseRestaurantProcess)super.clone();
    }

    @Override
    public Vector getMean() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

    @Override
    public Vector sample(Random random) {
        ArrayList<Integer> tables = new ArrayList<Integer>(this.numCustomers);
        for (int n = 0; n < this.getNumCustomers(); ++n) {
            int tableIndex = ChineseRestaurantProcess.sampleNextCustomer(tables, n, this.alpha, random);
            if (tableIndex >= tables.size()) {
                tables.add(0);
            }
            int nc = tables.get(tableIndex) + 1;
            tables.set(tableIndex, nc);
        }
        Vector parameters = VectorFactory.getDefault().copyValues(tables);
        return parameters;
    }

    public static int sampleNextCustomer(Collection<Integer> tables, int numCustomers, double alpha, Random random) {
        double p = random.nextDouble();
        double pnew = alpha / ((double)numCustomers + alpha);
        if ((p -= pnew) <= 0.0) {
            return tables.size();
        }
        int tableIndex = 0;
        for (Integer customersAtTable : tables) {
            double tableProb = (double)customersAtTable.intValue() / ((double)numCustomers + alpha);
            if ((p -= tableProb) <= 0.0) {
                return tableIndex;
            }
            ++tableIndex;
        }
        throw new IllegalArgumentException("Bad computation in sampleNextcustomer!!!");
    }

    @Override
    public ArrayList<Vector> sample(Random random, int numSamples) {
        ArrayList<Vector> samples = new ArrayList<Vector>(numSamples);
        for (int n = 0; n < numSamples; ++n) {
            samples.add(this.sample(random));
        }
        return samples;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (alpha <= 0.0) {
            throw new IllegalArgumentException("Alpha must be > 0.0");
        }
        this.alpha = alpha;
    }

    public int getNumCustomers() {
        return this.numCustomers;
    }

    public void setNumCustomers(int numCustomers) {
        if (numCustomers <= 0) {
            throw new IllegalArgumentException("numCustomers must be > 0");
        }
        this.numCustomers = numCustomers;
    }

    @Override
    public MultiCollection<Vector> getDomain() {
        ArrayList<MultinomialDistribution.Domain> domain = new ArrayList<MultinomialDistribution.Domain>(this.getNumCustomers());
        for (int i = 1; i <= this.getNumCustomers(); ++i) {
            domain.add(new MultinomialDistribution.Domain(i, this.getNumCustomers()));
        }
        return new DefaultMultiCollection(domain);
    }

    @Override
    public int getDomainSize() {
        return this.getDomain().size();
    }

    public PMF getProbabilityFunction() {
        return new PMF(this);
    }

    public Vector convertToVector() {
        return VectorFactory.getDefault().copyValues(new double[]{this.getAlpha(), this.getNumCustomers()});
    }

    public void convertFromVector(Vector parameters) {
        parameters.assertDimensionalityEquals(2);
        this.setAlpha(parameters.getElement(0));
        this.setNumCustomers((int)Math.round(parameters.getElement(1)));
    }

    public static class PMF
    extends ChineseRestaurantProcess
    implements ProbabilityMassFunction<Vector> {
        public PMF() {
        }

        public PMF(double alpha, int numCustomers) {
            super(alpha, numCustomers);
        }

        public PMF(ChineseRestaurantProcess other) {
            super(other);
        }

        @Override
        public PMF getProbabilityFunction() {
            return this;
        }

        @Override
        public double getEntropy() {
            return ProbabilityMassFunctionUtil.getEntropy(this);
        }

        @Override
        public double logEvaluate(Vector input) {
            int numTables = input.getDimensionality();
            double logSum = (double)numTables * Math.log(this.alpha);
            int totalCustomers = 0;
            for (int table = 0; table < numTables; ++table) {
                double ceil;
                double value = input.getElement(table);
                if (value < 1.0 || value > (double)this.numCustomers) {
                    return Math.log(0.0);
                }
                double floor = Math.floor(value);
                if (floor != (ceil = Math.ceil(value))) {
                    throw new IllegalArgumentException("Customers at each table must be an integer: " + input);
                }
                int customersAtTable = (int)floor;
                logSum += MathUtil.logFactorial((int)(customersAtTable - 1));
                totalCustomers += customersAtTable;
            }
            logSum += MathUtil.logGammaFunction((double)this.alpha);
            return logSum -= MathUtil.logGammaFunction((double)(this.alpha + (double)totalCustomers));
        }

        public Double evaluate(Vector input) {
            return Math.exp(this.logEvaluate(input));
        }
    }
}

