/*
 * Decompiled with CFR 0.152.
 */
package gov.sandia.cognition.learning.algorithm.perceptron;

import gov.sandia.cognition.algorithm.MeasurablePerformanceAlgorithm;
import gov.sandia.cognition.annotation.CodeReview;
import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.DatasetUtil;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.LinearBinaryCategorizer;
import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.VectorFactoryContainer;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.util.CloneableSerializable;
import gov.sandia.cognition.util.DefaultNamedValue;
import gov.sandia.cognition.util.NamedValue;
import java.util.Collection;

@CodeReview(reviewer={"Kevin R. Dixon"}, date="2008-07-23", changesNeeded=false, comments={"Added PublicationReference to Wikiepedia article.", "Minor changes to javadoc.", "Looks fine."})
@PublicationReference(author={"Wikipedia"}, title="Perceptron Learning algorithm", type=PublicationType.WebPage, year=2008, url="http://en.wikipedia.org/wiki/Perceptron#Learning_algorithm")
public class Perceptron
extends AbstractAnytimeSupervisedBatchLearner<Vectorizable, Boolean, LinearBinaryCategorizer>
implements MeasurablePerformanceAlgorithm,
CloneableSerializable,
VectorFactoryContainer {
    public static final int DEFAULT_MAX_ITERATIONS = 100;
    public static final double DEFAULT_MARGIN_POSITIVE = 0.0;
    public static final double DEFAULT_MARGIN_NEGATIVE = 0.0;
    private double marginPositive;
    private double marginNegative;
    private VectorFactory<?> vectorFactory;
    private LinearBinaryCategorizer result;
    private int errorCount;

    public Perceptron() {
        this(100);
    }

    public Perceptron(int maxIterations) {
        this(maxIterations, 0.0, 0.0);
    }

    public Perceptron(int maxIterations, double marginPositive, double marginNegative) {
        this(maxIterations, marginPositive, marginNegative, VectorFactory.getDefault());
    }

    public Perceptron(int maxIterations, double marginPositive, double marginNegative, VectorFactory vectorFactory) {
        super(maxIterations);
        this.setMarginPositive(marginPositive);
        this.setMarginNegative(marginNegative);
        this.setVectorFactory(vectorFactory);
    }

    @Override
    public Perceptron clone() {
        Perceptron clone = (Perceptron)super.clone();
        clone.result = null;
        clone.errorCount = 0;
        return clone;
    }

    @Override
    protected boolean initializeAlgorithm() {
        if (this.getData() == null) {
            return false;
        }
        int dimensionality = DatasetUtil.getInputDimensionality((Iterable)this.getData());
        if (dimensionality < 0) {
            return false;
        }
        DatasetUtil.assertInputDimensionalitiesAllEqual((Iterable)this.getData());
        this.setResult(new LinearBinaryCategorizer(this.getVectorFactory().createVector(dimensionality), 0.0));
        return true;
    }

    @Override
    protected boolean step() {
        this.setErrorCount(0);
        for (InputOutputPair example : (Collection)this.getData()) {
            if (example == null) continue;
            Vector input = ((Vectorizable)example.getInput()).convertToVector();
            boolean actual = (Boolean)example.getOutput();
            double prediction = this.result.evaluateAsDouble(input);
            if (!(actual && prediction <= this.marginPositive) && (actual || !(prediction >= -this.marginNegative))) continue;
            this.setErrorCount(this.getErrorCount() + 1);
            Vector weights = this.result.getWeights();
            double bias = this.result.getBias();
            if (actual) {
                weights.plusEquals((Ring)input);
                bias += 1.0;
            } else {
                weights.minusEquals((Ring)input);
                bias -= 1.0;
            }
            this.result.setBias(bias);
        }
        return this.getErrorCount() > 0;
    }

    @Override
    protected void cleanupAlgorithm() {
    }

    public void setMargin(double margin) {
        this.setMarginPositive(margin);
        this.setMarginNegative(margin);
    }

    public double getMarginPositive() {
        return this.marginPositive;
    }

    public void setMarginPositive(double marginPositive) {
        this.marginPositive = marginPositive;
    }

    public double getMarginNegative() {
        return this.marginNegative;
    }

    public void setMarginNegative(double marginNegative) {
        this.marginNegative = marginNegative;
    }

    public VectorFactory<?> getVectorFactory() {
        return this.vectorFactory;
    }

    public void setVectorFactory(VectorFactory<?> vectorFactory) {
        this.vectorFactory = vectorFactory;
    }

    public LinearBinaryCategorizer getResult() {
        return this.result;
    }

    protected void setResult(LinearBinaryCategorizer result) {
        this.result = result;
    }

    public int getErrorCount() {
        return this.errorCount;
    }

    protected void setErrorCount(int errorCount) {
        this.errorCount = errorCount;
    }

    public NamedValue<Integer> getPerformance() {
        return new DefaultNamedValue("error count", (Object)this.getErrorCount());
    }
}

