package com.johnsnowlabs.nlp.annotators.parser.typdep;

import com.johnsnowlabs.nlp.annotators.parser.typdep.io.ConllWriter;
import com.johnsnowlabs.nlp.annotators.parser.typdep.util.DependencyLabel;
import java.io.Serializable;
import java.util.ArrayList;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/johnsnowlabs/nlp/annotators/parser/typdep/TypedDependencyParser.class */
public class TypedDependencyParser implements Serializable {
    private static final long serialVersionUID = 1;
    private transient Logger logger = LoggerFactory.getLogger("TypedDependencyParser");
    private Options options;
    private DependencyPipe dependencyPipe;
    private Parameters parameters;

    /* JADX INFO: Access modifiers changed from: package-private */
    public DependencyPipe getDependencyPipe() {
        return this.dependencyPipe;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Parameters getParameters() {
        return this.parameters;
    }

    public Options getOptions() {
        return this.options;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setDependencyPipe(DependencyPipe dependencyPipe) {
        this.dependencyPipe = dependencyPipe;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void setParameters(Parameters parameters) {
        this.parameters = parameters;
    }

    public void setOptions(Options options) {
        this.options = options;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void train(DependencyInstance[] dependencyInstanceArr) {
        if ((this.options.rankFirstOrderTensor > 0 || this.options.rankSecondOrderTensor > 0) && this.options.gammaLabel < 1.0f && this.options.initTensorWithPretrain) {
            Options newInstance = Options.newInstance(this.options);
            this.options.rankFirstOrderTensor = 0;
            this.options.rankSecondOrderTensor = 0;
            this.options.gammaLabel = 1.0f;
            newInstance.setNumberOfTrainingIterations(this.options.numberOfPreTrainingIterations);
            this.parameters.setRankFirstOrderTensor(this.options.rankFirstOrderTensor);
            this.parameters.setRankSecondOrderTensor(this.options.rankSecondOrderTensor);
            this.parameters.setGammaLabel(this.options.gammaLabel);
            this.logger.debug("Pre-training:%n");
            long currentTimeMillis = System.currentTimeMillis();
            this.logger.debug("Running MIRA ... ");
            trainIterations(dependencyInstanceArr);
            this.options = newInstance;
            this.parameters.setRankFirstOrderTensor(this.options.rankFirstOrderTensor);
            this.parameters.setRankSecondOrderTensor(this.options.rankSecondOrderTensor);
            this.parameters.setGammaLabel(this.options.gammaLabel);
            this.logger.debug("Init tensor ... ");
            int numberWordFeatures = this.parameters.getNumberWordFeatures();
            int dl = this.parameters.getDL();
            LowRankTensor lowRankTensor = new LowRankTensor(new int[]{numberWordFeatures, numberWordFeatures, dl}, this.options.rankFirstOrderTensor);
            LowRankTensor lowRankTensor2 = new LowRankTensor(new int[]{numberWordFeatures, numberWordFeatures, numberWordFeatures, dl, dl}, this.options.rankSecondOrderTensor);
            this.dependencyPipe.getSynFactory().fillParameters(lowRankTensor, lowRankTensor2, this.parameters);
            ArrayList<float[][]> arrayList = new ArrayList<>();
            arrayList.add(this.parameters.getU());
            arrayList.add(this.parameters.getV());
            arrayList.add(this.parameters.getWL());
            lowRankTensor.decompose(arrayList);
            ArrayList<float[][]> arrayList2 = new ArrayList<>();
            arrayList2.add(this.parameters.getU2());
            arrayList2.add(this.parameters.getV2());
            arrayList2.add(this.parameters.getW2());
            arrayList2.add(this.parameters.getX2L());
            arrayList2.add(this.parameters.getY2L());
            lowRankTensor2.decompose(arrayList2);
            this.parameters.assignTotal();
            this.parameters.printStat();
            long currentTimeMillis2 = System.currentTimeMillis();
            if (this.logger.isDebugEnabled()) {
                this.logger.debug(String.format("Pre-training took %d ms.%n", Long.valueOf(currentTimeMillis2 - currentTimeMillis)));
            }
        } else {
            this.parameters.randomlyInit();
        }
        this.logger.debug(" Training:%n");
        long currentTimeMillis3 = System.currentTimeMillis();
        this.logger.debug("Running MIRA ... ");
        trainIterations(dependencyInstanceArr);
        long currentTimeMillis4 = System.currentTimeMillis();
        if (this.logger.isDebugEnabled()) {
            this.logger.debug(String.format("Training took %d ms.%n", Long.valueOf(currentTimeMillis4 - currentTimeMillis3)));
        }
    }

    private void trainIterations(DependencyInstance[] dependencyInstanceArr) {
        int length = 10000 < dependencyInstanceArr.length ? dependencyInstanceArr.length / 10 : 1000;
        if (this.logger.isDebugEnabled()) {
            this.logger.debug(String.format("Number of Training Iterations: %d", Integer.valueOf(this.options.getNumberOfTrainingIterations())));
        }
        for (int i = 0; i < this.options.getNumberOfTrainingIterations(); i++) {
            double d = 0.0d;
            int i2 = 0;
            int i3 = 0;
            long currentTimeMillis = System.currentTimeMillis();
            for (int i4 = 0; i4 < dependencyInstanceArr.length; i4++) {
                if ((i4 + 1) % length == 0 && this.logger.isDebugEnabled()) {
                    this.logger.debug(String.format("  %d (time=%ds)", Integer.valueOf(i4 + 1), Long.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000)));
                }
                DependencyInstance dependencyInstance = dependencyInstanceArr[i4];
                LocalFeatureData localFeatureData = new LocalFeatureData(dependencyInstance, this);
                int length2 = dependencyInstance.getLength();
                int[] heads = dependencyInstance.getHeads();
                int[] iArr = new int[length2];
                localFeatureData.predictLabels(heads, iArr, true);
                int numberCorrectMatches = getNumberCorrectMatches(dependencyInstance.getHeads(), dependencyInstance.getDependencyLabelIds(), heads, iArr);
                if (numberCorrectMatches != length2 - 1) {
                    d += this.parameters.updateLabel(dependencyInstance, heads, iArr, localFeatureData, (i * dependencyInstanceArr.length) + i4 + 1);
                }
                i2 += numberCorrectMatches;
                i3 += length2 - 1;
            }
            int i5 = i3 == 0 ? 1 : i3;
            if (this.logger.isDebugEnabled()) {
                this.logger.debug(String.format("%n Iter %d loss=%.4f totalNUmberCorrectMatches=%.4f [%ds]%n", Integer.valueOf(i + 1), Double.valueOf(d), Double.valueOf(i2 / (i5 + 0.0d)), Long.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000)));
            }
            this.parameters.printStat();
        }
    }

    private int getNumberCorrectMatches(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4) {
        int i = 0;
        int length = iArr.length;
        for (int i2 = 1; i2 < length; i2++) {
            if (iArr[i2] == iArr3[i2] && iArr2[i2] == iArr4[i2]) {
                i++;
            }
        }
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DependencyLabel[] predictDependency(ConllData[][] conllDataArr, String str) {
        ConllWriter conllWriter = new ConllWriter(this.options, this.dependencyPipe);
        DependencyLabel[] dependencyLabelArr = new DependencyLabel[conllDataArr[0].length];
        for (ConllData[] conllDataArr2 : conllDataArr) {
            DependencyInstance nextSentence = this.dependencyPipe.nextSentence(conllDataArr2, str);
            if (nextSentence == null) {
                break;
            }
            LocalFeatureData localFeatureData = new LocalFeatureData(nextSentence, this);
            int length = nextSentence.getLength();
            int[] heads = nextSentence.getHeads();
            int[] iArr = new int[length];
            localFeatureData.predictLabels(heads, iArr, true);
            dependencyLabelArr = conllWriter.getDependencyLabels(nextSentence, heads, iArr);
        }
        return dependencyLabelArr;
    }
}
