/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.statistics.bayesian;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.evaluator.Evaluator;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.function.scalar.PolynomialFunction;
import gov.sandia.cognition.math.AbstractUnivariateScalarFunction;
import gov.sandia.cognition.math.OperationNotConvergedException;
import gov.sandia.cognition.math.ProbabilityUtil;
import gov.sandia.cognition.statistics.ProbabilityFunction;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.ObjectUtil;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Random;

@PublicationReference(author={"Christian P. Robert", "George Casella"}, title="Monte Carlo Statistical Methods, Seconds Edition", type=PublicationType.Book, pages={56, 58, 70, 71}, notes={"Algorithm A.7", "Algorithm A.17"}, year=2004)
public class AdaptiveRejectionSampling
extends AbstractCloneableSerializable {
    public static final int DEFAULT_MAX_NUM_POINTS = 50;
    LogEvaluator logFunction;
    private ArrayList<Point> points = new ArrayList(50);
    private int maxNumPoints = 50;
    private double minSupport;
    private double maxSupport;
    UpperEnvelope upperEnvelope = new UpperEnvelope();
    LowerEnvelope lowerEnvelope = new LowerEnvelope();

    public AdaptiveRejectionSampling clone() {
        AdaptiveRejectionSampling clone = (AdaptiveRejectionSampling)super.clone();
        clone.points = ObjectUtil.cloneSmartElementsAsArrayList(this.getPoints());
        clone.upperEnvelope = clone.new UpperEnvelope();
        clone.upperEnvelope.resetLines();
        clone.lowerEnvelope = clone.new LowerEnvelope();
        clone.lowerEnvelope.resetLines();
        clone.setLogFunction((LogEvaluator)ObjectUtil.cloneSafe((CloneableSerializable)this.getLogFunction()));
        return clone;
    }

    public void initialize(LogEvaluator logFunction, double minSupport, double maxSupport, double leftPoint, double midPoint, double rightPoint) {
        this.setLogFunction(logFunction);
        this.setMinSupport(minSupport);
        this.setMaxSupport(maxSupport);
        this.points = new ArrayList(50);
        this.upperEnvelope = new UpperEnvelope();
        this.lowerEnvelope = new LowerEnvelope();
        double y = this.logFunction.evaluate(leftPoint);
        this.addPoint(leftPoint, y);
        y = this.logFunction.evaluate(midPoint);
        this.addPoint(midPoint, y);
        y = this.logFunction.evaluate(rightPoint);
        this.addPoint(rightPoint, y);
    }

    public void addPoint(double x, double y) {
        if (this.getNumPoints() < this.getMaxNumPoints()) {
            this.points.add(new Point(x, y));
            Collections.sort(this.points);
            this.upperEnvelope.resetLines();
            this.lowerEnvelope.resetLines();
        }
    }

    public int getNumPoints() {
        return this.getPoints().size();
    }

    protected Collection<Point> getPoints() {
        return this.points;
    }

    public double sample(Random random) {
        int maxRejections = 100;
        for (int rejections = 0; rejections < 100; ++rejections) {
            double rejectionRatio;
            double logUpper;
            double logLower;
            double envelopeRatio;
            double x = this.upperEnvelope.sample(random);
            double u = random.nextDouble();
            if (u <= (envelopeRatio = Math.exp((logLower = this.lowerEnvelope.logEvaluate(x)) - (logUpper = this.upperEnvelope.logEvaluate(x))))) {
                return x;
            }
            double logFx = this.logFunction.evaluate(x);
            if (this.getNumPoints() < this.getMaxNumPoints()) {
                this.addPoint(x, logFx);
            }
            if (!(u <= (rejectionRatio = Math.exp(logFx - logUpper)))) continue;
            return x;
        }
        throw new OperationNotConvergedException("Maximum number of rejections exceeded for a single sample: 100");
    }

    public ArrayList<Double> sample(Random random, int numSamples) {
        ArrayList<Double> samples = new ArrayList<Double>(numSamples);
        for (int n = 0; n < numSamples; ++n) {
            samples.add(this.sample(random));
        }
        return samples;
    }

    public LogEvaluator getLogFunction() {
        return this.logFunction;
    }

    public void setLogFunction(LogEvaluator logFunction) {
        this.logFunction = logFunction;
    }

    public int getMaxNumPoints() {
        return this.maxNumPoints;
    }

    public void setMaxNumPoints(int maxNumPoints) {
        this.maxNumPoints = maxNumPoints;
    }

    public double getMinSupport() {
        return this.minSupport;
    }

    public void setMinSupport(double minSupport) {
        this.minSupport = minSupport;
    }

    public double getMaxSupport() {
        return this.maxSupport;
    }

    public void setMaxSupport(double maxSupport) {
        this.maxSupport = maxSupport;
    }

    public static class PDFLogEvaluator
    extends LogEvaluator<ProbabilityFunction<Double>> {
        public PDFLogEvaluator(ProbabilityFunction<Double> function) {
            super(function);
        }

        @Override
        public double evaluate(double input) {
            return ((ProbabilityFunction)this.function).logEvaluate(input);
        }
    }

    public static abstract class LogEvaluator<EvaluatorType extends Evaluator<Double, Double>>
    extends AbstractUnivariateScalarFunction {
        protected EvaluatorType function;

        public LogEvaluator(EvaluatorType function) {
            this.setFunction(function);
        }

        public LogEvaluator<EvaluatorType> clone() {
            LogEvaluator clone = (LogEvaluator)super.clone();
            clone.setFunction((Evaluator)ObjectUtil.cloneSmart(this.getFunction()));
            return clone;
        }

        public EvaluatorType getFunction() {
            return this.function;
        }

        public void setFunction(EvaluatorType function) {
            this.function = function;
        }

        public double evaluate(double input) {
            return Math.log((Double)this.function.evaluate((Object)input));
        }
    }

    public static class Point
    extends DefaultInputOutputPair<Double, Double>
    implements Comparable<Point> {
        public Point(double x, double y) {
            super(x, y);
        }

        @Override
        public int compareTo(Point o) {
            double x1;
            double x0 = (Double)this.getInput();
            if (x0 < (x1 = ((Double)o.getInput()).doubleValue())) {
                return -1;
            }
            if (x0 > x1) {
                return 1;
            }
            return 0;
        }

        public static PolynomialFunction.Linear line(int index, ArrayList<Point> points) {
            Point pi = points.get(index);
            Point pip1 = points.get(index + 1);
            return PolynomialFunction.Linear.fit(pi, pip1);
        }

        public static double intercept(PolynomialFunction.Linear line1, PolynomialFunction.Linear line2) {
            double a1 = line1.getQ1();
            double b1 = line1.getQ0();
            double a2 = line2.getQ1();
            double b2 = line2.getQ0();
            if (a1 == a2) {
                throw new IllegalArgumentException("Lines are collinear");
            }
            return (b2 - b1) / (a1 - a2);
        }
    }

    public static class LineSegment
    extends PolynomialFunction.Linear
    implements Comparable<Double> {
        double left;
        double right;

        public LineSegment(PolynomialFunction.Linear line, double left, double right) {
            super(line.getQ0(), line.getQ1());
            this.left = left;
            this.right = right;
        }

        public double sampleExp(double p) {
            ProbabilityUtil.assertIsProbability((double)p);
            double q1 = this.getQ1();
            if (Math.abs(q1) >= 0.0) {
                double l = Math.exp(q1 * this.left);
                double r = Math.exp(q1 * this.right);
                double delta = p * (r - l);
                double x = Math.log(l + delta) / q1;
                return x;
            }
            double l = this.left;
            double r = this.right;
            return l + p * (r - l);
        }

        public double integrateExp() {
            double q0 = this.getQ0();
            double q1 = this.getQ1();
            if (Math.abs(q1) >= 0.0) {
                double l = Math.exp(q1 * this.left);
                double r = Math.exp(q1 * this.right);
                double coeff = Math.exp(q0) / q1;
                return coeff * (r - l);
            }
            double l = this.left;
            double r = this.right;
            double coeff = Math.exp(q0);
            return coeff * (r - l);
        }

        @Override
        public int compareTo(Double o) {
            double x = o;
            if (x < this.left) {
                return 1;
            }
            if (x > this.right) {
                return -1;
            }
            return 0;
        }
    }

    public class LowerEnvelope
    extends AbstractEnvelope {
        @Override
        protected void computeLines() {
            int numPoints = AdaptiveRejectionSampling.this.points.size();
            int numLines = numPoints + 1;
            this.lines = new ArrayList(numLines);
            Iterator iterator = AdaptiveRejectionSampling.this.points.iterator();
            double left = AdaptiveRejectionSampling.this.minSupport;
            double right = (Double)((Point)iterator.next()).getInput();
            PolynomialFunction.Linear line = new PolynomialFunction.Linear(0.0, Double.NEGATIVE_INFINITY);
            this.lines.add(new LineSegment(line, left, right));
            for (int i = 0; i < numPoints - 1; ++i) {
                left = right;
                right = (Double)((Point)iterator.next()).getInput();
                line = Point.line(i, AdaptiveRejectionSampling.this.points);
                this.lines.add(new LineSegment(line, left, right));
            }
            left = right;
            right = AdaptiveRejectionSampling.this.maxSupport;
            line = new PolynomialFunction.Linear(0.0, Double.NEGATIVE_INFINITY);
            this.lines.add(new LineSegment(line, left, right));
        }
    }

    public class UpperEnvelope
    extends AbstractEnvelope
    implements ProbabilityFunction<Double> {
        protected double[] segmentCDF;

        public UpperEnvelope() {
            this.segmentCDF = null;
        }

        @Override
        public UpperEnvelope clone() {
            UpperEnvelope clone = (UpperEnvelope)super.clone();
            clone.segmentCDF = (double[])ObjectUtil.cloneSmart((Object)this.segmentCDF);
            return clone;
        }

        public UpperEnvelope getProbabilityFunction() {
            return this;
        }

        public Double getMean() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        @Override
        public Double sample(Random random) {
            ArrayList<LineSegment> ls = this.getLines();
            double p1 = random.nextDouble();
            int index = Arrays.binarySearch(this.segmentCDF, p1);
            if (index < 0) {
                int insertionPoint;
                index = insertionPoint = -index - 1;
            }
            LineSegment segment = ls.get(index);
            double p2 = random.nextDouble();
            return segment.sampleExp(p2);
        }

        @Override
        public ArrayList<Double> sample(Random random, int numSamples) {
            ArrayList<Double> samples = new ArrayList<Double>(numSamples);
            for (int n = 0; n < numSamples; ++n) {
                samples.add(this.sample(random));
            }
            return samples;
        }

        @Override
        protected void computeLines() {
            int numLines = (AdaptiveRejectionSampling.this.points.size() - 1) * 2;
            this.lines = new ArrayList(numLines);
            this.segmentCDF = new double[numLines];
            double totalMass = 0.0;
            Iterator iterator = AdaptiveRejectionSampling.this.points.iterator();
            double left = AdaptiveRejectionSampling.this.getMinSupport();
            double right = (Double)((Point)iterator.next()).getInput();
            PolynomialFunction.Linear leftLine = Point.line(0, AdaptiveRejectionSampling.this.points);
            LineSegment leftMost = new LineSegment(leftLine, left, right);
            double weight = leftMost.integrateExp();
            this.lines.add(leftMost);
            this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
            PolynomialFunction.Linear rightLine = Point.line(1, AdaptiveRejectionSampling.this.points);
            left = right;
            right = (Double)((Point)iterator.next()).getInput();
            LineSegment segment = new LineSegment(rightLine, left, right);
            weight = segment.integrateExp();
            this.lines.add(segment);
            this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
            int N = AdaptiveRejectionSampling.this.points.size();
            for (int index = 1; index < N - 2; ++index) {
                left = right;
                leftLine = Point.line(index - 1, AdaptiveRejectionSampling.this.points);
                rightLine = Point.line(index + 1, AdaptiveRejectionSampling.this.points);
                right = Point.intercept(leftLine, rightLine);
                segment = new LineSegment(leftLine, left, right);
                weight = segment.integrateExp();
                this.lines.add(segment);
                this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
                left = right;
                right = (Double)((Point)iterator.next()).getInput();
                segment = new LineSegment(rightLine, left, right);
                weight = segment.integrateExp();
                this.lines.add(segment);
                this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
            }
            left = right;
            right = (Double)((Point)iterator.next()).getInput();
            segment = new LineSegment(Point.line(N - 3, AdaptiveRejectionSampling.this.points), left, right);
            weight = segment.integrateExp();
            this.lines.add(segment);
            this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
            left = right;
            right = AdaptiveRejectionSampling.this.getMaxSupport();
            LineSegment rightMost = new LineSegment(Point.line(N - 2, AdaptiveRejectionSampling.this.points), left, right);
            weight = rightMost.integrateExp();
            this.lines.add(rightMost);
            this.segmentCDF[this.lines.size() - 1] = totalMass += weight;
            int i = 0;
            while (i < this.lines.size()) {
                int n = i++;
                this.segmentCDF[n] = this.segmentCDF[n] / totalMass;
            }
        }
    }

    public abstract class AbstractEnvelope
    extends AbstractUnivariateScalarFunction {
        protected ArrayList<LineSegment> lines = null;

        public AbstractEnvelope clone() {
            AbstractEnvelope clone = (AbstractEnvelope)super.clone();
            clone.lines = ObjectUtil.cloneSmartElementsAsArrayList(this.getLines());
            return clone;
        }

        protected ArrayList<LineSegment> getLines() {
            if (this.lines == null) {
                this.computeLines();
            }
            return this.lines;
        }

        public void resetLines() {
            this.lines = null;
        }

        protected abstract void computeLines();

        public double logEvaluate(Double input) {
            return this.findLineSegment(input).evaluate(input);
        }

        public double evaluate(double input) {
            return Math.exp(this.logEvaluate(input));
        }

        protected LineSegment findLineSegment(Double input) {
            ArrayList<LineSegment> ls = this.getLines();
            int index = Collections.binarySearch(ls, input);
            return ls.get(index);
        }
    }
}

