/*
 * Decompiled with CFR 0.152.
 */
package cc.factorie.app.nlp.parse;

import cc.factorie.app.classify.backend.LinearMulticlassClassifier;
import cc.factorie.app.classify.backend.MulticlassClassifierTrainer;
import cc.factorie.app.classify.backend.OnlineLinearMulticlassTrainer;
import cc.factorie.app.classify.backend.OnlineLinearMulticlassTrainer$;
import cc.factorie.app.classify.backend.SVMMulticlassTrainer;
import cc.factorie.app.classify.backend.SVMMulticlassTrainer$;
import cc.factorie.app.nlp.Document;
import cc.factorie.app.nlp.Sentence;
import cc.factorie.app.nlp.load.AnnotationTypes$;
import cc.factorie.app.nlp.load.LoadConll2008$;
import cc.factorie.app.nlp.load.LoadOntonotes5$;
import cc.factorie.app.nlp.load.LoadWSJMalt$;
import cc.factorie.app.nlp.parse.ParseTree;
import cc.factorie.app.nlp.parse.ParserEval$;
import cc.factorie.app.nlp.parse.TransitionBasedParser;
import cc.factorie.app.nlp.parse.TransitionBasedParserArgs;
import cc.factorie.app.nlp.parse.TransitionBasedParserTrainer$;
import cc.factorie.la.Tensor2;
import cc.factorie.optimize.AdaGradRDA;
import cc.factorie.optimize.AdaGradRDA$;
import cc.factorie.optimize.OptimizableObjectives;
import cc.factorie.optimize.OptimizableObjectives$;
import cc.factorie.package$;
import cc.factorie.util.BoxedDouble;
import cc.factorie.util.CmdOptions;
import cc.factorie.util.FileUtils$;
import cc.factorie.util.HyperparameterMain;
import cc.factorie.util.HyperparameterMain$class;
import cc.factorie.variable.CategoricalDomain;
import cc.factorie.variable.DiscreteDomain;
import java.io.File;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.Iterable;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichDouble$;
import scala.runtime.RichInt$;
import scala.util.Random;

public final class TransitionBasedParserTrainer$
implements HyperparameterMain {
    public static final TransitionBasedParserTrainer$ MODULE$;

    static {
        new TransitionBasedParserTrainer$();
    }

    @Override
    public final void main(String[] args) {
        HyperparameterMain$class.main(this, args);
    }

    @Override
    public final BoxedDouble actualMain(String[] args) {
        return HyperparameterMain$class.actualMain(this, args);
    }

    @Override
    public double evaluateParameters(String[] args) {
        MulticlassClassifierTrainer<LinearMulticlassClassifier> multiclassClassifierTrainer;
        TransitionBasedParserArgs opts = new TransitionBasedParserArgs();
        Random random = new Random(0);
        opts.parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args));
        Predef$.MODULE$.assert(opts.trainFiles().wasInvoked() || opts.trainDir().wasInvoked());
        Seq sentencesFull = this.loadSentences$1(opts.trainFiles(), opts.trainDir(), opts);
        Seq devSentencesFull = this.loadSentences$1(opts.devFiles(), opts.devDir(), opts);
        Seq testSentencesFull = this.loadSentences$1(opts.testFiles(), opts.testDir(), opts);
        double trainPortionToTake = opts.trainPortion().wasInvoked() ? BoxesRunTime.unboxToDouble((Object)opts.trainPortion().value()) : 1.0;
        double testPortionToTake = opts.testPortion().wasInvoked() ? BoxesRunTime.unboxToDouble((Object)opts.testPortion().value()) : 1.0;
        Seq sentences = (Seq)sentencesFull.take((int)RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(trainPortionToTake * (double)sentencesFull.length())));
        Seq testSentences = (Seq)testSentencesFull.take((int)RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(testPortionToTake * (double)testSentencesFull.length())));
        Seq devSentences = (Seq)devSentencesFull.take((int)RichDouble$.MODULE$.floor$extension(Predef$.MODULE$.doubleWrapper(testPortionToTake * (double)devSentencesFull.length())));
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"Total train sentences: ").append((Object)BoxesRunTime.boxToInteger((int)sentences.size())).toString());
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"Total test sentences: ").append((Object)BoxesRunTime.boxToInteger((int)testSentences.size())).toString());
        int numBootstrappingIterations = BoxesRunTime.unboxToInt((Object)opts.bootstrapping().value());
        TransitionBasedParser c = new TransitionBasedParser();
        double l1 = (double)2 * BoxesRunTime.unboxToDouble((Object)opts.l1().value()) / (double)sentences.length();
        double l2 = (double)2 * BoxesRunTime.unboxToDouble((Object)opts.l2().value()) / (double)sentences.length();
        AdaGradRDA optimizer = new AdaGradRDA(BoxesRunTime.unboxToDouble((Object)opts.rate().value()), BoxesRunTime.unboxToDouble((Object)opts.delta().value()), l1, l2, AdaGradRDA$.MODULE$.$lessinit$greater$default$5());
        Predef$.MODULE$.println((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Initializing trainer (", " threads)"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{opts.nThreads().value()})));
        if (BoxesRunTime.unboxToBoolean((Object)opts.useSVM().value())) {
            multiclassClassifierTrainer = new SVMMulticlassTrainer(BoxesRunTime.unboxToInt((Object)opts.nThreads().value()), SVMMulticlassTrainer$.MODULE$.$lessinit$greater$default$2(), random);
        } else {
            AdaGradRDA x$35 = optimizer;
            boolean x$36 = BoxesRunTime.unboxToInt((Object)opts.nThreads().value()) > 1;
            int x$37 = BoxesRunTime.unboxToInt((Object)opts.nThreads().value());
            OptimizableObjectives.HingeMulticlass x$38 = OptimizableObjectives$.MODULE$.hingeMulticlass();
            int x$39 = BoxesRunTime.unboxToInt((Object)opts.maxIters().value());
            int x$40 = OnlineLinearMulticlassTrainer$.MODULE$.$lessinit$greater$default$5();
            multiclassClassifierTrainer = new OnlineLinearMulticlassTrainer(x$36, x$35, x$38, x$39, x$40, x$37, random);
        }
        SVMMulticlassTrainer trainer = multiclassClassifierTrainer;
        ((CategoricalDomain)c.featuresDomain().dimensionDomain()).gatherCounts_$eq(true);
        Predef$.MODULE$.println((Object)"Generating decisions...");
        c.generateDecisions((Iterable<Sentence>)sentences, c.ParserConstants().TRAINING(), BoxesRunTime.unboxToInt((Object)opts.nThreads().value()));
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"Before pruning # features ").append((Object)BoxesRunTime.boxToInteger((int)c.featuresDomain().dimensionDomain().size())).toString());
        DiscreteDomain qual$2 = c.featuresDomain().dimensionDomain();
        int x$41 = 2 * BoxesRunTime.unboxToInt((Object)opts.cutoff().value());
        boolean x$42 = ((CategoricalDomain)qual$2).trimBelowCount$default$2();
        ((CategoricalDomain)qual$2).trimBelowCount(x$41, x$42);
        c.featuresDomain().freeze();
        ((CategoricalDomain)c.featuresDomain().dimensionDomain()).gatherCounts_$eq(false);
        Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"After pruning # features ").append((Object)BoxesRunTime.boxToInteger((int)c.featuresDomain().dimensionDomain().size())).toString());
        Predef$.MODULE$.println((Object)"Training...");
        Iterable<TransitionBasedParser.ParseDecisionVariable> trainingVs = c.generateDecisions((Iterable<Sentence>)sentences, c.ParserConstants().TRAINING(), BoxesRunTime.unboxToInt((Object)opts.nThreads().value()));
        c.trainFromVariables(trainingVs, trainer, (Function1<LinearMulticlassClassifier, BoxedUnit>)new Serializable(sentences, testSentences, devSentences, c){
            private final Seq sentences$1;
            private final Seq testSentences$3;
            private final Seq devSentences$1;
            private final TransitionBasedParser c$1;

            public final void apply(LinearMulticlassClassifier cls) {
                TransitionBasedParserTrainer$.MODULE$.cc$factorie$app$nlp$parse$TransitionBasedParserTrainer$$evaluate$2(cls, this.sentences$1, this.testSentences$3, this.devSentences$1, this.c$1);
            }
            {
                this.sentences$1 = sentences$1;
                this.testSentences$3 = testSentences$3;
                this.devSentences$1 = devSentences$1;
                this.c$1 = c$1;
            }
        });
        trainingVs = null;
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), numBootstrappingIterations).foreach$mVc$sp((Function1)new Serializable(opts, sentences, testSentences, devSentences, c, trainer){
            private final TransitionBasedParserArgs opts$1;
            public final Seq sentences$1;
            public final Seq testSentences$3;
            public final Seq devSentences$1;
            public final TransitionBasedParser c$1;
            private final MulticlassClassifierTrainer trainer$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Predef$.MODULE$.println((Object)new StringBuilder().append((Object)"Boosting iteration ").append((Object)BoxesRunTime.boxToInteger((int)i)).toString());
                this.c$1.boosting((Iterable<Sentence>)this.sentences$1, BoxesRunTime.unboxToInt((Object)this.opts$1.nThreads().value()), this.trainer$1, (Function1<LinearMulticlassClassifier, BoxedUnit>)new Serializable(this){
                    private final /* synthetic */ anonfun.evaluateParameters.1 $outer;

                    public final void apply(LinearMulticlassClassifier cls) {
                        TransitionBasedParserTrainer$.MODULE$.cc$factorie$app$nlp$parse$TransitionBasedParserTrainer$$evaluate$2(cls, this.$outer.sentences$1, this.$outer.testSentences$3, this.$outer.devSentences$1, this.$outer.c$1);
                    }
                    {
                        if ($outer == null) {
                            throw null;
                        }
                        this.$outer = $outer;
                    }
                });
            }
            {
                this.opts$1 = opts$1;
                this.sentences$1 = sentences$1;
                this.testSentences$3 = testSentences$3;
                this.devSentences$1 = devSentences$1;
                this.c$1 = c$1;
                this.trainer$1 = trainer$1;
            }
        });
        if (BoxesRunTime.unboxToBoolean((Object)opts.saveModel().value())) {
            String modelUrl = opts.modelDir().wasInvoked() ? opts.modelDir().value() : new StringBuilder().append((Object)opts.modelDir().defaultValue()).append((Object)((Object)BoxesRunTime.boxToLong((long)System.currentTimeMillis())).toString()).append((Object)".factorie").toString();
            c.serialize(new File(modelUrl));
            TransitionBasedParser d = new TransitionBasedParser();
            d.deserialize(new File(modelUrl));
            this.testSingle$1(d, testSentences, "Post serialization accuracy ");
        }
        double testLAS = ParserEval$.MODULE$.calcLas((Iterable<ParseTree>)((Iterable)testSentences.map((Function1)new Serializable(){

            public final ParseTree apply(Sentence x$18) {
                return (ParseTree)x$18.attr().apply(ClassTag$.MODULE$.apply(ParseTree.class));
            }
        }, Seq$.MODULE$.canBuildFrom())), ParserEval$.MODULE$.calcLas$default$2());
        if (opts.targetAccuracy().wasInvoked()) {
            package$.MODULE$.assertMinimalAccuracy(testLAS, new StringOps(Predef$.MODULE$.augmentString(opts.targetAccuracy().value())).toDouble());
        }
        return testLAS;
    }

    private final Seq loadSentences$1(CmdOptions.CmdOption listOpt, CmdOptions.CmdOption dirOpt, TransitionBasedParserArgs opts$1) {
        Seq fileList = (Seq)Seq$.MODULE$.empty();
        if (listOpt.wasInvoked()) {
            fileList = ((scala.collection.immutable.Seq)listOpt.value()).toSeq();
        }
        if (dirOpt.wasInvoked()) {
            fileList = (Seq)fileList.$plus$plus(FileUtils$.MODULE$.getFileListFromDir((String)dirOpt.value(), FileUtils$.MODULE$.getFileListFromDir$default$2()), Seq$.MODULE$.canBuildFrom());
        }
        return (Seq)fileList.flatMap((Function1)new Serializable(opts$1){
            private final TransitionBasedParserArgs opts$1;

            public final Seq<Sentence> apply(String fname) {
                Seq seq;
                if (BoxesRunTime.unboxToBoolean((Object)this.opts$1.wsj().value())) {
                    String x$29 = fname;
                    int x$30 = AnnotationTypes$.MODULE$.AUTO();
                    int x$31 = LoadWSJMalt$.MODULE$.fromFilename$default$2();
                    int x$32 = LoadWSJMalt$.MODULE$.fromFilename$default$4();
                    boolean x$33 = LoadWSJMalt$.MODULE$.fromFilename$default$5();
                    boolean x$34 = LoadWSJMalt$.MODULE$.fromFilename$default$6();
                    seq = ((Document)LoadWSJMalt$.MODULE$.fromFilename(x$29, x$31, x$30, x$32, x$33, x$34).head()).sentences().toSeq();
                } else {
                    seq = BoxesRunTime.unboxToBoolean((Object)this.opts$1.ontonotes().value()) ? ((Document)LoadOntonotes5$.MODULE$.fromFilename(fname, AnnotationTypes$.MODULE$.AUTO(), AnnotationTypes$.MODULE$.AUTO(), LoadOntonotes5$.MODULE$.fromFilename$default$4(), LoadOntonotes5$.MODULE$.fromFilename$default$5(), LoadOntonotes5$.MODULE$.fromFilename$default$6()).head()).sentences().toSeq() : ((Document)LoadConll2008$.MODULE$.fromFilename(fname).head()).sentences().toSeq();
                }
                return seq;
            }
            {
                this.opts$1 = opts$1;
            }
        }, Seq$.MODULE$.canBuildFrom());
    }

    private final void testSingle$1(TransitionBasedParser c, Seq ss, String extraText) {
        if (ss.nonEmpty()) {
            Predef$.MODULE$.println((Object)new StringBuilder().append((Object)extraText).append((Object)" ").append((Object)c.testString((Seq<Sentence>)ss)).toString());
        }
    }

    private final String testSingle$default$3$1() {
        return "";
    }

    private final void testAll$1(TransitionBasedParser c, String extraText, Seq sentences$1, Seq testSentences$3, Seq devSentences$1) {
        Predef$.MODULE$.println((Object)"\n");
        this.testSingle$1(c, sentences$1, new StringBuilder().append((Object)"Train ").append((Object)extraText).toString());
        this.testSingle$1(c, devSentences$1, new StringBuilder().append((Object)"Dev ").append((Object)extraText).toString());
        this.testSingle$1(c, testSentences$3, new StringBuilder().append((Object)"Test ").append((Object)extraText).toString());
    }

    private final String testAll$default$2$1() {
        return "";
    }

    public final void cc$factorie$app$nlp$parse$TransitionBasedParserTrainer$$evaluate$2(LinearMulticlassClassifier cls, Seq sentences$1, Seq testSentences$3, Seq devSentences$1, TransitionBasedParser c$1) {
        Predef$.MODULE$.println((Object)new StringBuilder().append((float)cls.weights().value().toSeq().count((Function1)new Serializable(){

            public final boolean apply(double x) {
                return this.apply$mcZD$sp(x);
            }

            public boolean apply$mcZD$sp(double x) {
                return x == 0.0;
            }
        }) / (float)((Tensor2)cls.weights().value()).length()).append((Object)" sparsity").toString());
        this.testAll$1(c$1, "iteration ", sentences$1, testSentences$3, devSentences$1);
    }

    private TransitionBasedParserTrainer$() {
        MODULE$ = this;
        HyperparameterMain$class.$init$(this);
    }
}

