/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.parameterserver.modelimport.elephas;

import java.io.IOException;
import java.util.Map;
import org.apache.spark.api.java.JavaSparkContext;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive;
import org.deeplearning4j.nn.modelimport.keras.KerasModelImport;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.Repartitioner;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster;
import org.deeplearning4j.spark.impl.repartitioner.DefaultRepartitioner;
import org.deeplearning4j.spark.parameterserver.training.SharedTrainingMaster;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;

public class ElephasModelImport {
    private static final String DISTRIBUTED_CONFIG = "distributed_config";
    private static final RDDTrainingApproach APPROACH = RDDTrainingApproach.Export;

    public static SparkComputationGraph importElephasModelAndWeights(JavaSparkContext sparkContext, String modelHdf5Filename) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        ComputationGraph model = KerasModelImport.importKerasModelAndWeights((String)modelHdf5Filename, (boolean)true);
        Map<String, Object> distributedProperties = ElephasModelImport.distributedTrainingMap(modelHdf5Filename);
        TrainingMaster tm = ElephasModelImport.getTrainingMaster(distributedProperties);
        return new SparkComputationGraph(sparkContext, model, tm);
    }

    public static SparkDl4jMultiLayer importElephasSequentialModelAndWeights(JavaSparkContext sparkContext, String modelHdf5Filename) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        MultiLayerNetwork model = KerasModelImport.importKerasSequentialModelAndWeights((String)modelHdf5Filename, (boolean)true);
        Map<String, Object> distributedProperties = ElephasModelImport.distributedTrainingMap(modelHdf5Filename);
        TrainingMaster tm = ElephasModelImport.getTrainingMaster(distributedProperties);
        return new SparkDl4jMultiLayer(sparkContext, model, tm);
    }

    private static Map<String, Object> distributedTrainingMap(String modelHdf5Filename) throws UnsupportedKerasConfigurationException, IOException {
        Hdf5Archive archive = new Hdf5Archive(modelHdf5Filename);
        String initialModelJson = archive.readAttributeAsJson(DISTRIBUTED_CONFIG, new String[0]);
        return KerasModelUtils.parseJsonString((String)initialModelJson);
    }

    private static TrainingMaster getTrainingMaster(Map<String, Object> distributedProperties) throws InvalidKerasConfigurationException {
        Object tm;
        Map innerConfig = (Map)distributedProperties.get("config");
        Integer numWorkers = (Integer)innerConfig.get("num_workers");
        int batchSize = (Integer)innerConfig.get("batch_size");
        String mode = "synchronous";
        if (!innerConfig.containsKey("mode")) {
            throw new InvalidKerasConfigurationException("Couldn't find mode field.");
        }
        mode = (String)innerConfig.get("mode");
        boolean collectStats = false;
        if (innerConfig.containsKey("collect_stats")) {
            collectStats = (Boolean)innerConfig.get("collect_stats");
        }
        int numBatchesPrefetch = 0;
        if (innerConfig.containsKey("num_batches_prefetch")) {
            numBatchesPrefetch = (Integer)innerConfig.get("num_batches_prefetch");
        }
        if (mode.equals("synchronous")) {
            int averagingFrequency = 5;
            if (innerConfig.containsKey("averaging_frequency")) {
                averagingFrequency = (Integer)innerConfig.get("averaging_frequency");
            }
            tm = new ParameterAveragingTrainingMaster.Builder(numWorkers, batchSize).collectTrainingStats(collectStats).batchSizePerWorker(batchSize).averagingFrequency(averagingFrequency).workerPrefetchNumBatches(numBatchesPrefetch).aggregationDepth(2).repartionData(Repartition.Always).rddTrainingApproach(APPROACH).repartitionStrategy(RepartitionStrategy.Balanced).saveUpdater(false).build();
        } else if (mode.equals("asynchronous")) {
            double updateThreshold = 0.001;
            if (innerConfig.containsKey("update_threshold")) {
                updateThreshold = (Double)innerConfig.get("update_threshold");
            }
            FixedThresholdAlgorithm thresholdAlgorithm = new FixedThresholdAlgorithm(updateThreshold);
            VoidConfiguration voidConfiguration = VoidConfiguration.builder().build();
            tm = new SharedTrainingMaster.Builder(voidConfiguration, batchSize).thresholdAlgorithm((ThresholdAlgorithm)thresholdAlgorithm).batchSizePerWorker(batchSize).collectTrainingStats(collectStats).workerPrefetchNumBatches(numBatchesPrefetch).rddTrainingApproach(APPROACH).repartitioner((Repartitioner)new DefaultRepartitioner()).build();
        } else {
            throw new InvalidKerasConfigurationException("Unknown mode " + mode);
        }
        return tm;
    }
}

