package org.deeplearning4j.spark.models.embeddings.word2vec;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.class */
public class VocabHolder implements Serializable {
    private static VocabHolder ourInstance = new VocabHolder();
    private Map<VocabWord, INDArray> indexSyn0VecMap = new ConcurrentHashMap();
    private Map<Integer, INDArray> pointSyn1VecMap = new ConcurrentHashMap();
    private HashSet<Long> workers = new LinkedHashSet();
    private AtomicLong seed = new AtomicLong(0);
    private AtomicInteger vectorLength = new AtomicInteger(0);

    public static VocabHolder getInstance() {
        return ourInstance;
    }

    private VocabHolder() {
    }

    public void setSeed(long j, int i) {
        this.seed.set(j);
        this.vectorLength.set(i);
    }

    public INDArray getSyn0Vector(Integer num, VocabCache<VocabWord> vocabCache) {
        if (!this.workers.contains(Long.valueOf(Thread.currentThread().getId()))) {
            this.workers.add(Long.valueOf(Thread.currentThread().getId()));
        }
        VocabWord elementAtIndex = vocabCache.elementAtIndex(num.intValue());
        if (!this.indexSyn0VecMap.containsKey(elementAtIndex)) {
            synchronized (this) {
                if (!this.indexSyn0VecMap.containsKey(elementAtIndex)) {
                    this.indexSyn0VecMap.put(elementAtIndex, getRandomSyn0Vec(this.vectorLength.get(), num.intValue()));
                }
            }
        }
        return this.indexSyn0VecMap.get(elementAtIndex);
    }

    public INDArray getSyn1Vector(Integer num) {
        if (!this.pointSyn1VecMap.containsKey(num)) {
            synchronized (this) {
                if (!this.pointSyn1VecMap.containsKey(num)) {
                    this.pointSyn1VecMap.put(num, Nd4j.zeros(1L, this.vectorLength.get()));
                }
            }
        }
        return this.pointSyn1VecMap.get(num);
    }

    private INDArray getRandomSyn0Vec(int i, long j) {
        return Nd4j.rand(new int[]{1, i}, j * this.seed.get()).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(i));
    }

    public Iterable<Map.Entry<VocabWord, INDArray>> getSplit(VocabCache<VocabWord> vocabCache) {
        HashSet hashSet = new HashSet();
        int i = 0;
        Iterator<Map.Entry<VocabWord, INDArray>> it = this.indexSyn0VecMap.entrySet().iterator();
        while (it.hasNext()) {
            hashSet.add(it.next());
            i++;
            if (i > 10) {
                break;
            }
        }
        System.out.println("Returning set: " + hashSet.size());
        return hashSet;
    }
}
