package org.deeplearning4j.spark.models.embeddings.glove;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.util.FastMath;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCalculator;
import org.deeplearning4j.spark.models.embeddings.glove.cooccurrences.CoOccurrenceCounts;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2VecVariables;
import org.deeplearning4j.spark.text.functions.TextPipeline;
import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.primitives.CounterMap;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/glove/Glove.class */
public class Glove implements Serializable {
    private Broadcast<VocabCache<VocabWord>> vocabCacheBroadcast;
    private String tokenizerFactoryClazz;
    private boolean symmetric;
    private int windowSize;
    private int iterations;
    private static Logger log = LoggerFactory.getLogger(Glove.class);

    public Glove(String str, boolean z, int i, int i2) {
        this.tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
        this.symmetric = true;
        this.windowSize = 15;
        this.iterations = 300;
        this.tokenizerFactoryClazz = str;
        this.symmetric = z;
        this.windowSize = i;
        this.iterations = i2;
    }

    public Glove(boolean z, int i, int i2) {
        this.tokenizerFactoryClazz = DefaultTokenizerFactory.class.getName();
        this.symmetric = true;
        this.windowSize = 15;
        this.iterations = 300;
        this.symmetric = z;
        this.windowSize = i;
        this.iterations = i2;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<INDArray, Float> update(AdaGrad adaGrad, AdaGrad adaGrad2, INDArray iNDArray, INDArray iNDArray2, VocabWord vocabWord, INDArray iNDArray3, INDArray iNDArray4, double d) {
        INDArray gradient = adaGrad.getGradient(iNDArray4.mul(Double.valueOf(d)), vocabWord.getIndex(), ArrayUtil.toInts(iNDArray.shape()));
        iNDArray3.subi(gradient);
        double d2 = iNDArray2.getDouble(vocabWord.getIndex()) - adaGrad2.getGradient(d, vocabWord.getIndex(), ArrayUtil.toInts(iNDArray2.shape()));
        iNDArray2.putScalar(vocabWord.getIndex(), iNDArray2.getDouble(vocabWord.getIndex()) - d2);
        return new Pair<>(gradient, Float.valueOf((float) d2));
    }

    public Pair<VocabCache<VocabWord>, GloveWeightLookupTable> train(JavaRDD<String> javaRDD) throws Exception {
        JavaSparkContext javaSparkContext = new JavaSparkContext(javaRDD.context());
        SparkConf conf = javaSparkContext.getConf();
        ((Integer) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.length", conf, Integer.class)).intValue();
        ((Boolean) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.adagrad", conf, Boolean.class)).booleanValue();
        ((Double) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.negative", conf, Double.class)).doubleValue();
        final int intValue = ((Integer) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.numwords", conf, Integer.class)).intValue();
        ((Integer) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.window", conf, Integer.class)).intValue();
        ((Double) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.alpha", conf, Double.class)).doubleValue();
        ((Double) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.minalpha", conf, Double.class)).doubleValue();
        int intValue2 = ((Integer) Word2VecVariables.assignVar("org.deeplearning4j.scaleout.perform.models.word2vec.iterations", conf, Integer.class)).intValue();
        final int intValue3 = ((Integer) Word2VecVariables.assignVar(Word2VecVariables.N_GRAMS, conf, Integer.class)).intValue();
        final String str = (String) Word2VecVariables.assignVar(Word2VecVariables.TOKENIZER, conf, String.class);
        final String str2 = (String) Word2VecVariables.assignVar(Word2VecVariables.TOKEN_PREPROCESSOR, conf, String.class);
        final boolean booleanValue = ((Boolean) Word2VecVariables.assignVar(Word2VecVariables.REMOVE_STOPWORDS, conf, Boolean.class)).booleanValue();
        TextPipeline textPipeline = new TextPipeline(javaRDD, javaSparkContext.broadcast(new HashMap<String, Object>() { // from class: org.deeplearning4j.spark.models.embeddings.glove.Glove.1
            {
                put("numWords", Integer.valueOf(intValue));
                put("nGrams", Integer.valueOf(intValue3));
                put("tokenizer", str);
                put("tokenPreprocessor", str2);
                put("removeStop", Boolean.valueOf(booleanValue));
            }
        }));
        textPipeline.buildVocabCache();
        textPipeline.buildVocabWordListRDD();
        Long totalWordCount = textPipeline.getTotalWordCount();
        VocabCache<VocabWord> vocabCache = textPipeline.getVocabCache();
        JavaRDD<Pair<List<String>, AtomicLong>> sentenceWordsCountRDD = textPipeline.getSentenceWordsCountRDD();
        Pair pair = new Pair(vocabCache, totalWordCount);
        this.vocabCacheBroadcast = javaSparkContext.broadcast(pair.getFirst());
        final GloveWeightLookupTable build = new GloveWeightLookupTable.Builder().cache((VocabCache) pair.getFirst()).lr(conf.getDouble(GlovePerformer.ALPHA, 0.01d)).maxCount(conf.getDouble(GlovePerformer.MAX_COUNT, 100.0d)).vectorLength(conf.getInt(GlovePerformer.VECTOR_LENGTH, 300)).xMax(conf.getDouble(GlovePerformer.X_MAX, 0.75d)).build();
        build.resetWeights();
        build.getBiasAdaGrad().historicalGradient = Nd4j.ones(new int[]{build.getSyn0().rows()});
        build.getWeightAdaGrad().historicalGradient = Nd4j.ones(build.getSyn0().shape());
        log.info("Created lookup table of size " + Arrays.toString(build.getSyn0().shape()));
        CounterMap counterMap = (CounterMap) sentenceWordsCountRDD.map(new CoOccurrenceCalculator(this.symmetric, this.vocabCacheBroadcast, this.windowSize)).fold(new CounterMap(), new CoOccurrenceCounts());
        Iterator iterator = counterMap.getIterator();
        ArrayList arrayList = new ArrayList();
        while (iterator.hasNext()) {
            Pair pair2 = (Pair) iterator.next();
            if (counterMap.getCount(pair2.getFirst(), pair2.getSecond()) > build.getMaxCount()) {
                counterMap.setCount(pair2.getFirst(), pair2.getSecond(), (float) build.getMaxCount());
            }
            arrayList.add(new Triple(pair2.getFirst(), pair2.getSecond(), Float.valueOf((float) counterMap.getCount(pair2.getFirst(), pair2.getSecond()))));
        }
        log.info("Calculated co occurrences");
        JavaPairRDD mapToPair = javaSparkContext.parallelize(arrayList).mapToPair(new PairFunction<Triple<String, String, Float>, String, Tuple2<String, Float>>() { // from class: org.deeplearning4j.spark.models.embeddings.glove.Glove.2
            public Tuple2<String, Tuple2<String, Float>> call(Triple<String, String, Float> triple) throws Exception {
                return new Tuple2<>(triple.getFirst(), new Tuple2(triple.getSecond(), triple.getThird()));
            }
        }).mapToPair(new PairFunction<Tuple2<String, Tuple2<String, Float>>, VocabWord, Tuple2<VocabWord, Float>>() { // from class: org.deeplearning4j.spark.models.embeddings.glove.Glove.3
            public Tuple2<VocabWord, Tuple2<VocabWord, Float>> call(Tuple2<String, Tuple2<String, Float>> tuple2) throws Exception {
                return new Tuple2<>(((VocabCache) Glove.this.vocabCacheBroadcast.getValue()).wordFor((String) tuple2._1()), new Tuple2(((VocabCache) Glove.this.vocabCacheBroadcast.getValue()).wordFor((String) ((Tuple2) tuple2._2())._1()), ((Tuple2) tuple2._2())._2()));
            }
        });
        for (int i = 0; i < intValue2; i++) {
            double d = 0.0d;
            for (GloveChange gloveChange : mapToPair.map(new Function<Tuple2<VocabWord, Tuple2<VocabWord, Float>>, GloveChange>() { // from class: org.deeplearning4j.spark.models.embeddings.glove.Glove.4
                public GloveChange call(Tuple2<VocabWord, Tuple2<VocabWord, Float>> tuple2) throws Exception {
                    VocabWord vocabWord = (VocabWord) tuple2._1();
                    VocabWord vocabWord2 = (VocabWord) ((Tuple2) tuple2._2())._1();
                    INDArray slice = build.getSyn0().slice(vocabWord.getIndex());
                    INDArray slice2 = build.getSyn0().slice(vocabWord2.getIndex());
                    INDArray bias = build.getBias();
                    double floatValue = ((Float) ((Tuple2) tuple2._2())._2()).floatValue();
                    double d2 = build.getxMax();
                    double maxCount = build.getMaxCount();
                    double dot = Nd4j.getBlasWrapper().dot(slice, slice2) + bias.getDouble(vocabWord.getIndex()) + bias.getDouble(vocabWord2.getIndex());
                    double pow = floatValue > d2 ? dot : FastMath.pow(Math.min(1.0d, floatValue / maxCount), d2) * (dot - Math.log(floatValue));
                    if (Double.isNaN(pow)) {
                        pow = Nd4j.EPS_THRESHOLD;
                    }
                    double d3 = pow;
                    return new GloveChange(vocabWord, vocabWord2, (INDArray) Glove.this.update(build.getWeightAdaGrad(), build.getBiasAdaGrad(), build.getSyn0(), build.getBias(), vocabWord, slice, slice2, d3).getFirst(), (INDArray) Glove.this.update(build.getWeightAdaGrad(), build.getBiasAdaGrad(), build.getSyn0(), build.getBias(), vocabWord2, slice2, slice, d3).getFirst(), ((Float) r0.getSecond()).floatValue(), ((Float) r0.getSecond()).floatValue(), pow, build.getWeightAdaGrad().getHistoricalGradient().slice(vocabWord.getIndex()), build.getWeightAdaGrad().getHistoricalGradient().slice(vocabWord2.getIndex()), build.getBiasAdaGrad().getHistoricalGradient().getDouble(vocabWord2.getIndex()), build.getBiasAdaGrad().getHistoricalGradient().getDouble(vocabWord.getIndex()));
                }
            }).collect()) {
                gloveChange.apply(build);
                d += gloveChange.getError();
            }
            List collect = mapToPair.collect();
            Collections.shuffle(collect);
            mapToPair = javaSparkContext.parallelizePairs(collect);
            log.info("Error at iteration " + i + " was " + d);
        }
        return new Pair<>(pair.getFirst(), build);
    }
}
