/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.impl.paramavg;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Random;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.datavec.spark.util.SerializableHadoopConfig;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.TrainingResult;
import org.deeplearning4j.spark.api.TrainingWorker;
import org.deeplearning4j.spark.data.BatchAndExportDataSetsFunction;
import org.deeplearning4j.spark.data.BatchAndExportMultiDataSetsFunction;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.paramavg.util.ExportSupport;
import org.deeplearning4j.spark.util.serde.StorageLevelDeserializer;
import org.deeplearning4j.spark.util.serde.StorageLevelSerializer;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.annotation.PropertyAccessor;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseTrainingMaster<R extends TrainingResult, W extends TrainingWorker<R>>
implements TrainingMaster<R, W> {
    private static final Logger log = LoggerFactory.getLogger(BaseTrainingMaster.class);
    protected static ObjectMapper jsonMapper;
    protected static ObjectMapper yamlMapper;
    protected boolean collectTrainingStats;
    protected ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    protected int lastExportedRDDId = Integer.MIN_VALUE;
    protected String lastRDDExportPath;
    protected int batchSizePerWorker;
    protected String exportDirectory = null;
    protected Random rng;
    protected String trainingMasterUID;
    protected Boolean workerTogglePeriodicGC;
    protected Integer workerPeriodicGCFrequency;
    protected StatsStorageRouter statsStorage;
    protected List<TrainingListener> listeners;
    protected Repartition repartition;
    protected RepartitionStrategy repartitionStrategy;
    @JsonSerialize(using=StorageLevelSerializer.class)
    @JsonDeserialize(using=StorageLevelDeserializer.class)
    protected StorageLevel storageLevel;
    @JsonSerialize(using=StorageLevelSerializer.class)
    @JsonDeserialize(using=StorageLevelDeserializer.class)
    protected StorageLevel storageLevelStreams = StorageLevel.MEMORY_ONLY();
    protected RDDTrainingApproach rddTrainingApproach = RDDTrainingApproach.Export;
    protected Broadcast<SerializableHadoopConfig> broadcastHadoopConfig;

    protected BaseTrainingMaster() {
    }

    protected static synchronized ObjectMapper getJsonMapper() {
        if (jsonMapper == null) {
            jsonMapper = BaseTrainingMaster.getNewMapper(new JsonFactory());
        }
        return jsonMapper;
    }

    protected static synchronized ObjectMapper getYamlMapper() {
        if (yamlMapper == null) {
            yamlMapper = BaseTrainingMaster.getNewMapper((JsonFactory)new YAMLFactory());
        }
        return yamlMapper;
    }

    protected static ObjectMapper getNewMapper(JsonFactory jsonFactory) {
        ObjectMapper om = new ObjectMapper(jsonFactory);
        om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        om.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, true);
        om.enable(SerializationFeature.INDENT_OUTPUT);
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE);
        om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY);
        return om;
    }

    protected JavaRDD<String> exportIfRequired(JavaSparkContext sc, JavaRDD<DataSet> trainingData) {
        String baseDir;
        ExportSupport.assertExportSupported(sc);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int currentRDDUid = trainingData.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            baseDir = this.export(trainingData);
        } else if (this.lastExportedRDDId == currentRDDUid) {
            baseDir = this.getBaseDirForRDD(trainingData);
        } else {
            this.deleteTempDir(sc, this.lastRDDExportPath);
            baseDir = this.export(trainingData);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return sc.textFile(baseDir + "paths/");
    }

    protected JavaRDD<String> exportIfRequiredMDS(JavaSparkContext sc, JavaRDD<MultiDataSet> trainingData) {
        String baseDir;
        ExportSupport.assertExportSupported(sc);
        if (this.collectTrainingStats) {
            this.stats.logExportStart();
        }
        int currentRDDUid = trainingData.id();
        if (this.lastExportedRDDId == Integer.MIN_VALUE) {
            baseDir = this.exportMDS(trainingData);
        } else if (this.lastExportedRDDId == currentRDDUid) {
            baseDir = this.getBaseDirForRDD(trainingData);
        } else {
            this.deleteTempDir(sc, this.lastRDDExportPath);
            baseDir = this.exportMDS(trainingData);
        }
        if (this.collectTrainingStats) {
            this.stats.logExportEnd();
        }
        return sc.textFile(baseDir + "paths/");
    }

    protected String export(JavaRDD<DataSet> trainingData) {
        String baseDir = this.getBaseDirForRDD(trainingData);
        String dataDir = baseDir + "data/";
        String pathsDir = baseDir + "paths/";
        log.info("Initiating RDD<DataSet> export at {}", (Object)baseDir);
        JavaRDD paths = trainingData.mapPartitionsWithIndex((Function2)new BatchAndExportDataSetsFunction(this.batchSizePerWorker, dataDir), true);
        paths.saveAsTextFile(pathsDir);
        log.info("RDD<DataSet> export complete at {}", (Object)baseDir);
        this.lastExportedRDDId = trainingData.id();
        this.lastRDDExportPath = baseDir;
        return baseDir;
    }

    protected String exportMDS(JavaRDD<MultiDataSet> trainingData) {
        String baseDir = this.getBaseDirForRDD(trainingData);
        String dataDir = baseDir + "data/";
        String pathsDir = baseDir + "paths/";
        log.info("Initiating RDD<MultiDataSet> export at {}", (Object)baseDir);
        JavaRDD paths = trainingData.mapPartitionsWithIndex((Function2)new BatchAndExportMultiDataSetsFunction(this.batchSizePerWorker, dataDir), true);
        paths.saveAsTextFile(pathsDir);
        log.info("RDD<MultiDataSet> export complete at {}", (Object)baseDir);
        this.lastExportedRDDId = trainingData.id();
        this.lastRDDExportPath = baseDir;
        return baseDir;
    }

    protected String getBaseDirForRDD(JavaRDD<?> rdd) {
        if (this.exportDirectory == null) {
            this.exportDirectory = this.getDefaultExportDirectory(rdd.context());
        }
        return this.exportDirectory + (this.exportDirectory.endsWith("/") ? "" : "/") + this.trainingMasterUID + "/" + rdd.id() + "/";
    }

    protected boolean deleteTempDir(JavaSparkContext sc, String tempDirPath) {
        FileSystem fileSystem;
        log.info("Attempting to delete temporary directory: {}", (Object)tempDirPath);
        Configuration hadoopConfiguration = sc.hadoopConfiguration();
        try {
            fileSystem = FileSystem.get((URI)new URI(tempDirPath), (Configuration)hadoopConfiguration);
        }
        catch (IOException | URISyntaxException e) {
            throw new RuntimeException(e);
        }
        try {
            fileSystem.delete(new Path(tempDirPath), true);
            log.info("Deleted temporary directory: {}", (Object)tempDirPath);
            return true;
        }
        catch (IOException e) {
            log.warn("Could not delete temporary directory: {}", (Object)tempDirPath, (Object)e);
            return false;
        }
    }

    protected String getDefaultExportDirectory(SparkContext sc) {
        String hadoopTmpDir = sc.hadoopConfiguration().get("hadoop.tmp.dir");
        if (!hadoopTmpDir.endsWith("/") && !hadoopTmpDir.endsWith("\\")) {
            hadoopTmpDir = hadoopTmpDir + "/";
        }
        return hadoopTmpDir + "dl4j/";
    }

    @Override
    public boolean deleteTempFiles(JavaSparkContext sc) {
        return this.lastRDDExportPath == null || this.deleteTempDir(sc, this.lastRDDExportPath);
    }

    @Override
    public boolean deleteTempFiles(SparkContext sc) {
        return this.deleteTempFiles(new JavaSparkContext(sc));
    }

    public void setWorkerTogglePeriodicGC(Boolean workerTogglePeriodicGC) {
        this.workerTogglePeriodicGC = workerTogglePeriodicGC;
    }

    public Boolean getWorkerTogglePeriodicGC() {
        return this.workerTogglePeriodicGC;
    }

    public void setWorkerPeriodicGCFrequency(Integer workerPeriodicGCFrequency) {
        this.workerPeriodicGCFrequency = workerPeriodicGCFrequency;
    }

    public Integer getWorkerPeriodicGCFrequency() {
        return this.workerPeriodicGCFrequency;
    }
}

