package org.deeplearning4j.spark.parameterserver.training;

import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaRDDLike;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.storage.StorageLevel;
import org.datavec.spark.util.BroadcastHadoopConfigHolder;
import org.deeplearning4j.api.loader.DataSetLoader;
import org.deeplearning4j.api.loader.MultiDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedDataSetLoader;
import org.deeplearning4j.api.loader.impl.SerializedMultiDataSetLoader;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorageRouter;
import org.deeplearning4j.api.storage.StorageMetaData;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.residual.ResidualClippingPostProcessor;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm;
import org.deeplearning4j.spark.api.RDDTrainingApproach;
import org.deeplearning4j.spark.api.Repartition;
import org.deeplearning4j.spark.api.RepartitionStrategy;
import org.deeplearning4j.spark.api.Repartitioner;
import org.deeplearning4j.spark.api.TrainingHook;
import org.deeplearning4j.spark.api.TrainingMaster;
import org.deeplearning4j.spark.api.stats.SparkTrainingStats;
import org.deeplearning4j.spark.api.worker.NetBroadcastTuple;
import org.deeplearning4j.spark.impl.graph.SparkComputationGraph;
import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer;
import org.deeplearning4j.spark.impl.paramavg.BaseTrainingMaster;
import org.deeplearning4j.spark.impl.paramavg.stats.ParameterAveragingTrainingMasterStats;
import org.deeplearning4j.spark.impl.repartitioner.DefaultRepartitioner;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationFunction;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAccumulationTuple;
import org.deeplearning4j.spark.parameterserver.accumulation.SharedTrainingAggregateFunction;
import org.deeplearning4j.spark.parameterserver.conf.SharedTrainingConfiguration;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapMultiDataSet;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPaths;
import org.deeplearning4j.spark.parameterserver.functions.SharedFlatMapPathsMDS;
import org.deeplearning4j.spark.parameterserver.networking.v1.SilentTrainingDriver;
import org.deeplearning4j.spark.parameterserver.networking.v2.UpdatesConsumer;
import org.deeplearning4j.spark.util.SparkUtils;
import org.deeplearning4j.util.UIDProvider;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.enums.ExecutionMode;
import org.nd4j.parameterserver.distributed.enums.NodeRole;
import org.nd4j.parameterserver.distributed.enums.TransportType;
import org.nd4j.parameterserver.distributed.util.NetworkOrganizer;
import org.nd4j.parameterserver.distributed.v2.ModelParameterServer;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.transport.impl.AeronUdpTransport;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.class */
public class SharedTrainingMaster extends BaseTrainingMaster<SharedTrainingResult, SharedTrainingWorker> implements TrainingMaster<SharedTrainingResult, SharedTrainingWorker> {
    private static final Logger log = LoggerFactory.getLogger(SharedTrainingMaster.class);
    protected static final AtomicInteger INSTANCE_COUNTER = new AtomicInteger();
    protected static final AtomicInteger LAST_TRAINING_INSTANCE = new AtomicInteger(-1);
    protected List<TrainingHook> trainingHooks;
    protected VoidConfiguration voidConfiguration;
    protected Integer numWorkers;
    protected Integer numWorkersPerNode;
    protected int workerPrefetchBatches;
    protected RDDTrainingApproach rddTrainingApproach;
    protected StorageLevel storageLevel;
    protected Repartitioner repartitioner;
    protected boolean collectTrainingStats;
    protected int rddDataSetNumExamples;
    protected long debugLongerIterations;
    protected boolean logMinibatchesPerWorker;
    protected boolean encodingDebugMode;
    protected ThresholdAlgorithm thresholdAlgorithm;
    protected ResidualPostProcessor residualPostProcessor;
    protected Repartition repartition;
    protected RepartitionStrategy repartitionStrategy;
    protected ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats;
    protected Random rng;
    protected AtomicBoolean isFirstRun;
    protected final transient int instanceId;
    protected transient Broadcast<NetBroadcastTuple> broadcastModel;
    protected transient Broadcast<SharedTrainingConfiguration> broadcastConfiguration;
    protected transient Transport transport;
    protected transient SilentTrainingDriver trainingDriver;
    protected transient UpdatesConsumer updatesConsumer;
    protected boolean setupDone;

    /* loaded from: input_file:org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster$Builder.class */
    public static class Builder {
        protected ThresholdAlgorithm thresholdAlgorithm;
        protected ResidualPostProcessor residualPostProcessor;
        protected int rddDataSetNumExamples;

        @Deprecated
        protected Repartition repartition;

        @Deprecated
        protected RepartitionStrategy repartitionStrategy;
        protected StorageLevel storageLevel;
        protected VoidConfiguration voidConfiguration;
        protected RDDTrainingApproach rddTrainingApproach;
        protected long rngSeed;
        protected String exportDirectory;
        protected Integer numWorkers;
        protected boolean collectTrainingStats;
        protected Transport transport;
        protected int batchSize;
        protected long debugLongerIterations;
        protected int numWorkersPerNode;
        protected int workerPrefetchNumBatches;
        protected Repartitioner repartitioner;
        protected Boolean workerTogglePeriodicGC;
        protected Integer workerPeriodicGCFrequency;
        protected boolean encodingDebugMode;

        public Builder(int i) {
            this((ThresholdAlgorithm) new AdaptiveThresholdAlgorithm(), i);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, int i) {
            this(voidConfiguration, new AdaptiveThresholdAlgorithm(), i);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
            }
        }

        public Builder(ThresholdAlgorithm thresholdAlgorithm, int i) {
            this(VoidConfiguration.builder().executionMode(ExecutionMode.MANAGED).forcedRole(NodeRole.SHARD).controllerAddress(System.getenv("SPARK_PUBLIC_DNS")).build(), thresholdAlgorithm, i);
        }

        @Deprecated
        public Builder(@NonNull VoidConfiguration voidConfiguration, Integer num, double d, int i) {
            this(voidConfiguration, new AdaptiveThresholdAlgorithm(d), i);
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
            }
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, ThresholdAlgorithm thresholdAlgorithm, int i) {
            this.thresholdAlgorithm = new AdaptiveThresholdAlgorithm();
            this.residualPostProcessor = new ResidualClippingPostProcessor(5.0d, 5);
            this.rddDataSetNumExamples = 1;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            this.storageLevel = StorageLevel.MEMORY_ONLY_SER();
            this.rddTrainingApproach = RDDTrainingApproach.Export;
            this.exportDirectory = null;
            this.debugLongerIterations = 0L;
            this.numWorkersPerNode = -1;
            this.workerPrefetchNumBatches = 2;
            this.repartitioner = new DefaultRepartitioner();
            this.workerTogglePeriodicGC = new Boolean(true);
            this.workerPeriodicGCFrequency = new Integer(5000);
            this.encodingDebugMode = false;
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
            }
            this.thresholdAlgorithm = thresholdAlgorithm;
            this.voidConfiguration = voidConfiguration;
            this.rddDataSetNumExamples = i;
            this.voidConfiguration.setExecutionMode(ExecutionMode.MANAGED);
        }

        public Builder(@NonNull VoidConfiguration voidConfiguration, Integer num, ThresholdAlgorithm thresholdAlgorithm, int i) {
            this.thresholdAlgorithm = new AdaptiveThresholdAlgorithm();
            this.residualPostProcessor = new ResidualClippingPostProcessor(5.0d, 5);
            this.rddDataSetNumExamples = 1;
            this.repartition = Repartition.Always;
            this.repartitionStrategy = RepartitionStrategy.Balanced;
            this.storageLevel = StorageLevel.MEMORY_ONLY_SER();
            this.rddTrainingApproach = RDDTrainingApproach.Export;
            this.exportDirectory = null;
            this.debugLongerIterations = 0L;
            this.numWorkersPerNode = -1;
            this.workerPrefetchNumBatches = 2;
            this.repartitioner = new DefaultRepartitioner();
            this.workerTogglePeriodicGC = new Boolean(true);
            this.workerPeriodicGCFrequency = new Integer(5000);
            this.encodingDebugMode = false;
            if (voidConfiguration == null) {
                throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
            }
            this.thresholdAlgorithm = thresholdAlgorithm;
            this.voidConfiguration = voidConfiguration;
            this.rddDataSetNumExamples = i;
            this.numWorkers = num;
            this.voidConfiguration.setExecutionMode(ExecutionMode.MANAGED);
        }

        public Builder collectTrainingStats(boolean z) {
            this.collectTrainingStats = z;
            return this;
        }

        @Deprecated
        public Builder repartitionData(Repartition repartition) {
            this.repartition = repartition;
            return this;
        }

        @Deprecated
        public Builder repartitionStrategy(RepartitionStrategy repartitionStrategy) {
            this.repartitionStrategy = repartitionStrategy;
            return this;
        }

        public Builder storageLevel(StorageLevel storageLevel) {
            this.storageLevel = storageLevel;
            return this;
        }

        public Builder rddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
            this.rddTrainingApproach = rDDTrainingApproach;
            return this;
        }

        public Builder exportDirectory(String str) {
            this.exportDirectory = str;
            return this;
        }

        public Builder rngSeed(long j) {
            this.rngSeed = j;
            return this;
        }

        @Deprecated
        public Builder updatesThreshold(double d) {
            return thresholdAlgorithm(new AdaptiveThresholdAlgorithm(d));
        }

        public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
            this.thresholdAlgorithm = thresholdAlgorithm;
            return this;
        }

        public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor) {
            this.residualPostProcessor = residualPostProcessor;
            return this;
        }

        public Builder batchSizePerWorker(int i) {
            this.batchSize = i;
            return this;
        }

        public Builder workersPerNode(int i) {
            if (i < 1) {
                i = -1;
            }
            this.numWorkersPerNode = i;
            return this;
        }

        @Deprecated
        public Builder debugLongerIterations(long j) {
            if (j < 0) {
                j = 0;
            }
            this.debugLongerIterations = j;
            return this;
        }

        public Builder transport(Transport transport) {
            this.transport = transport;
            return this;
        }

        public Builder workerPrefetchNumBatches(int i) {
            this.workerPrefetchNumBatches = i;
            return this;
        }

        public Builder repartitioner(Repartitioner repartitioner) {
            this.repartitioner = repartitioner;
            return this;
        }

        public Builder workerTogglePeriodicGC(boolean z) {
            this.workerTogglePeriodicGC = Boolean.valueOf(z);
            return this;
        }

        public Builder workerPeriodicGCFrequency(int i) {
            this.workerPeriodicGCFrequency = Integer.valueOf(i);
            return this;
        }

        public Builder encodingDebugMode(boolean z) {
            this.encodingDebugMode = z;
            return this;
        }

        public SharedTrainingMaster build() {
            SharedTrainingMaster sharedTrainingMaster = new SharedTrainingMaster(this.voidConfiguration, this.numWorkers, this.rddTrainingApproach, this.storageLevel, this.collectTrainingStats, this.repartitionStrategy, this.repartition, this.thresholdAlgorithm, this.residualPostProcessor, this.rddDataSetNumExamples, this.batchSize, this.debugLongerIterations, this.numWorkersPerNode, this.workerPrefetchNumBatches, this.repartitioner, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency, this.encodingDebugMode);
            if (this.transport != null) {
                sharedTrainingMaster.transport = this.transport;
            }
            return sharedTrainingMaster;
        }
    }

    protected SharedTrainingMaster() {
        this.debugLongerIterations = 0L;
        this.logMinibatchesPerWorker = false;
        this.encodingDebugMode = false;
        this.instanceId = INSTANCE_COUNTER.getAndIncrement();
    }

    public SharedTrainingMaster(@NonNull VoidConfiguration voidConfiguration, Integer num, RDDTrainingApproach rDDTrainingApproach, StorageLevel storageLevel, boolean z, RepartitionStrategy repartitionStrategy, Repartition repartition, ThresholdAlgorithm thresholdAlgorithm, ResidualPostProcessor residualPostProcessor, int i, int i2, long j, int i3, int i4, Repartitioner repartitioner, Boolean bool, Integer num2, boolean z2) {
        this.debugLongerIterations = 0L;
        this.logMinibatchesPerWorker = false;
        this.encodingDebugMode = false;
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        this.voidConfiguration = voidConfiguration;
        this.numWorkers = num;
        this.thresholdAlgorithm = thresholdAlgorithm;
        this.residualPostProcessor = residualPostProcessor;
        this.rddTrainingApproach = rDDTrainingApproach;
        this.repartitionStrategy = repartitionStrategy;
        this.repartition = repartition;
        this.storageLevel = storageLevel;
        this.collectTrainingStats = z;
        this.isFirstRun = new AtomicBoolean(false);
        this.batchSizePerWorker = i2;
        this.rddDataSetNumExamples = i;
        this.debugLongerIterations = j;
        this.numWorkersPerNode = Integer.valueOf(i3);
        this.workerPrefetchBatches = i4;
        this.repartitioner = repartitioner;
        this.workerTogglePeriodicGC = bool;
        this.workerPeriodicGCFrequency = num2;
        this.encodingDebugMode = z2;
        if (z) {
            this.stats = new ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper();
        }
        String jvmuid = UIDProvider.getJVMUID();
        this.trainingMasterUID = System.currentTimeMillis() + "_" + (jvmuid.length() <= 8 ? jvmuid : jvmuid.substring(0, 8));
        this.instanceId = INSTANCE_COUNTER.getAndIncrement();
    }

    public void removeHook(TrainingHook trainingHook) {
        if (this.trainingHooks != null) {
            this.trainingHooks.remove(trainingHook);
        }
    }

    public void addHook(@NonNull TrainingHook trainingHook) {
        if (trainingHook == null) {
            throw new NullPointerException("trainingHook is marked @NonNull but is null");
        }
        if (this.trainingHooks == null) {
            this.trainingHooks = new ArrayList();
        }
        this.trainingHooks.add(trainingHook);
    }

    public String toJson() {
        try {
            return getJsonMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing JSON representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public String toYaml() {
        try {
            return getYamlMapper().writeValueAsString(this);
        } catch (JsonProcessingException e) {
            throw new RuntimeException("Error producing YAML representation for ParameterAveragingTrainingMaster", e);
        }
    }

    public static SharedTrainingMaster fromJson(String str) {
        try {
            return (SharedTrainingMaster) getJsonMapper().readValue(str, SharedTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse JSON", e);
        }
    }

    public static SharedTrainingMaster fromYaml(String str) {
        try {
            return (SharedTrainingMaster) getYamlMapper().readValue(str, SharedTrainingMaster.class);
        } catch (IOException e) {
            throw new RuntimeException("Could not parse YAML", e);
        }
    }

    /* renamed from: getWorkerInstance, reason: merged with bridge method [inline-methods] */
    public SharedTrainingWorker m18getWorkerInstance(SparkDl4jMultiLayer sparkDl4jMultiLayer) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkDl4jMultiLayer.getNetwork().getLayerWiseConfigurations(), sparkDl4jMultiLayer.getNetwork().params(), sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray());
        this.voidConfiguration.setUnicastControllerPort(this.voidConfiguration.getPortSupplier().getPort());
        SharedTrainingConfiguration build = SharedTrainingConfiguration.builder().thresholdAlgorithm(this.thresholdAlgorithm).residualPostProcessor(this.residualPostProcessor).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode.intValue()).encodingDebugMode(this.encodingDebugMode).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = sparkDl4jMultiLayer.getSparkContext().broadcast(netBroadcastTuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = sparkDl4jMultiLayer.getSparkContext().broadcast(build);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new SharedTrainingWorker(this.instanceId, this.broadcastModel, this.broadcastConfiguration, this.listeners, this.statsStorage, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency);
    }

    /* renamed from: getWorkerInstance, reason: merged with bridge method [inline-methods] */
    public SharedTrainingWorker m17getWorkerInstance(SparkComputationGraph sparkComputationGraph) {
        NetBroadcastTuple netBroadcastTuple = new NetBroadcastTuple(sparkComputationGraph.getNetwork().getConfiguration(), sparkComputationGraph.getNetwork().params(), sparkComputationGraph.getNetwork().getUpdater().getStateViewArray());
        SharedTrainingConfiguration build = SharedTrainingConfiguration.builder().thresholdAlgorithm(this.thresholdAlgorithm).residualPostProcessor(this.residualPostProcessor).voidConfiguration(this.voidConfiguration).debugLongerIterations(this.debugLongerIterations).numberOfWorkersPerNode(this.numWorkersPerNode.intValue()).prefetchSize(this.workerPrefetchBatches).encodingDebugMode(this.encodingDebugMode).build();
        if (this.collectTrainingStats) {
            this.stats.logBroadcastStart();
        }
        if (this.broadcastModel == null) {
            this.broadcastModel = sparkComputationGraph.getSparkContext().broadcast(netBroadcastTuple);
        }
        if (this.broadcastConfiguration == null) {
            this.broadcastConfiguration = sparkComputationGraph.getSparkContext().broadcast(build);
        }
        if (this.collectTrainingStats) {
            this.stats.logBroadcastEnd();
        }
        return new SharedTrainingWorker(this.instanceId, this.broadcastModel, this.broadcastConfiguration, this.listeners, this.statsStorage, this.workerTogglePeriodicGC, this.workerPeriodicGCFrequency);
    }

    protected int numObjectsEachWorker(int i) {
        return this.batchSizePerWorker / i;
    }

    protected <T, Repr extends JavaRDDLike<T, Repr>> long getTotalDataSetObjectCount(JavaRDDLike<T, Repr> javaRDDLike) {
        if (this.collectTrainingStats) {
            this.stats.logCountStart();
        }
        long count = javaRDDLike.count();
        if (this.collectTrainingStats) {
            this.stats.logCountEnd();
        }
        return count;
    }

    protected void executeTrainingDirect(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIteration(sparkDl4jMultiLayer, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirectMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIterationMDS(sparkComputationGraph, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void executeTrainingDirect(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevel != null) {
            javaRDD.persist(this.storageLevel);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIteration(sparkComputationGraph, javaRDD, 1, 1);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    public void executeTrainingPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader) {
        prepareNetworkAndStuff(sparkDl4jMultiLayer, sparkComputationGraph);
        executeTrainingPathsHelper(sparkDl4jMultiLayer, sparkComputationGraph, javaRDD, dataSetLoader, multiDataSetLoader, this.rddDataSetNumExamples);
    }

    protected void executeTrainingPathsHelper(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader, int i) {
        if (this.numWorkers == null) {
            if (sparkDl4jMultiLayer != null) {
                this.numWorkers = sparkDl4jMultiLayer.getSparkContext().defaultParallelism();
            } else {
                this.numWorkers = sparkComputationGraph.getSparkContext().defaultParallelism();
            }
        }
        if (this.collectTrainingStats) {
            this.stats.logFitStart();
        }
        if (this.storageLevelStreams != null) {
            javaRDD.persist(this.storageLevelStreams);
        }
        long totalDataSetObjectCount = getTotalDataSetObjectCount(javaRDD);
        doIterationPaths(sparkDl4jMultiLayer, sparkComputationGraph, javaRDD, 1, 1, dataSetLoader, multiDataSetLoader, i);
        if (this.collectTrainingStats) {
            this.stats.logFitEnd((int) totalDataSetObjectCount);
        }
    }

    protected void prepareNetworkAndStuff(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph) {
        String str;
        if (sparkDl4jMultiLayer == null && sparkComputationGraph == null) {
            throw new IllegalStateException("Both MLN & CG are undefined");
        }
        this.voidConfiguration.setUnicastControllerPort(this.voidConfiguration.getPortSupplier().getPort());
        if (this.voidConfiguration.getStreamId() < 1) {
            this.voidConfiguration.setStreamId(RandomUtils.nextInt(119, 2147483646));
        }
        if (this.numWorkers == null) {
            this.numWorkers = sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getSparkContext().defaultParallelism() : sparkComputationGraph.getSparkContext().defaultParallelism();
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            try {
                String str2 = System.getenv("SPARK_PUBLIC_DNS");
                log.info("Trying {SPARK_PUBLIC_DNS}: [{}]", str2);
                if (str2 != null) {
                    this.voidConfiguration.setControllerAddress(InetAddress.getByName(str2).getHostAddress());
                }
            } catch (UnknownHostException e) {
            }
        }
        if (this.voidConfiguration.getControllerAddress() == null && this.voidConfiguration.getNetworkMask() != null) {
            String matchingAddress = new NetworkOrganizer(this.voidConfiguration.getNetworkMask()).getMatchingAddress();
            log.info("Trying auto-detected address: [{}]", matchingAddress);
            this.voidConfiguration.setControllerAddress(matchingAddress);
        }
        if (this.voidConfiguration.getControllerAddress() == null && (str = System.getenv("DL4J_VOID_IP")) != null && !str.isEmpty()) {
            this.voidConfiguration.setControllerAddress(str);
        }
        if (this.voidConfiguration.getControllerAddress() == null) {
            throw new DL4JInvalidConfigException("Can't get Spark Master local address. Please specify it manually using VoidConfiguration.setControllerAddress(String) method or VoidConfiguration.setNetworkMask(String) method");
        }
        log.info("Setting controller address to {}:{}", this.voidConfiguration.getControllerAddress(), Integer.valueOf(this.voidConfiguration.getUnicastControllerPort()));
        this.voidConfiguration.setShardAddresses(new String[]{this.voidConfiguration.getControllerAddress()});
        this.voidConfiguration.setNumberOfShards(1);
        if (sparkDl4jMultiLayer != null) {
            sparkDl4jMultiLayer.getNetwork().init();
        } else {
            sparkComputationGraph.getNetwork().init();
        }
        if (this.isFirstRun.compareAndSet(false, true) || LAST_TRAINING_INSTANCE.get() != this.instanceId) {
            if (LAST_TRAINING_INSTANCE.get() >= 0 && LAST_TRAINING_INSTANCE.get() != this.instanceId) {
                log.debug("Detected changed training instance - setting up new parameter server - old instance {}, new instance {}", LAST_TRAINING_INSTANCE, Integer.valueOf(this.instanceId));
                ModelParameterServer.getInstance().shutdown();
                try {
                    Thread.sleep(3000L);
                } catch (Exception e2) {
                    throw new RuntimeException(e2);
                }
            }
            AeronUdpTransport aeronUdpTransport = this.voidConfiguration.getTransportType() == TransportType.ROUTED_UDP ? new AeronUdpTransport(this.voidConfiguration.getControllerAddress(), this.voidConfiguration.getUnicastControllerPort(), this.voidConfiguration) : null;
            if (aeronUdpTransport == null) {
                throw new DL4JInvalidConfigException("No Transport implementation was defined for this training session!");
            }
            INDArray params = sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getNetwork().params() : sparkComputationGraph.getNetwork().params();
            this.updatesConsumer = UpdatesConsumer.builder().params(params).updates(Nd4j.create(params.shape(), params.ordering())).stepFunction(sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getNetwork().getOptimizer().getStepFunction() : sparkComputationGraph.getNetwork().getOptimizer().getStepFunction()).build();
            ModelParameterServer.getInstance().configure(this.voidConfiguration, aeronUdpTransport, true);
            ModelParameterServer.getInstance().addUpdatesSubscriber(this.updatesConsumer);
            if (!ModelParameterServer.getInstance().isInitialized()) {
                ModelParameterServer.getInstance().launch();
            }
            LAST_TRAINING_INSTANCE.set(this.instanceId);
        }
        this.setupDone = true;
    }

    protected void finalizeTraining() {
        if (this.trainingDriver != null) {
            this.trainingDriver.finishTraining(0L, 0L);
        }
        if (this.updatesConsumer != null) {
            this.updatesConsumer.flush();
        }
    }

    public void executeTraining(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD) {
        prepareNetworkAndStuff(sparkDl4jMultiLayer, null);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkDl4jMultiLayer, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsHelper(sparkDl4jMultiLayer, null, exportIfRequired(sparkDl4jMultiLayer.getSparkContext(), javaRDD), new SerializedDataSetLoader(), null, this.batchSizePerWorker);
        }
    }

    public void executeTraining(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirect(sparkComputationGraph, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsHelper(null, sparkComputationGraph, exportIfRequired(sparkComputationGraph.getSparkContext(), javaRDD), new SerializedDataSetLoader(), null, this.batchSizePerWorker);
        }
    }

    public void executeTrainingMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD) {
        prepareNetworkAndStuff(null, sparkComputationGraph);
        if (this.rddTrainingApproach == RDDTrainingApproach.Direct) {
            executeTrainingDirectMDS(sparkComputationGraph, javaRDD);
        } else {
            if (this.rddTrainingApproach != RDDTrainingApproach.Export) {
                throw new DL4JInvalidConfigException("Unknown RDDtrainingApproach [" + this.rddTrainingApproach + "] was specified!");
            }
            executeTrainingPathsHelper(null, sparkComputationGraph, exportIfRequiredMDS(sparkComputationGraph.getSparkContext(), javaRDD), null, new SerializedMultiDataSetLoader(), this.batchSizePerWorker);
        }
    }

    public void setCollectTrainingStats(boolean z) {
        this.collectTrainingStats = z;
    }

    public boolean getIsCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public SparkTrainingStats getTrainingStats() {
        return null;
    }

    public void setListeners(Collection<TrainingListener> collection) {
        setListeners(null, collection);
    }

    public void setListeners(StatsStorageRouter statsStorageRouter, Collection<TrainingListener> collection) {
        this.statsStorage = statsStorageRouter;
        this.listeners = collection == null ? null : new ArrayList(collection);
    }

    protected void processResults(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<SharedTrainingResult> javaRDD) {
        Preconditions.checkState((sparkDl4jMultiLayer == null && sparkComputationGraph == null) ? false : true, "Both MLN & CG are null");
        Preconditions.checkState(this.setupDone, "Setup was not completed before trying to process results");
        if (this.collectTrainingStats) {
            this.stats.logAggregateStartTime();
        }
        SharedTrainingAccumulationTuple sharedTrainingAccumulationTuple = (SharedTrainingAccumulationTuple) javaRDD.treeAggregate((Object) null, new SharedTrainingAggregateFunction(), new SharedTrainingAccumulationFunction(), 4);
        SparkTrainingStats sparkTrainingStats = sharedTrainingAccumulationTuple.getSparkTrainingStats();
        if (this.collectTrainingStats) {
            this.stats.logAggregationEndTime();
        }
        finalizeTraining();
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterStart();
        }
        if (sharedTrainingAccumulationTuple.getUpdaterStateArray() != null) {
            if (sharedTrainingAccumulationTuple.getAggregationsCount() > 1) {
                sharedTrainingAccumulationTuple.getUpdaterStateArray().divi(Integer.valueOf(sharedTrainingAccumulationTuple.getAggregationsCount()));
            }
            if (sparkDl4jMultiLayer != null) {
                if (sparkDl4jMultiLayer.getNetwork().getUpdater() != null && sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray() != null) {
                    sparkDl4jMultiLayer.getNetwork().getUpdater().getStateViewArray().assign(sharedTrainingAccumulationTuple.getUpdaterStateArray());
                }
            } else if (sparkComputationGraph.getNetwork().getUpdater() != null && sparkComputationGraph.getNetwork().getUpdater().getStateViewArray() != null) {
                sparkComputationGraph.getNetwork().getUpdater().getStateViewArray().assign(sharedTrainingAccumulationTuple.getUpdaterStateArray());
            }
        }
        double scoreSum = sharedTrainingAccumulationTuple.getScoreSum() / Math.max(1, sharedTrainingAccumulationTuple.getAggregationsCount());
        if (sparkDl4jMultiLayer != null) {
            sparkDl4jMultiLayer.getNetwork().setScore(scoreSum);
        } else {
            sparkComputationGraph.getNetwork().setScore(scoreSum);
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
        }
        if (this.collectTrainingStats) {
            this.stats.logProcessParamsUpdaterEnd();
            this.stats.addWorkerStats(sparkTrainingStats);
        }
        if (this.statsStorage != null) {
            Collection<StorageMetaData> listenerMetaData = sharedTrainingAccumulationTuple.getListenerMetaData();
            if (listenerMetaData != null && !listenerMetaData.isEmpty()) {
                this.statsStorage.putStorageMetaData(listenerMetaData);
            }
            Collection<Persistable> listenerStaticInfo = sharedTrainingAccumulationTuple.getListenerStaticInfo();
            if (listenerStaticInfo != null && !listenerStaticInfo.isEmpty()) {
                this.statsStorage.putStaticInfo(listenerStaticInfo);
            }
            Collection<Persistable> listenerUpdates = sharedTrainingAccumulationTuple.getListenerUpdates();
            if (listenerUpdates != null && !listenerUpdates.isEmpty()) {
                this.statsStorage.putUpdate(listenerUpdates);
            }
        }
        if (this.logMinibatchesPerWorker && sharedTrainingAccumulationTuple.getMinibatchesPerExecutor() != null) {
            ArrayList<String> arrayList = new ArrayList(sharedTrainingAccumulationTuple.getMinibatchesPerExecutor().keySet());
            Collections.sort(arrayList);
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (String str : arrayList) {
                linkedHashMap.put(str, sharedTrainingAccumulationTuple.getMinibatchesPerExecutor().get(str));
            }
            log.info("Number of minibatches processed per JVM/executor: {}", linkedHashMap);
        }
        if (sharedTrainingAccumulationTuple.getThresholdAlgorithmReducer() != null) {
            this.thresholdAlgorithm = sharedTrainingAccumulationTuple.getThresholdAlgorithmReducer().getFinalResult();
        }
        Nd4j.getExecutioner().commit();
    }

    protected void doIteration(SparkDl4jMultiLayer sparkDl4jMultiLayer, JavaRDD<DataSet> javaRDD, int i, int i2) {
        JavaRDD repartitionEqually;
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", this.repartitioner);
            repartitionEqually = this.repartitioner.repartition(javaRDD, Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples), this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            repartitionEqually = SparkUtils.repartitionEqually(javaRDD, this.repartition, this.numWorkers.intValue());
        }
        int size = repartitionEqually.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(sparkDl4jMultiLayer, null, repartitionEqually.mapPartitions(new SharedFlatMapDataSet(m18getWorkerInstance(sparkDl4jMultiLayer))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationMDS(SparkComputationGraph sparkComputationGraph, JavaRDD<MultiDataSet> javaRDD, int i, int i2) {
        JavaRDD repartitionEqually;
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", this.repartitioner);
            repartitionEqually = this.repartitioner.repartition(javaRDD, Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples), this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            repartitionEqually = SparkUtils.repartitionEqually(javaRDD, this.repartition, this.numWorkers.intValue());
        }
        int size = repartitionEqually.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartitionEqually.mapPartitions(new SharedFlatMapMultiDataSet(m17getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIteration(SparkComputationGraph sparkComputationGraph, JavaRDD<DataSet> javaRDD, int i, int i2) {
        JavaRDD repartitionEqually;
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", this.repartitioner);
            repartitionEqually = this.repartitioner.repartition(javaRDD, Math.max(1, this.batchSizePerWorker / this.rddDataSetNumExamples), this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            repartitionEqually = SparkUtils.repartitionEqually(javaRDD, this.repartition, this.numWorkers.intValue());
        }
        int size = repartitionEqually.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        processResults(null, sparkComputationGraph, repartitionEqually.mapPartitions(new SharedFlatMapDataSet(m17getWorkerInstance(sparkComputationGraph))));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    protected void doIterationPaths(SparkDl4jMultiLayer sparkDl4jMultiLayer, SparkComputationGraph sparkComputationGraph, JavaRDD<String> javaRDD, int i, int i2, DataSetLoader dataSetLoader, MultiDataSetLoader multiDataSetLoader, int i3) {
        JavaRDD repartitionEqually;
        FlatMapFunction sharedFlatMapPathsMDS;
        if (sparkDl4jMultiLayer == null && sparkComputationGraph == null) {
            throw new DL4JInvalidConfigException("Both MLN & CompGraph are NULL");
        }
        log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers", new Object[]{Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(this.batchSizePerWorker), this.thresholdAlgorithm, this.numWorkers});
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsStart();
        }
        if (this.collectTrainingStats) {
            this.stats.logRepartitionStart();
        }
        if (this.repartitioner != null) {
            log.info("Repartitioning training data using repartitioner: {}", this.repartitioner);
            repartitionEqually = this.repartitioner.repartition(javaRDD, Math.max(1, this.batchSizePerWorker / i3), this.numWorkers.intValue());
        } else {
            log.info("Repartitioning training data using SparkUtils repartitioner");
            repartitionEqually = SparkUtils.repartitionEqually(javaRDD, this.repartition, this.numWorkers.intValue());
        }
        int size = repartitionEqually.partitions().size();
        if (this.collectTrainingStats && this.repartition != Repartition.Never) {
            this.stats.logRepartitionEnd();
        }
        JavaSparkContext sparkContext = sparkDl4jMultiLayer != null ? sparkDl4jMultiLayer.getSparkContext() : sparkComputationGraph.getSparkContext();
        if (dataSetLoader != null) {
            sharedFlatMapPathsMDS = new SharedFlatMapPaths(sparkDl4jMultiLayer != null ? m18getWorkerInstance(sparkDl4jMultiLayer) : m17getWorkerInstance(sparkComputationGraph), dataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext));
        } else {
            sharedFlatMapPathsMDS = new SharedFlatMapPathsMDS(sparkDl4jMultiLayer != null ? m18getWorkerInstance(sparkDl4jMultiLayer) : m17getWorkerInstance(sparkComputationGraph), multiDataSetLoader, BroadcastHadoopConfigHolder.get(sparkContext));
        }
        processResults(sparkDl4jMultiLayer, sparkComputationGraph, repartitionEqually.mapPartitions(sharedFlatMapPathsMDS));
        if (this.collectTrainingStats) {
            this.stats.logMapPartitionsEnd(size);
        }
    }

    public List<TrainingHook> getTrainingHooks() {
        return this.trainingHooks;
    }

    public VoidConfiguration getVoidConfiguration() {
        return this.voidConfiguration;
    }

    public Integer getNumWorkers() {
        return this.numWorkers;
    }

    public Integer getNumWorkersPerNode() {
        return this.numWorkersPerNode;
    }

    public int getWorkerPrefetchBatches() {
        return this.workerPrefetchBatches;
    }

    public RDDTrainingApproach getRddTrainingApproach() {
        return this.rddTrainingApproach;
    }

    public StorageLevel getStorageLevel() {
        return this.storageLevel;
    }

    public Repartitioner getRepartitioner() {
        return this.repartitioner;
    }

    public boolean isCollectTrainingStats() {
        return this.collectTrainingStats;
    }

    public int getRddDataSetNumExamples() {
        return this.rddDataSetNumExamples;
    }

    public long getDebugLongerIterations() {
        return this.debugLongerIterations;
    }

    public boolean isLogMinibatchesPerWorker() {
        return this.logMinibatchesPerWorker;
    }

    public boolean isEncodingDebugMode() {
        return this.encodingDebugMode;
    }

    public ThresholdAlgorithm getThresholdAlgorithm() {
        return this.thresholdAlgorithm;
    }

    public ResidualPostProcessor getResidualPostProcessor() {
        return this.residualPostProcessor;
    }

    public Repartition getRepartition() {
        return this.repartition;
    }

    public RepartitionStrategy getRepartitionStrategy() {
        return this.repartitionStrategy;
    }

    public ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper getStats() {
        return this.stats;
    }

    public Random getRng() {
        return this.rng;
    }

    public AtomicBoolean getIsFirstRun() {
        return this.isFirstRun;
    }

    public int getInstanceId() {
        return this.instanceId;
    }

    public Broadcast<NetBroadcastTuple> getBroadcastModel() {
        return this.broadcastModel;
    }

    public Broadcast<SharedTrainingConfiguration> getBroadcastConfiguration() {
        return this.broadcastConfiguration;
    }

    public Transport getTransport() {
        return this.transport;
    }

    public SilentTrainingDriver getTrainingDriver() {
        return this.trainingDriver;
    }

    public UpdatesConsumer getUpdatesConsumer() {
        return this.updatesConsumer;
    }

    public boolean isSetupDone() {
        return this.setupDone;
    }

    public void setTrainingHooks(List<TrainingHook> list) {
        this.trainingHooks = list;
    }

    public void setVoidConfiguration(VoidConfiguration voidConfiguration) {
        this.voidConfiguration = voidConfiguration;
    }

    public void setNumWorkers(Integer num) {
        this.numWorkers = num;
    }

    public void setNumWorkersPerNode(Integer num) {
        this.numWorkersPerNode = num;
    }

    public void setWorkerPrefetchBatches(int i) {
        this.workerPrefetchBatches = i;
    }

    public void setRddTrainingApproach(RDDTrainingApproach rDDTrainingApproach) {
        this.rddTrainingApproach = rDDTrainingApproach;
    }

    public void setStorageLevel(StorageLevel storageLevel) {
        this.storageLevel = storageLevel;
    }

    public void setRepartitioner(Repartitioner repartitioner) {
        this.repartitioner = repartitioner;
    }

    public void setRddDataSetNumExamples(int i) {
        this.rddDataSetNumExamples = i;
    }

    public void setDebugLongerIterations(long j) {
        this.debugLongerIterations = j;
    }

    public void setLogMinibatchesPerWorker(boolean z) {
        this.logMinibatchesPerWorker = z;
    }

    public void setEncodingDebugMode(boolean z) {
        this.encodingDebugMode = z;
    }

    public void setThresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm) {
        this.thresholdAlgorithm = thresholdAlgorithm;
    }

    public void setResidualPostProcessor(ResidualPostProcessor residualPostProcessor) {
        this.residualPostProcessor = residualPostProcessor;
    }

    public void setRepartition(Repartition repartition) {
        this.repartition = repartition;
    }

    public void setRepartitionStrategy(RepartitionStrategy repartitionStrategy) {
        this.repartitionStrategy = repartitionStrategy;
    }

    public void setStats(ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper parameterAveragingTrainingMasterStatsHelper) {
        this.stats = parameterAveragingTrainingMasterStatsHelper;
    }

    public void setRng(Random random) {
        this.rng = random;
    }

    public void setIsFirstRun(AtomicBoolean atomicBoolean) {
        this.isFirstRun = atomicBoolean;
    }

    public void setBroadcastModel(Broadcast<NetBroadcastTuple> broadcast) {
        this.broadcastModel = broadcast;
    }

    public void setBroadcastConfiguration(Broadcast<SharedTrainingConfiguration> broadcast) {
        this.broadcastConfiguration = broadcast;
    }

    public void setTransport(Transport transport) {
        this.transport = transport;
    }

    public void setTrainingDriver(SilentTrainingDriver silentTrainingDriver) {
        this.trainingDriver = silentTrainingDriver;
    }

    public void setUpdatesConsumer(UpdatesConsumer updatesConsumer) {
        this.updatesConsumer = updatesConsumer;
    }

    public void setSetupDone(boolean z) {
        this.setupDone = z;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SharedTrainingMaster)) {
            return false;
        }
        SharedTrainingMaster sharedTrainingMaster = (SharedTrainingMaster) obj;
        if (!sharedTrainingMaster.canEqual(this)) {
            return false;
        }
        List<TrainingHook> trainingHooks = getTrainingHooks();
        List<TrainingHook> trainingHooks2 = sharedTrainingMaster.getTrainingHooks();
        if (trainingHooks == null) {
            if (trainingHooks2 != null) {
                return false;
            }
        } else if (!trainingHooks.equals(trainingHooks2)) {
            return false;
        }
        VoidConfiguration voidConfiguration = getVoidConfiguration();
        VoidConfiguration voidConfiguration2 = sharedTrainingMaster.getVoidConfiguration();
        if (voidConfiguration == null) {
            if (voidConfiguration2 != null) {
                return false;
            }
        } else if (!voidConfiguration.equals(voidConfiguration2)) {
            return false;
        }
        Integer numWorkers = getNumWorkers();
        Integer numWorkers2 = sharedTrainingMaster.getNumWorkers();
        if (numWorkers == null) {
            if (numWorkers2 != null) {
                return false;
            }
        } else if (!numWorkers.equals(numWorkers2)) {
            return false;
        }
        Integer numWorkersPerNode = getNumWorkersPerNode();
        Integer numWorkersPerNode2 = sharedTrainingMaster.getNumWorkersPerNode();
        if (numWorkersPerNode == null) {
            if (numWorkersPerNode2 != null) {
                return false;
            }
        } else if (!numWorkersPerNode.equals(numWorkersPerNode2)) {
            return false;
        }
        if (getWorkerPrefetchBatches() != sharedTrainingMaster.getWorkerPrefetchBatches()) {
            return false;
        }
        RDDTrainingApproach rddTrainingApproach = getRddTrainingApproach();
        RDDTrainingApproach rddTrainingApproach2 = sharedTrainingMaster.getRddTrainingApproach();
        if (rddTrainingApproach == null) {
            if (rddTrainingApproach2 != null) {
                return false;
            }
        } else if (!rddTrainingApproach.equals(rddTrainingApproach2)) {
            return false;
        }
        StorageLevel storageLevel = getStorageLevel();
        StorageLevel storageLevel2 = sharedTrainingMaster.getStorageLevel();
        if (storageLevel == null) {
            if (storageLevel2 != null) {
                return false;
            }
        } else if (!storageLevel.equals(storageLevel2)) {
            return false;
        }
        Repartitioner repartitioner = getRepartitioner();
        Repartitioner repartitioner2 = sharedTrainingMaster.getRepartitioner();
        if (repartitioner == null) {
            if (repartitioner2 != null) {
                return false;
            }
        } else if (!repartitioner.equals(repartitioner2)) {
            return false;
        }
        if (isCollectTrainingStats() != sharedTrainingMaster.isCollectTrainingStats() || getRddDataSetNumExamples() != sharedTrainingMaster.getRddDataSetNumExamples() || getDebugLongerIterations() != sharedTrainingMaster.getDebugLongerIterations() || isLogMinibatchesPerWorker() != sharedTrainingMaster.isLogMinibatchesPerWorker() || isEncodingDebugMode() != sharedTrainingMaster.isEncodingDebugMode()) {
            return false;
        }
        ThresholdAlgorithm thresholdAlgorithm = getThresholdAlgorithm();
        ThresholdAlgorithm thresholdAlgorithm2 = sharedTrainingMaster.getThresholdAlgorithm();
        if (thresholdAlgorithm == null) {
            if (thresholdAlgorithm2 != null) {
                return false;
            }
        } else if (!thresholdAlgorithm.equals(thresholdAlgorithm2)) {
            return false;
        }
        ResidualPostProcessor residualPostProcessor = getResidualPostProcessor();
        ResidualPostProcessor residualPostProcessor2 = sharedTrainingMaster.getResidualPostProcessor();
        if (residualPostProcessor == null) {
            if (residualPostProcessor2 != null) {
                return false;
            }
        } else if (!residualPostProcessor.equals(residualPostProcessor2)) {
            return false;
        }
        Repartition repartition = getRepartition();
        Repartition repartition2 = sharedTrainingMaster.getRepartition();
        if (repartition == null) {
            if (repartition2 != null) {
                return false;
            }
        } else if (!repartition.equals(repartition2)) {
            return false;
        }
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        RepartitionStrategy repartitionStrategy2 = sharedTrainingMaster.getRepartitionStrategy();
        if (repartitionStrategy == null) {
            if (repartitionStrategy2 != null) {
                return false;
            }
        } else if (!repartitionStrategy.equals(repartitionStrategy2)) {
            return false;
        }
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats = getStats();
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats2 = sharedTrainingMaster.getStats();
        if (stats == null) {
            if (stats2 != null) {
                return false;
            }
        } else if (!stats.equals(stats2)) {
            return false;
        }
        Random rng = getRng();
        Random rng2 = sharedTrainingMaster.getRng();
        if (rng == null) {
            if (rng2 != null) {
                return false;
            }
        } else if (!rng.equals(rng2)) {
            return false;
        }
        AtomicBoolean isFirstRun = getIsFirstRun();
        AtomicBoolean isFirstRun2 = sharedTrainingMaster.getIsFirstRun();
        if (isFirstRun == null) {
            if (isFirstRun2 != null) {
                return false;
            }
        } else if (!isFirstRun.equals(isFirstRun2)) {
            return false;
        }
        return isSetupDone() == sharedTrainingMaster.isSetupDone();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SharedTrainingMaster;
    }

    public int hashCode() {
        List<TrainingHook> trainingHooks = getTrainingHooks();
        int hashCode = (1 * 59) + (trainingHooks == null ? 43 : trainingHooks.hashCode());
        VoidConfiguration voidConfiguration = getVoidConfiguration();
        int hashCode2 = (hashCode * 59) + (voidConfiguration == null ? 43 : voidConfiguration.hashCode());
        Integer numWorkers = getNumWorkers();
        int hashCode3 = (hashCode2 * 59) + (numWorkers == null ? 43 : numWorkers.hashCode());
        Integer numWorkersPerNode = getNumWorkersPerNode();
        int hashCode4 = (((hashCode3 * 59) + (numWorkersPerNode == null ? 43 : numWorkersPerNode.hashCode())) * 59) + getWorkerPrefetchBatches();
        RDDTrainingApproach rddTrainingApproach = getRddTrainingApproach();
        int hashCode5 = (hashCode4 * 59) + (rddTrainingApproach == null ? 43 : rddTrainingApproach.hashCode());
        StorageLevel storageLevel = getStorageLevel();
        int hashCode6 = (hashCode5 * 59) + (storageLevel == null ? 43 : storageLevel.hashCode());
        Repartitioner repartitioner = getRepartitioner();
        int hashCode7 = (((((hashCode6 * 59) + (repartitioner == null ? 43 : repartitioner.hashCode())) * 59) + (isCollectTrainingStats() ? 79 : 97)) * 59) + getRddDataSetNumExamples();
        long debugLongerIterations = getDebugLongerIterations();
        int i = (((((hashCode7 * 59) + ((int) ((debugLongerIterations >>> 32) ^ debugLongerIterations))) * 59) + (isLogMinibatchesPerWorker() ? 79 : 97)) * 59) + (isEncodingDebugMode() ? 79 : 97);
        ThresholdAlgorithm thresholdAlgorithm = getThresholdAlgorithm();
        int hashCode8 = (i * 59) + (thresholdAlgorithm == null ? 43 : thresholdAlgorithm.hashCode());
        ResidualPostProcessor residualPostProcessor = getResidualPostProcessor();
        int hashCode9 = (hashCode8 * 59) + (residualPostProcessor == null ? 43 : residualPostProcessor.hashCode());
        Repartition repartition = getRepartition();
        int hashCode10 = (hashCode9 * 59) + (repartition == null ? 43 : repartition.hashCode());
        RepartitionStrategy repartitionStrategy = getRepartitionStrategy();
        int hashCode11 = (hashCode10 * 59) + (repartitionStrategy == null ? 43 : repartitionStrategy.hashCode());
        ParameterAveragingTrainingMasterStats.ParameterAveragingTrainingMasterStatsHelper stats = getStats();
        int hashCode12 = (hashCode11 * 59) + (stats == null ? 43 : stats.hashCode());
        Random rng = getRng();
        int hashCode13 = (hashCode12 * 59) + (rng == null ? 43 : rng.hashCode());
        AtomicBoolean isFirstRun = getIsFirstRun();
        return (((hashCode13 * 59) + (isFirstRun == null ? 43 : isFirstRun.hashCode())) * 59) + (isSetupDone() ? 79 : 97);
    }

    public String toString() {
        return "SharedTrainingMaster(trainingHooks=" + getTrainingHooks() + ", voidConfiguration=" + getVoidConfiguration() + ", numWorkers=" + getNumWorkers() + ", numWorkersPerNode=" + getNumWorkersPerNode() + ", workerPrefetchBatches=" + getWorkerPrefetchBatches() + ", rddTrainingApproach=" + getRddTrainingApproach() + ", storageLevel=" + getStorageLevel() + ", repartitioner=" + getRepartitioner() + ", collectTrainingStats=" + isCollectTrainingStats() + ", rddDataSetNumExamples=" + getRddDataSetNumExamples() + ", debugLongerIterations=" + getDebugLongerIterations() + ", logMinibatchesPerWorker=" + isLogMinibatchesPerWorker() + ", encodingDebugMode=" + isEncodingDebugMode() + ", thresholdAlgorithm=" + getThresholdAlgorithm() + ", residualPostProcessor=" + getResidualPostProcessor() + ", repartition=" + getRepartition() + ", repartitionStrategy=" + getRepartitionStrategy() + ", stats=" + getStats() + ", rng=" + getRng() + ", isFirstRun=" + getIsFirstRun() + ", instanceId=" + getInstanceId() + ", broadcastModel=" + getBroadcastModel() + ", broadcastConfiguration=" + getBroadcastConfiguration() + ", transport=" + getTransport() + ", trainingDriver=" + getTrainingDriver() + ", updatesConsumer=" + getUpdatesConsumer() + ", setupDone=" + isSetupDone() + ")";
    }
}
