/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.functions;

import java.util.Collections;
import java.util.Iterator;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.parameterserver.pw.SharedTrainingWrapper;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingWorker;
import org.nd4j.linalg.dataset.api.MultiDataSet;

public class SharedFlatMapMultiDataSet<R extends TrainingResult>
implements FlatMapFunction<Iterator<MultiDataSet>, R> {
    private final SharedTrainingWorker worker;

    public SharedFlatMapMultiDataSet(TrainingWorker<R> worker) {
        this.worker = (SharedTrainingWorker)worker;
    }

    public Iterator<R> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
        if (!dataSetIterator.hasNext()) {
            return Collections.emptyIterator();
        }
        SharedTrainingWrapper.getInstance(this.worker.getInstanceId()).attachMDS(dataSetIterator);
        SharedTrainingResult result = SharedTrainingWrapper.getInstance(this.worker.getInstanceId()).run(this.worker);
        return Collections.singletonList(result).iterator();
    }
}

