/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.common.score;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public abstract class BaseVaeScoreWithKeyFunction<K>
implements PairFlatMapFunction<Iterator<Tuple2<K, INDArray>>, K, Double> {
    private static final Logger log = LoggerFactory.getLogger(BaseVaeScoreWithKeyFunction.class);
    protected final Broadcast<INDArray> params;
    protected final Broadcast<String> jsonConfig;
    private final int batchSize;

    public BaseVaeScoreWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, int batchSize) {
        this.params = params;
        this.jsonConfig = jsonConfig;
        this.batchSize = batchSize;
    }

    public abstract VariationalAutoencoder getVaeLayer();

    public abstract INDArray computeScore(VariationalAutoencoder var1, INDArray var2);

    public Iterator<Tuple2<K, Double>> call(Iterator<Tuple2<K, INDArray>> iterator) throws Exception {
        if (!iterator.hasNext()) {
            return Collections.emptyIterator();
        }
        VariationalAutoencoder vae = this.getVaeLayer();
        ArrayList<Tuple2> ret = new ArrayList<Tuple2>();
        ArrayList<INDArray> collect = new ArrayList<INDArray>(this.batchSize);
        ArrayList<Object> collectKey = new ArrayList<Object>(this.batchSize);
        int totalCount = 0;
        while (iterator.hasNext()) {
            collect.clear();
            collectKey.clear();
            int nExamples = 0;
            while (iterator.hasNext() && nExamples < this.batchSize) {
                Tuple2<K, INDArray> t2 = iterator.next();
                INDArray features = (INDArray)t2._2();
                long n = features.size(0);
                if (n != 1L) {
                    throw new IllegalStateException("Cannot score examples with one key per data set if data set contains more than 1 example (numExamples: " + n + ")");
                }
                collect.add(features);
                collectKey.add(t2._1());
                nExamples = (int)((long)nExamples + n);
            }
            totalCount += nExamples;
            INDArray toScore = Nd4j.vstack(collect);
            INDArray scores = this.computeScore(vae, toScore);
            double[] doubleScores = scores.data().asDouble();
            for (int i = 0; i < doubleScores.length; ++i) {
                ret.add(new Tuple2(collectKey.get(i), (Object)doubleScores[i]));
            }
        }
        Nd4j.getExecutioner().commit();
        if (log.isDebugEnabled()) {
            log.debug("Scored {} examples ", (Object)totalCount);
        }
        return ret.iterator();
    }
}

