/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.training;

import java.util.List;
import org.apache.spark.broadcast.Broadcast;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.api.WorkerConfiguration;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingWorker;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingResult;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

public class SharedTrainingWorker
extends BaseTrainingWorker<SharedTrainingResult>
implements TrainingWorker<SharedTrainingResult> {
    private final long instanceId;
    private final Broadcast<NetBroadcastTuple> broadcastModel;
    private final Broadcast<SharedTrainingConfiguration> broadcastConfiguration;
    private final List<TrainingListener> listeners;
    private final StatsStorageRouter router;
    private final Boolean workerTogglePeriodicGC;
    private final Integer workerPeriodicGCFrequency;

    public SharedTrainingWorker(long instanceId, Broadcast<NetBroadcastTuple> broadcastModel, Broadcast<SharedTrainingConfiguration> broadcastConfiguration, List<TrainingListener> listeners, StatsStorageRouter router, Boolean workerTogglePeriodicGC, Integer workerPeriodicGCFrequency) {
        this.instanceId = instanceId;
        this.broadcastModel = broadcastModel;
        this.broadcastConfiguration = broadcastConfiguration;
        this.listeners = listeners;
        this.router = router;
        this.workerTogglePeriodicGC = workerTogglePeriodicGC;
        this.workerPeriodicGCFrequency = workerPeriodicGCFrequency;
    }

    public void removeHook(TrainingHook trainingHook) {
        throw new UnsupportedOperationException();
    }

    public void addHook(TrainingHook trainingHook) {
        throw new UnsupportedOperationException();
    }

    public MultiLayerNetwork getInitialModel() {
        if (this.workerTogglePeriodicGC != null) {
            Nd4j.getMemoryManager().togglePeriodicGc(this.workerTogglePeriodicGC.booleanValue());
        }
        if (this.workerPeriodicGCFrequency != null) {
            Nd4j.getMemoryManager().setAutoGcWindow(this.workerPeriodicGCFrequency.intValue());
        }
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(0));
        NetBroadcastTuple tuple = (NetBroadcastTuple)this.broadcastModel.getValue();
        if (tuple.getConfiguration() != null) {
            MultiLayerConfiguration conf = tuple.getConfiguration();
            MultiLayerNetwork network = new MultiLayerNetwork(conf);
            network.init();
            if (tuple.getParameters() != null) {
                network.setParams(tuple.getParameters());
            }
            if (tuple.getUpdaterState() != null) {
                network.getUpdater().getStateViewArray().assign(tuple.getUpdaterState());
            }
            return network;
        }
        return null;
    }

    public ComputationGraph getInitialModelGraph() {
        Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(0));
        NetBroadcastTuple tuple = (NetBroadcastTuple)this.broadcastModel.getValue();
        if (tuple.getGraphConfiguration() != null) {
            ComputationGraphConfiguration conf = tuple.getGraphConfiguration();
            ComputationGraph network = new ComputationGraph(conf);
            network.init();
            if (tuple.getParameters() != null) {
                network.setParams(tuple.getParameters());
            }
            if (tuple.getUpdaterState() != null) {
                network.getUpdater().getUpdaterStateViewArray().assign(tuple.getUpdaterState());
            }
            return network;
        }
        return null;
    }

    public SharedTrainingResult processMinibatch(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public SharedTrainingResult processMinibatch(DataSet dataSet, ComputationGraph graph, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public SharedTrainingResult processMinibatch(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, MultiLayerNetwork network, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(DataSet dataSet, ComputationGraph graph, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> processMinibatchWithStats(MultiDataSet dataSet, ComputationGraph graph, boolean isLast) {
        throw new UnsupportedOperationException();
    }

    public SharedTrainingResult getFinalResult(MultiLayerNetwork network) {
        throw new UnsupportedOperationException();
    }

    public SharedTrainingResult getFinalResult(ComputationGraph network) {
        throw new UnsupportedOperationException();
    }

    public SharedTrainingResult getFinalResultNoData() {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultNoDataWithStats() {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultWithStats(MultiLayerNetwork network) {
        throw new UnsupportedOperationException();
    }

    public Pair<SharedTrainingResult, SparkTrainingStats> getFinalResultWithStats(ComputationGraph graph) {
        throw new UnsupportedOperationException();
    }

    public WorkerConfiguration getDataConfiguration() {
        throw new UnsupportedOperationException();
    }

    public long getInstanceId() {
        return this.instanceId;
    }

    public Broadcast<NetBroadcastTuple> getBroadcastModel() {
        return this.broadcastModel;
    }

    public Broadcast<SharedTrainingConfiguration> getBroadcastConfiguration() {
        return this.broadcastConfiguration;
    }

    public List<TrainingListener> getListeners() {
        return this.listeners;
    }

    public StatsStorageRouter getRouter() {
        return this.router;
    }

    public Boolean getWorkerTogglePeriodicGC() {
        return this.workerTogglePeriodicGC;
    }

    public Integer getWorkerPeriodicGCFrequency() {
        return this.workerPeriodicGCFrequency;
    }
}

