/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.common.score;

import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder;
import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class BaseVaeReconstructionProbWithKeyFunction<K>
extends BaseVaeScoreWithKeyFunction<K> {
    private final boolean useLogProbability;
    private final int numSamples;

    public BaseVaeReconstructionProbWithKeyFunction(Broadcast<INDArray> params, Broadcast<String> jsonConfig, boolean useLogProbability, int batchSize, int numSamples) {
        super(params, jsonConfig, batchSize);
        this.useLogProbability = useLogProbability;
        this.numSamples = numSamples;
    }

    @Override
    public INDArray computeScore(VariationalAutoencoder vae, INDArray toScore) {
        if (this.useLogProbability) {
            return vae.reconstructionLogProbability(toScore, this.numSamples);
        }
        return vae.reconstructionProbability(toScore, this.numSamples);
    }
}

