/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.confidence;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractSupervisedBatchAndIncrementalLearner;
import gov.sandia.cognition.learning.function.categorization.DiagonalConfidenceWeightedBinaryCategorizer;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.ArgumentChecker;

@PublicationReference(title="Confidence-Weighted Linear Classification", author={"Mark Dredze", "Koby Crammer", "Fernando Pereira"}, year=2008, type=PublicationType.Conference, publication="International Conference on Machine Learning", url="http://portal.acm.org/citation.cfm?id=1390190")
public class ConfidenceWeightedDiagonalVariance
extends AbstractSupervisedBatchAndIncrementalLearner<Vectorizable, Boolean, DiagonalConfidenceWeightedBinaryCategorizer> {
    public static final double DEFAULT_CONFIDENCE = 0.85;
    public static final double DEFAULT_DEFAULT_VARIANCE = 1.0;
    protected double confidence;
    protected double defaultVariance;
    protected double phi;

    public ConfidenceWeightedDiagonalVariance() {
        this(0.85, 1.0);
    }

    public ConfidenceWeightedDiagonalVariance(double confidence, double defaultVariance) {
        this.setConfidence(confidence);
        this.setDefaultVariance(defaultVariance);
    }

    @Override
    public DiagonalConfidenceWeightedBinaryCategorizer createInitialLearnedObject() {
        return new DiagonalConfidenceWeightedBinaryCategorizer();
    }

    @Override
    public void update(DiagonalConfidenceWeightedBinaryCategorizer target, Vectorizable input, Boolean output) {
        if (input != null && output != null) {
            this.update(target, input.convertToVector(), (boolean)output);
        }
    }

    @Override
    public void update(DiagonalConfidenceWeightedBinaryCategorizer target, Vector input, boolean label) {
        double denominator;
        Vector variance;
        Vector mean;
        if (!target.isInitialized()) {
            int dimensionality = input.getDimensionality();
            mean = VectorFactory.getDenseDefault().createVector(dimensionality);
            variance = VectorFactory.getDenseDefault().createVector(dimensionality, this.getDefaultVariance());
            target.setMean(mean);
            target.setVariance(variance);
        } else {
            mean = target.getMean();
            variance = target.getVariance();
        }
        double predicted = input.dotProduct(mean);
        double actual = label ? 1.0 : -1.0;
        double margin = actual * predicted;
        Vector varianceTimesInput = (Vector)input.dotTimes((Ring)variance);
        double marginVariance = input.dotProduct(varianceTimesInput);
        if (marginVariance == 0.0 || margin > this.phi * marginVariance) {
            return;
        }
        double meanPart = 1.0 + 2.0 * this.phi * margin;
        double variancePart = margin - this.phi * marginVariance;
        double numerator = -meanPart + Math.sqrt(meanPart * meanPart - 8.0 * this.phi * variancePart);
        double alpha = numerator / (denominator = 4.0 * this.phi * marginVariance);
        if (alpha <= 0.0) {
            return;
        }
        Vector meanUpdate = (Vector)varianceTimesInput.scale(actual * alpha);
        mean.plusEquals((Ring)meanUpdate);
        double twoAlphaPhi = 2.0 * alpha * this.phi;
        Vector varianceUpdate = (Vector)varianceTimesInput.dotTimes((Ring)varianceTimesInput);
        varianceUpdate.scaleEquals(-twoAlphaPhi / (1.0 + twoAlphaPhi * marginVariance));
        variance.plusEquals((Ring)varianceUpdate);
        target.setMean(mean);
        target.setVariance(variance);
    }

    public double getConfidence() {
        return this.confidence;
    }

    public void setConfidence(double confidence) {
        ArgumentChecker.assertIsInRangeInclusive((String)"confidence", (double)confidence, (double)0.0, (double)1.0);
        this.confidence = confidence;
        this.phi = -UnivariateGaussian.CDF.Inverse.evaluate(1.0 - confidence, 0.0, 1.0);
    }

    public double getDefaultVariance() {
        return this.defaultVariance;
    }

    public void setDefaultVariance(double defaultVariance) {
        ArgumentChecker.assertIsPositive((String)"defaultVariance", (double)defaultVariance);
        this.defaultVariance = defaultVariance;
    }
}

