package org.deeplearning4j.ui.module.train;

import java.text.DateFormat;
import java.text.DecimalFormat;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
import org.deeplearning4j.api.storage.Persistable;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.FunctionType;
import org.deeplearning4j.ui.api.HttpMethod;
import org.deeplearning4j.ui.api.I18N;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.i18n.I18NResource;
import org.deeplearning4j.ui.module.train.TrainModuleUtils;
import org.deeplearning4j.ui.stats.api.Histogram;
import org.deeplearning4j.ui.stats.api.StatsInitializationReport;
import org.deeplearning4j.ui.stats.api.StatsReport;
import org.deeplearning4j.ui.stats.api.StatsType;
import org.deeplearning4j.ui.views.html.training.TrainingModel;
import org.deeplearning4j.ui.views.html.training.TrainingOverview;
import org.deeplearning4j.ui.views.html.training.TrainingSystem;
import org.eclipse.collections.impl.list.mutable.primitive.LongArrayList;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import play.mvc.Result;
import play.mvc.Results;

/* loaded from: input_file:org/deeplearning4j/ui/module/train/TrainModule.class */
public class TrainModule implements UIModule {
    public static final double NAN_REPLACEMENT_VALUE = 0.0d;
    public static final int DEFAULT_MAX_CHART_POINTS = 512;

    @Deprecated
    public static final String CHART_MAX_POINTS_PROPERTY = "org.deeplearning4j.ui.maxChartPoints";
    private final int maxChartPoints;
    private String currentSessionID;
    private int currentWorkerIdx;
    private static final Logger log = LoggerFactory.getLogger(TrainModule.class);
    private static final DecimalFormat df2 = new DecimalFormat("#.00");
    private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
    private static final ObjectMapper JSON = new ObjectMapper();
    private static Triple<int[], float[], float[]> EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]);
    private static final Map<String, Object> EMPTY_LR_MAP = new HashMap();
    private Map<String, StatsStorage> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap());
    private Map<String, AtomicInteger> workerIdxCount = Collections.synchronizedMap(new HashMap());
    private Map<String, Map<Integer, String>> workerIdxToName = Collections.synchronizedMap(new HashMap());
    private Map<String, Long> lastUpdateForSession = Collections.synchronizedMap(new HashMap());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/ui/module/train/TrainModule$MeanMagnitudes.class */
    public static class MeanMagnitudes {
        private List<Integer> iterations;
        private Map<String, List<Double>> ratios;
        private Map<String, List<Double>> paramMM;
        private Map<String, List<Double>> updateMM;

        public MeanMagnitudes(List<Integer> list, Map<String, List<Double>> map, Map<String, List<Double>> map2, Map<String, List<Double>> map3) {
            this.iterations = list;
            this.ratios = map;
            this.paramMM = map2;
            this.updateMM = map3;
        }

        public List<Integer> getIterations() {
            return this.iterations;
        }

        public Map<String, List<Double>> getRatios() {
            return this.ratios;
        }

        public Map<String, List<Double>> getParamMM() {
            return this.paramMM;
        }

        public Map<String, List<Double>> getUpdateMM() {
            return this.updateMM;
        }

        public void setIterations(List<Integer> list) {
            this.iterations = list;
        }

        public void setRatios(Map<String, List<Double>> map) {
            this.ratios = map;
        }

        public void setParamMM(Map<String, List<Double>> map) {
            this.paramMM = map;
        }

        public void setUpdateMM(Map<String, List<Double>> map) {
            this.updateMM = map;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof MeanMagnitudes)) {
                return false;
            }
            MeanMagnitudes meanMagnitudes = (MeanMagnitudes) obj;
            if (!meanMagnitudes.canEqual(this)) {
                return false;
            }
            List<Integer> iterations = getIterations();
            List<Integer> iterations2 = meanMagnitudes.getIterations();
            if (iterations == null) {
                if (iterations2 != null) {
                    return false;
                }
            } else if (!iterations.equals(iterations2)) {
                return false;
            }
            Map<String, List<Double>> ratios = getRatios();
            Map<String, List<Double>> ratios2 = meanMagnitudes.getRatios();
            if (ratios == null) {
                if (ratios2 != null) {
                    return false;
                }
            } else if (!ratios.equals(ratios2)) {
                return false;
            }
            Map<String, List<Double>> paramMM = getParamMM();
            Map<String, List<Double>> paramMM2 = meanMagnitudes.getParamMM();
            if (paramMM == null) {
                if (paramMM2 != null) {
                    return false;
                }
            } else if (!paramMM.equals(paramMM2)) {
                return false;
            }
            Map<String, List<Double>> updateMM = getUpdateMM();
            Map<String, List<Double>> updateMM2 = meanMagnitudes.getUpdateMM();
            return updateMM == null ? updateMM2 == null : updateMM.equals(updateMM2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof MeanMagnitudes;
        }

        public int hashCode() {
            List<Integer> iterations = getIterations();
            int hashCode = (1 * 59) + (iterations == null ? 43 : iterations.hashCode());
            Map<String, List<Double>> ratios = getRatios();
            int hashCode2 = (hashCode * 59) + (ratios == null ? 43 : ratios.hashCode());
            Map<String, List<Double>> paramMM = getParamMM();
            int hashCode3 = (hashCode2 * 59) + (paramMM == null ? 43 : paramMM.hashCode());
            Map<String, List<Double>> updateMM = getUpdateMM();
            return (hashCode3 * 59) + (updateMM == null ? 43 : updateMM.hashCode());
        }

        public String toString() {
            return "TrainModule.MeanMagnitudes(iterations=" + getIterations() + ", ratios=" + getRatios() + ", paramMM=" + getParamMM() + ", updateMM=" + getUpdateMM() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/deeplearning4j/ui/module/train/TrainModule$ModelType.class */
    public enum ModelType {
        MLN,
        CG,
        Layer
    }

    public TrainModule() {
        String property = System.getProperty(CHART_MAX_POINTS_PROPERTY);
        int i = 512;
        if (property != null) {
            try {
                i = Integer.parseInt(property);
            } catch (NumberFormatException e) {
                log.warn("Invalid system property: {} = {}", CHART_MAX_POINTS_PROPERTY, property);
            }
        }
        if (i >= 10) {
            this.maxChartPoints = i;
        } else {
            this.maxChartPoints = DEFAULT_MAX_CHART_POINTS;
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public List<String> getCallbackTypeIDs() {
        return Collections.singletonList("StatsListener");
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public List<Route> getRoutes() {
        return Arrays.asList(new Route("/train", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.redirect("/train/overview");
        }), new Route("/train/overview", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(TrainingOverview.apply(I18NProvider.getInstance()));
        }), new Route("/train/overview/data", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) this::getOverviewData), new Route("/train/model", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(TrainingModel.apply(I18NProvider.getInstance()));
        }), new Route("/train/model/graph", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) this::getModelGraph), new Route("/train/model/data/:layerId", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::getModelData), new Route("/train/system", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(TrainingSystem.apply(I18NProvider.getInstance()));
        }), new Route("/train/system/data", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) this::getSystemData), new Route("/train/sessions/current", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(this.currentSessionID == null ? "" : this.currentSessionID);
        }), new Route("/train/sessions/all", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) this::listSessions), new Route("/train/sessions/info", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) this::sessionInfo), new Route("/train/sessions/set/:to", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::setSession), new Route("/train/sessions/lastUpdate/:sessionId", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::getLastUpdateForSession), new Route("/train/workers/currentByIdx", HttpMethod.GET, FunctionType.Supplier, (Supplier<Result>) () -> {
            return Results.ok(String.valueOf(this.currentWorkerIdx));
        }), new Route("/train/workers/setByIdx/:to", HttpMethod.GET, FunctionType.Function, (Function<String, Result>) this::setWorkerByIdx));
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public synchronized void reportStorageEvents(Collection<StatsStorageEvent> collection) {
        for (StatsStorageEvent statsStorageEvent : collection) {
            if ("StatsListener".equals(statsStorageEvent.getTypeID())) {
                if (statsStorageEvent.getEventType() == StatsStorageListener.EventType.PostStaticInfo && "StatsListener".equals(statsStorageEvent.getTypeID())) {
                    this.knownSessionIDs.put(statsStorageEvent.getSessionID(), statsStorageEvent.getStatsStorage());
                }
                Long l = this.lastUpdateForSession.get(statsStorageEvent.getSessionID());
                if (l == null) {
                    this.lastUpdateForSession.put(statsStorageEvent.getSessionID(), Long.valueOf(statsStorageEvent.getTimestamp()));
                } else if (statsStorageEvent.getTimestamp() > l.longValue()) {
                    this.lastUpdateForSession.put(statsStorageEvent.getSessionID(), Long.valueOf(statsStorageEvent.getTimestamp()));
                }
            }
        }
        if (this.currentSessionID == null) {
            getDefaultSession();
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public synchronized void onAttach(StatsStorage statsStorage) {
        for (String str : statsStorage.listSessionIDs()) {
            Iterator it = statsStorage.listTypeIDsForSession(str).iterator();
            while (it.hasNext()) {
                if ("StatsListener".equals((String) it.next())) {
                    this.knownSessionIDs.put(str, statsStorage);
                }
            }
        }
        if (this.currentSessionID == null) {
            getDefaultSession();
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public void onDetach(StatsStorage statsStorage) {
        for (String str : this.knownSessionIDs.keySet()) {
            if (this.knownSessionIDs.get(str) == statsStorage) {
                this.knownSessionIDs.remove(str);
            }
        }
    }

    private void getDefaultSession() {
        if (this.currentSessionID != null) {
            return;
        }
        long j = Long.MIN_VALUE;
        String str = null;
        for (Map.Entry<String, StatsStorage> entry : this.knownSessionIDs.entrySet()) {
            List allStaticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), "StatsListener");
            if (allStaticInfos != null && !allStaticInfos.isEmpty()) {
                long timeStamp = ((Persistable) allStaticInfos.get(0)).getTimeStamp();
                if (timeStamp > j) {
                    j = timeStamp;
                    str = entry.getKey();
                }
            }
        }
        if (str != null) {
            this.currentSessionID = str;
        }
    }

    private synchronized String getWorkerIdForIndex(int i) {
        String str = this.currentSessionID;
        if (str == null) {
            return null;
        }
        Map<Integer, String> map = this.workerIdxToName.get(str);
        if (map == null) {
            map = Collections.synchronizedMap(new HashMap());
            this.workerIdxToName.put(str, map);
        }
        if (map.containsKey(Integer.valueOf(i))) {
            return map.get(Integer.valueOf(i));
        }
        AtomicInteger atomicInteger = this.workerIdxCount.get(str);
        if (atomicInteger == null) {
            atomicInteger = new AtomicInteger(0);
            this.workerIdxCount.put(str, atomicInteger);
        }
        ArrayList<String> arrayList = new ArrayList(this.knownSessionIDs.get(str).listWorkerIDsForSessionAndType(str, "StatsListener"));
        Collections.sort(arrayList);
        for (String str2 : arrayList) {
            if (!map.containsValue(str2)) {
                map.put(Integer.valueOf(atomicInteger.getAndIncrement()), str2);
            }
        }
        return map.get(Integer.valueOf(i));
    }

    private Result listSessions() {
        return Results.ok(asJson(this.knownSessionIDs.keySet())).as("application/json");
    }

    private Result sessionInfo() {
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, StatsStorage> entry : this.knownSessionIDs.entrySet()) {
            HashMap hashMap2 = new HashMap();
            String key = entry.getKey();
            StatsStorage value = entry.getValue();
            List listWorkerIDsForSessionAndType = value.listWorkerIDsForSessionAndType(key, "StatsListener");
            int size = listWorkerIDsForSessionAndType == null ? 0 : listWorkerIDsForSessionAndType.size();
            List allStaticInfos = value.getAllStaticInfos(key, "StatsListener");
            long j = Long.MAX_VALUE;
            if (allStaticInfos != null) {
                Iterator it = allStaticInfos.iterator();
                while (it.hasNext()) {
                    j = Math.min(((Persistable) it.next()).getTimeStamp(), j);
                }
            }
            long j2 = Long.MIN_VALUE;
            Iterator it2 = value.getLatestUpdateAllWorkers(key, "StatsListener").iterator();
            while (it2.hasNext()) {
                j2 = Math.max(j2, ((Persistable) it2.next()).getTimeStamp());
            }
            hashMap2.put("numWorkers", Integer.valueOf(size));
            hashMap2.put("initTime", j == Long.MAX_VALUE ? "" : Long.valueOf(j));
            hashMap2.put("lastUpdate", j2 == Long.MIN_VALUE ? "" : Long.valueOf(j2));
            if (size > 0) {
                hashMap2.put("workers", listWorkerIDsForSessionAndType);
            }
            if (allStaticInfos == null || allStaticInfos.isEmpty()) {
                hashMap2.put("modelType", "");
                hashMap2.put("numLayers", "");
                hashMap2.put("numParams", "");
            } else {
                StatsInitializationReport statsInitializationReport = (StatsInitializationReport) allStaticInfos.get(0);
                String modelClassName = statsInitializationReport.getModelClassName();
                if (modelClassName.endsWith("MultiLayerNetwork")) {
                    modelClassName = "MultiLayerNetwork";
                } else if (modelClassName.endsWith("ComputationGraph")) {
                    modelClassName = "ComputationGraph";
                }
                int modelNumLayers = statsInitializationReport.getModelNumLayers();
                long modelNumParams = statsInitializationReport.getModelNumParams();
                hashMap2.put("modelType", modelClassName);
                hashMap2.put("numLayers", Integer.valueOf(modelNumLayers));
                hashMap2.put("numParams", Long.valueOf(modelNumParams));
            }
            hashMap.put(key, hashMap2);
        }
        return Results.ok(asJson(hashMap)).as("application/json");
    }

    private Result setSession(String str) {
        if (!this.knownSessionIDs.containsKey(str)) {
            return Results.badRequest("Unknown session ID: " + str);
        }
        this.currentSessionID = str;
        this.currentWorkerIdx = 0;
        return Results.ok();
    }

    private Result getLastUpdateForSession(String str) {
        Long l = this.lastUpdateForSession.get(str);
        return l != null ? Results.ok(String.valueOf(l)) : Results.ok("-1");
    }

    private Result setWorkerByIdx(String str) {
        try {
            this.currentWorkerIdx = Integer.parseInt(str);
        } catch (NumberFormatException e) {
            log.debug("Invalid call to setWorkerByIdx", e);
        }
        return Results.ok();
    }

    private static double fixNaN(double d) {
        return Double.isFinite(d) ? d : NAN_REPLACEMENT_VALUE;
    }

    private static void cleanLegacyIterationCounts(List<Integer> list) {
        if (list.isEmpty()) {
            return;
        }
        boolean z = true;
        int i = 1;
        int intValue = list.get(0).intValue();
        int size = list.size();
        int i2 = intValue;
        for (int i3 = 1; i3 < size; i3++) {
            int intValue2 = list.get(i3).intValue();
            if (z && intValue != intValue2) {
                z = false;
            }
            i = Math.max(i, i2 - intValue2);
            i2 = intValue2;
        }
        if (z) {
            i = 1;
        }
        for (int i4 = 0; i4 < size; i4++) {
            list.set(i4, Integer.valueOf(intValue + (i4 * i)));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Result getOverviewData() {
        StatsInitializationReport staticInfo;
        String str;
        int i;
        Long l = this.lastUpdateForSession.get(this.currentSessionID);
        if (l == null) {
            l = -1L;
        }
        I18N i18NProvider = I18NProvider.getInstance();
        boolean z = this.currentSessionID == null;
        StatsStorage statsStorage = z ? null : this.knownSessionIDs.get(this.currentSessionID);
        String workerIdForIndex = getWorkerIdForIndex(this.currentWorkerIdx);
        if (workerIdForIndex == null) {
            z = true;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashMap hashMap = new HashMap();
        hashMap.put("updateTimestamp", l);
        hashMap.put("scores", arrayList2);
        hashMap.put("scoresIter", arrayList);
        long[] allUpdateTimes = z ? null : statsStorage.getAllUpdateTimes(this.currentSessionID, "StatsListener", workerIdForIndex);
        List<Persistable> list = null;
        if (allUpdateTimes != null && allUpdateTimes.length > this.maxChartPoints) {
            int length = allUpdateTimes.length / this.maxChartPoints;
            LongArrayList longArrayList = new LongArrayList(this.maxChartPoints + 2);
            int i2 = 0;
            while (true) {
                i = i2;
                if (i >= allUpdateTimes.length) {
                    break;
                }
                longArrayList.add(allUpdateTimes[i]);
                i2 = i + length;
            }
            if (i - length != allUpdateTimes.length - 1) {
                longArrayList.add(allUpdateTimes[allUpdateTimes.length - 1]);
            }
            list = statsStorage.getUpdates(this.currentSessionID, "StatsListener", workerIdForIndex, longArrayList.toArray());
        } else if (allUpdateTimes != null) {
            list = statsStorage.getAllUpdatesAfter(this.currentSessionID, "StatsListener", workerIdForIndex, 0L);
        }
        if (list == null || list.isEmpty()) {
            z = true;
        }
        HashMap hashMap2 = new HashMap();
        hashMap.put("updateRatios", hashMap2);
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        HashMap hashMap5 = new HashMap();
        hashMap.put("stdevActivations", hashMap3);
        hashMap.put("stdevGradients", hashMap4);
        hashMap.put("stdevUpdates", hashMap5);
        if (!z) {
            StatsReport statsReport = (Persistable) list.get(0);
            if (statsReport instanceof StatsReport) {
                StatsReport statsReport2 = statsReport;
                Map meanMagnitudes = statsReport2.getMeanMagnitudes(StatsType.Parameters);
                if (meanMagnitudes != null) {
                    for (String str2 : meanMagnitudes.keySet()) {
                        if (str2.toLowerCase().endsWith("w")) {
                            hashMap2.put(str2, new ArrayList());
                        }
                    }
                }
                Map stdev = statsReport2.getStdev(StatsType.Gradients);
                if (stdev != null) {
                    for (String str3 : stdev.keySet()) {
                        if (str3.toLowerCase().endsWith("w")) {
                            hashMap4.put(str3, new ArrayList());
                        }
                    }
                }
                Map stdev2 = statsReport2.getStdev(StatsType.Updates);
                if (stdev2 != null) {
                    for (String str4 : stdev2.keySet()) {
                        if (str4.toLowerCase().endsWith("w")) {
                            hashMap5.put(str4, new ArrayList());
                        }
                    }
                }
                Map stdev3 = statsReport2.getStdev(StatsType.Activations);
                if (stdev3 != null) {
                    Iterator it = stdev3.keySet().iterator();
                    while (it.hasNext()) {
                        hashMap3.put((String) it.next(), new ArrayList());
                    }
                }
            }
        }
        StatsReport statsReport3 = null;
        int i3 = -1;
        boolean z2 = false;
        if (!z) {
            int size = list.size();
            int i4 = size > this.maxChartPoints ? size / this.maxChartPoints : 1;
            int i5 = -1;
            int size2 = list.size() - 1;
            for (Persistable persistable : list) {
                i5++;
                if (persistable instanceof StatsReport) {
                    statsReport3 = (StatsReport) persistable;
                    int iterationCount = statsReport3.getIterationCount();
                    if (iterationCount <= i3) {
                        z2 = true;
                    }
                    i3 = iterationCount;
                    if (i5 <= 0 || i4 <= 1 || i5 % i4 == 0 || i5 == size2) {
                        arrayList.add(Integer.valueOf(iterationCount));
                        double score = statsReport3.getScore();
                        if (Double.isFinite(score)) {
                            arrayList2.add(Double.valueOf(score));
                        } else {
                            arrayList2.add(Double.valueOf(NAN_REPLACEMENT_VALUE));
                        }
                        Map meanMagnitudes2 = statsReport3.getMeanMagnitudes(StatsType.Updates);
                        Map meanMagnitudes3 = statsReport3.getMeanMagnitudes(StatsType.Parameters);
                        if (meanMagnitudes2 != null && meanMagnitudes3 != null && meanMagnitudes2.size() > 0 && meanMagnitudes3.size() > 0) {
                            for (String str5 : hashMap2.keySet()) {
                                List list2 = (List) hashMap2.get(str5);
                                double doubleValue = ((Double) meanMagnitudes2.getOrDefault(str5, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue() / ((Double) meanMagnitudes3.getOrDefault(str5, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue();
                                if (Double.isFinite(doubleValue)) {
                                    list2.add(Double.valueOf(doubleValue));
                                } else {
                                    list2.add(Double.valueOf(NAN_REPLACEMENT_VALUE));
                                }
                            }
                        }
                        Map stdev4 = statsReport3.getStdev(StatsType.Gradients);
                        Map stdev5 = statsReport3.getStdev(StatsType.Updates);
                        Map stdev6 = statsReport3.getStdev(StatsType.Activations);
                        if (stdev4 != null) {
                            for (String str6 : hashMap4.keySet()) {
                                ((List) hashMap4.get(str6)).add(Double.valueOf(fixNaN(((Double) stdev4.getOrDefault(str6, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue())));
                            }
                        }
                        if (stdev5 != null) {
                            for (String str7 : hashMap5.keySet()) {
                                ((List) hashMap5.get(str7)).add(Double.valueOf(fixNaN(((Double) stdev5.getOrDefault(str7, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue())));
                            }
                        }
                        if (stdev6 != null) {
                            for (String str8 : hashMap3.keySet()) {
                                ((List) hashMap3.get(str8)).add(Double.valueOf(fixNaN(((Double) stdev6.getOrDefault(str8, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue())));
                            }
                        }
                    }
                }
            }
        }
        if (z2) {
            cleanLegacyIterationCounts(arrayList);
        }
        String[] strArr = {new String[]{i18NProvider.getMessage("train.overview.perftable.startTime"), ""}, new String[]{i18NProvider.getMessage("train.overview.perftable.totalRuntime"), ""}, new String[]{i18NProvider.getMessage("train.overview.perftable.lastUpdate"), ""}, new String[]{i18NProvider.getMessage("train.overview.perftable.totalParamUpdates"), ""}, new String[]{i18NProvider.getMessage("train.overview.perftable.updatesPerSec"), ""}, new String[]{i18NProvider.getMessage("train.overview.perftable.examplesPerSec"), ""}};
        if (statsReport3 != null) {
            strArr[2][1] = String.valueOf(dateFormat.format(new Date(statsReport3.getTimeStamp())));
            strArr[3][1] = String.valueOf(statsReport3.getTotalMinibatches());
            strArr[4][1] = String.valueOf(df2.format(statsReport3.getMinibatchesPerSecond()));
            strArr[5][1] = String.valueOf(df2.format(statsReport3.getExamplesPerSecond()));
        }
        hashMap.put("perf", strArr);
        String[] strArr2 = {new String[]{i18NProvider.getMessage("train.overview.modeltable.modeltype"), ""}, new String[]{i18NProvider.getMessage("train.overview.modeltable.nLayers"), ""}, new String[]{i18NProvider.getMessage("train.overview.modeltable.nParams"), ""}};
        if (!z && (staticInfo = statsStorage.getStaticInfo(this.currentSessionID, "StatsListener", workerIdForIndex)) != null) {
            StatsInitializationReport statsInitializationReport = staticInfo;
            int modelNumLayers = statsInitializationReport.getModelNumLayers();
            long modelNumParams = statsInitializationReport.getModelNumParams();
            String modelClassName = statsInitializationReport.getModelClassName();
            if (modelClassName.endsWith("MultiLayerNetwork")) {
                str = "MultiLayerNetwork";
            } else if (modelClassName.endsWith("ComputationGraph")) {
                str = "ComputationGraph";
            } else {
                str = modelClassName;
                if (str.lastIndexOf(46) > 0) {
                    str = str.substring(str.lastIndexOf(46) + 1);
                }
            }
            strArr2[0][1] = str;
            strArr2[1][1] = String.valueOf(modelNumLayers);
            strArr2[2][1] = String.valueOf(modelNumParams);
        }
        hashMap.put("model", strArr2);
        return Results.ok(asJson(hashMap)).as("application/json");
    }

    private Result getModelGraph() {
        TrainModuleUtils.GraphInfo graphInfo;
        boolean z = this.currentSessionID == null;
        if (!(z ? Collections.EMPTY_LIST : (z ? null : this.knownSessionIDs.get(this.currentSessionID)).getAllStaticInfos(this.currentSessionID, "StatsListener")).isEmpty() && (graphInfo = getGraphInfo()) != null) {
            return Results.ok(asJson(graphInfo)).as("application/json");
        }
        return Results.ok();
    }

    private TrainModuleUtils.GraphInfo getGraphInfo() {
        Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> config = getConfig();
        if (config == null) {
            return null;
        }
        if (config.getFirst() != null) {
            return TrainModuleUtils.buildGraphInfo((MultiLayerConfiguration) config.getFirst());
        }
        if (config.getSecond() != null) {
            return TrainModuleUtils.buildGraphInfo((ComputationGraphConfiguration) config.getSecond());
        }
        if (config.getThird() != null) {
            return TrainModuleUtils.buildGraphInfo((NeuralNetConfiguration) config.getThird());
        }
        return null;
    }

    private Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> getConfig() {
        boolean z = this.currentSessionID == null;
        List allStaticInfos = z ? Collections.EMPTY_LIST : (z ? null : this.knownSessionIDs.get(this.currentSessionID)).getAllStaticInfos(this.currentSessionID, "StatsListener");
        if (allStaticInfos.isEmpty()) {
            return null;
        }
        StatsInitializationReport statsInitializationReport = (StatsInitializationReport) allStaticInfos.get(0);
        String modelClassName = statsInitializationReport.getModelClassName();
        String modelConfigJson = statsInitializationReport.getModelConfigJson();
        if (modelClassName.endsWith("MultiLayerNetwork")) {
            return new Triple<>(MultiLayerConfiguration.fromJson(modelConfigJson), (Object) null, (Object) null);
        }
        if (modelClassName.endsWith("ComputationGraph")) {
            return new Triple<>((Object) null, ComputationGraphConfiguration.fromJson(modelConfigJson), (Object) null);
        }
        try {
            return new Triple<>((Object) null, (Object) null, (NeuralNetConfiguration) NeuralNetConfiguration.mapper().readValue(modelConfigJson, NeuralNetConfiguration.class));
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    private Result getModelData(String str) {
        TrainModuleUtils.GraphInfo graphInfo;
        int i;
        Long l = this.lastUpdateForSession.get(this.currentSessionID);
        if (l == null) {
            l = -1L;
        }
        int parseInt = Integer.parseInt(str);
        I18N i18NProvider = I18NProvider.getInstance();
        boolean z = this.currentSessionID == null;
        StatsStorage statsStorage = z ? null : this.knownSessionIDs.get(this.currentSessionID);
        String workerIdForIndex = getWorkerIdForIndex(this.currentWorkerIdx);
        if (workerIdForIndex == null) {
            z = true;
        }
        HashMap hashMap = new HashMap();
        hashMap.put("updateTimestamp", l);
        Triple<MultiLayerConfiguration, ComputationGraphConfiguration, NeuralNetConfiguration> config = getConfig();
        if (config != null && (graphInfo = getGraphInfo()) != null) {
            hashMap.put("layerInfo", getLayerInfoTable(parseInt, graphInfo, i18NProvider, z, statsStorage, workerIdForIndex));
            long[] allUpdateTimes = z ? null : statsStorage.getAllUpdateTimes(this.currentSessionID, "StatsListener", workerIdForIndex);
            List<Persistable> list = null;
            boolean z2 = false;
            if (allUpdateTimes != null && allUpdateTimes.length > this.maxChartPoints) {
                int length = allUpdateTimes.length / this.maxChartPoints;
                LongArrayList longArrayList = new LongArrayList(this.maxChartPoints + 2);
                int i2 = 0;
                while (true) {
                    i = i2;
                    if (i >= allUpdateTimes.length) {
                        break;
                    }
                    longArrayList.add(allUpdateTimes[i]);
                    i2 = i + length;
                }
                if (i - length != allUpdateTimes.length - 1) {
                    longArrayList.add(allUpdateTimes[allUpdateTimes.length - 1]);
                }
                list = statsStorage.getUpdates(this.currentSessionID, "StatsListener", workerIdForIndex, longArrayList.toArray());
            } else if (allUpdateTimes != null) {
                list = statsStorage.getAllUpdatesAfter(this.currentSessionID, "StatsListener", workerIdForIndex, 0L);
            }
            ArrayList arrayList = new ArrayList(list.size());
            Iterator<Persistable> it = list.iterator();
            while (it.hasNext()) {
                StatsReport statsReport = (Persistable) it.next();
                if (statsReport instanceof StatsReport) {
                    int iterationCount = statsReport.getIterationCount();
                    if (iterationCount <= -1) {
                        z2 = true;
                    }
                    arrayList.add(Integer.valueOf(iterationCount));
                }
            }
            if (z2) {
                cleanLegacyIterationCounts(arrayList);
            }
            ModelType modelType = config.getFirst() != null ? ModelType.MLN : config.getSecond() != null ? ModelType.CG : ModelType.Layer;
            MeanMagnitudes layerMeanMagnitudes = getLayerMeanMagnitudes(parseInt, graphInfo, list, arrayList, modelType);
            HashMap hashMap2 = new HashMap();
            hashMap2.put("layerParamNames", layerMeanMagnitudes.getRatios().keySet());
            hashMap2.put("iterCounts", layerMeanMagnitudes.getIterations());
            hashMap2.put("ratios", layerMeanMagnitudes.getRatios());
            hashMap2.put("paramMM", layerMeanMagnitudes.getParamMM());
            hashMap2.put("updateMM", layerMeanMagnitudes.getUpdateMM());
            hashMap.put("meanMag", hashMap2);
            Triple<int[], float[], float[]> layerActivations = getLayerActivations(parseInt, graphInfo, list, arrayList);
            HashMap hashMap3 = new HashMap();
            hashMap3.put("iterCount", layerActivations.getFirst());
            hashMap3.put("mean", layerActivations.getSecond());
            hashMap3.put("stdev", layerActivations.getThird());
            hashMap.put("activations", hashMap3);
            hashMap.put("learningRates", getLayerLearningRates(parseInt, graphInfo, list, arrayList, modelType));
            Persistable persistable = (list == null || list.isEmpty()) ? null : list.get(list.size() - 1);
            hashMap.put("paramHist", getHistograms(parseInt, graphInfo, StatsType.Parameters, persistable));
            hashMap.put("updateHist", getHistograms(parseInt, graphInfo, StatsType.Updates, persistable));
            return Results.ok(asJson(hashMap)).as("application/json");
        }
        return Results.ok(asJson(hashMap)).as("application/json");
    }

    public Result getSystemData() {
        Long l = this.lastUpdateForSession.get(this.currentSessionID);
        if (l == null) {
            l = -1L;
        }
        I18N i18NProvider = I18NProvider.getInstance();
        boolean z = this.currentSessionID == null;
        StatsStorage statsStorage = z ? null : this.knownSessionIDs.get(this.currentSessionID);
        List allStaticInfos = z ? Collections.EMPTY_LIST : statsStorage.getAllStaticInfos(this.currentSessionID, "StatsListener");
        List latestUpdateAllWorkers = z ? Collections.EMPTY_LIST : statsStorage.getLatestUpdateAllWorkers(this.currentSessionID, "StatsListener");
        long j = -1;
        if (latestUpdateAllWorkers == null || latestUpdateAllWorkers.isEmpty()) {
            z = true;
        } else {
            Iterator it = latestUpdateAllWorkers.iterator();
            while (it.hasNext()) {
                j = Math.max(j, ((Persistable) it.next()).getTimeStamp());
            }
        }
        Map<String, Object> memory = getMemory(allStaticInfos, z ? null : statsStorage.getAllUpdatesAfter(this.currentSessionID, "StatsListener", j - 300000), i18NProvider);
        Pair<Map<String, Object>, Map<String, Object>> hardwareSoftwareInfo = getHardwareSoftwareInfo(allStaticInfos, i18NProvider);
        HashMap hashMap = new HashMap();
        hashMap.put("updateTimestamp", l);
        hashMap.put("memory", memory);
        hashMap.put("hardware", hardwareSoftwareInfo.getFirst());
        hashMap.put("software", hardwareSoftwareInfo.getSecond());
        return Results.ok(asJson(hashMap)).as("application/json");
    }

    private static String getLayerType(Layer layer) {
        String str = "n/a";
        if (layer != null) {
            try {
                str = layer.getClass().getSimpleName().replaceAll("Layer$", "");
            } catch (Exception e) {
            }
        }
        return str;
    }

    private String[][] getLayerInfoTable(int i, TrainModuleUtils.GraphInfo graphInfo, I18N i18n, boolean z, StatsStorage statsStorage, String str) {
        StatsInitializationReport staticInfo;
        int[] kernelSize;
        int[] stride;
        int[] padding;
        ArrayList arrayList = new ArrayList();
        arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerName"), graphInfo.getVertexNames().get(i)});
        arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerType"), ""});
        if (!z && (staticInfo = statsStorage.getStaticInfo(this.currentSessionID, "StatsListener", str)) != null) {
            StatsInitializationReport statsInitializationReport = staticInfo;
            String modelConfigJson = statsInitializationReport.getModelConfigJson();
            String modelClassName = statsInitializationReport.getModelClassName();
            String str2 = "";
            Layer layer = null;
            NeuralNetConfiguration neuralNetConfiguration = null;
            if (modelClassName.endsWith("MultiLayerNetwork")) {
                MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(modelConfigJson);
                int i2 = i - 1;
                if (i2 >= 0) {
                    neuralNetConfiguration = fromJson.getConf(i2);
                    layer = neuralNetConfiguration.getLayer();
                } else {
                    str2 = "Input";
                }
            } else if (modelClassName.endsWith("ComputationGraph")) {
                ComputationGraphConfiguration fromJson2 = ComputationGraphConfiguration.fromJson(modelConfigJson);
                String str3 = graphInfo.getVertexNames().get(i);
                Map vertices = fromJson2.getVertices();
                if (vertices.containsKey(str3) && (vertices.get(str3) instanceof LayerVertex)) {
                    neuralNetConfiguration = ((LayerVertex) vertices.get(str3)).getLayerConf();
                    layer = neuralNetConfiguration.getLayer();
                } else if (fromJson2.getNetworkInputs().contains(str3)) {
                    str2 = "Input";
                } else {
                    GraphVertex graphVertex = (GraphVertex) fromJson2.getVertices().get(str3);
                    if (graphVertex != null) {
                        str2 = graphVertex.getClass().getSimpleName();
                    }
                }
            } else if (modelClassName.endsWith("VariationalAutoencoder")) {
                str2 = graphInfo.getVertexTypes().get(i);
                for (Map.Entry<String, String> entry : graphInfo.getVertexInfo().get(i).entrySet()) {
                    arrayList.add(new String[]{entry.getKey(), entry.getValue()});
                }
            }
            if (layer != null) {
                str2 = getLayerType(layer);
            }
            if (layer != null) {
                String str4 = null;
                if (layer instanceof FeedForwardLayer) {
                    FeedForwardLayer feedForwardLayer = (FeedForwardLayer) layer;
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerNIn"), String.valueOf(feedForwardLayer.getNIn())});
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(feedForwardLayer.getNOut())});
                }
                if (layer instanceof BaseLayer) {
                    BaseLayer baseLayer = (BaseLayer) layer;
                    str4 = baseLayer.getActivationFn().toString();
                    long numParams = layer.initializer().numParams(neuralNetConfiguration);
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(numParams)});
                    if (numParams > 0) {
                        WeightInit weightInit = baseLayer.getWeightInit();
                        String weightInit2 = weightInit.toString();
                        if (weightInit == WeightInit.DISTRIBUTION) {
                            weightInit2 = weightInit2 + baseLayer.getDist();
                        }
                        arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerWeightInit"), weightInit2});
                        IUpdater iUpdater = baseLayer.getIUpdater();
                        arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerUpdater"), iUpdater == null ? "" : iUpdater.getClass().getSimpleName()});
                    }
                }
                if ((layer instanceof ConvolutionLayer) || (layer instanceof SubsamplingLayer)) {
                    if (layer instanceof ConvolutionLayer) {
                        ConvolutionLayer convolutionLayer = (ConvolutionLayer) layer;
                        kernelSize = convolutionLayer.getKernelSize();
                        stride = convolutionLayer.getStride();
                        padding = convolutionLayer.getPadding();
                    } else {
                        SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layer;
                        kernelSize = subsamplingLayer.getKernelSize();
                        stride = subsamplingLayer.getStride();
                        padding = subsamplingLayer.getPadding();
                        str4 = null;
                        arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerSubsamplingPoolingType"), subsamplingLayer.getPoolingType().toString()});
                    }
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerCnnKernel"), Arrays.toString(kernelSize)});
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerCnnStride"), Arrays.toString(stride)});
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerCnnPadding"), Arrays.toString(padding)});
                }
                if (str4 != null) {
                    arrayList.add(new String[]{i18n.getMessage("train.model.layerinfotable.layerActivationFn"), str4});
                }
            }
            ((String[]) arrayList.get(1))[1] = str2;
        }
        return (String[][]) arrayList.toArray(new String[arrayList.size()][0]);
    }

    private MeanMagnitudes getLayerMeanMagnitudes(int i, TrainModuleUtils.GraphInfo graphInfo, List<Persistable> list, List<Integer> list2, ModelType modelType) {
        if (graphInfo == null) {
            return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
        }
        String str = graphInfo.getVertexNames().get(i);
        if (modelType != ModelType.CG) {
            str = graphInfo.getOriginalVertexName().get(i);
        }
        if ("input".equalsIgnoreCase(graphInfo.getVertexTypes().get(i))) {
            return new MeanMagnitudes(Collections.emptyList(), Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap());
        }
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        if (list != null) {
            int i2 = -1;
            Iterator<Persistable> it = list.iterator();
            while (it.hasNext()) {
                StatsReport statsReport = (Persistable) it.next();
                i2++;
                if (statsReport instanceof StatsReport) {
                    StatsReport statsReport2 = statsReport;
                    if (list2 != null) {
                        arrayList.add(list2.get(i2));
                    } else {
                        arrayList.add(Integer.valueOf(statsReport2.getIterationCount()));
                    }
                    Map meanMagnitudes = statsReport2.getMeanMagnitudes(StatsType.Parameters);
                    Map meanMagnitudes2 = statsReport2.getMeanMagnitudes(StatsType.Updates);
                    for (String str2 : meanMagnitudes.keySet()) {
                        String str3 = modelType == ModelType.Layer ? str : str + "_";
                        if (str2.startsWith(str3)) {
                            String substring = str2.substring(str3.length());
                            double doubleValue = ((Double) meanMagnitudes.getOrDefault(str2, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue();
                            double doubleValue2 = ((Double) meanMagnitudes2.getOrDefault(str2, Double.valueOf(NAN_REPLACEMENT_VALUE))).doubleValue();
                            if (!Double.isFinite(doubleValue)) {
                                doubleValue = 0.0d;
                            }
                            if (!Double.isFinite(doubleValue2)) {
                                doubleValue2 = 0.0d;
                            }
                            double d = (doubleValue2 == NAN_REPLACEMENT_VALUE && doubleValue == NAN_REPLACEMENT_VALUE) ? 0.0d : doubleValue2 / doubleValue;
                            List list3 = (List) hashMap.get(substring);
                            if (list3 == null) {
                                list3 = new ArrayList();
                                hashMap.put(substring, list3);
                            }
                            list3.add(Double.valueOf(d));
                            List list4 = (List) hashMap2.get(substring);
                            if (list4 == null) {
                                list4 = new ArrayList();
                                hashMap2.put(substring, list4);
                            }
                            list4.add(Double.valueOf(doubleValue));
                            List list5 = (List) hashMap3.get(substring);
                            if (list5 == null) {
                                list5 = new ArrayList();
                                hashMap3.put(substring, list5);
                            }
                            list5.add(Double.valueOf(doubleValue2));
                        }
                    }
                }
            }
        }
        return new MeanMagnitudes(arrayList, hashMap, hashMap2, hashMap3);
    }

    private Triple<int[], float[], float[]> getLayerActivations(int i, TrainModuleUtils.GraphInfo graphInfo, List<Persistable> list, List<Integer> list2) {
        if (graphInfo != null && !"input".equalsIgnoreCase(graphInfo.getVertexTypes().get(i))) {
            List<String> originalVertexName = graphInfo.getOriginalVertexName();
            if (i < 0 || i >= originalVertexName.size()) {
                return EMPTY_TRIPLE;
            }
            String str = originalVertexName.get(i);
            int size = list == null ? 0 : list.size();
            int[] iArr = new int[size];
            float[] fArr = new float[size];
            float[] fArr2 = new float[size];
            int i2 = 0;
            if (list != null) {
                int i3 = -1;
                Iterator<Persistable> it = list.iterator();
                while (it.hasNext()) {
                    StatsReport statsReport = (Persistable) it.next();
                    i3++;
                    if (statsReport instanceof StatsReport) {
                        StatsReport statsReport2 = statsReport;
                        if (list2 == null) {
                            iArr[i2] = statsReport2.getIterationCount();
                        } else {
                            iArr[i2] = list2.get(i3).intValue();
                        }
                        Map mean = statsReport2.getMean(StatsType.Activations);
                        Map stdev = statsReport2.getStdev(StatsType.Activations);
                        if (mean != null && mean.containsKey(str)) {
                            fArr[i2] = ((Double) mean.get(str)).floatValue();
                            fArr2[i2] = ((Double) stdev.get(str)).floatValue();
                            if (!Float.isFinite(fArr[i2])) {
                                fArr[i2] = 0.0f;
                            }
                            if (!Float.isFinite(fArr2[i2])) {
                                fArr2[i2] = 0.0f;
                            }
                            i2++;
                        }
                    }
                }
            }
            if (i2 != iArr.length) {
                iArr = Arrays.copyOf(iArr, i2);
                fArr = Arrays.copyOf(fArr, i2);
                fArr2 = Arrays.copyOf(fArr2, i2);
            }
            return new Triple<>(iArr, fArr, fArr2);
        }
        return EMPTY_TRIPLE;
    }

    private Map<String, Object> getLayerLearningRates(int i, TrainModuleUtils.GraphInfo graphInfo, List<Persistable> list, List<Integer> list2, ModelType modelType) {
        if (graphInfo == null) {
            return Collections.emptyMap();
        }
        List<String> originalVertexName = graphInfo.getOriginalVertexName();
        if ("input".equalsIgnoreCase(graphInfo.getVertexTypes().get(i))) {
            return EMPTY_LR_MAP;
        }
        if (i < 0 || i >= originalVertexName.size()) {
            return EMPTY_LR_MAP;
        }
        String str = graphInfo.getOriginalVertexName().get(i);
        int size = list == null ? 0 : list.size();
        int[] iArr = new int[size];
        HashMap hashMap = new HashMap();
        int i2 = 0;
        if (list != null) {
            int i3 = -1;
            Iterator<Persistable> it = list.iterator();
            while (it.hasNext()) {
                StatsReport statsReport = (Persistable) it.next();
                i3++;
                if (statsReport instanceof StatsReport) {
                    StatsReport statsReport2 = statsReport;
                    if (list2 == null) {
                        iArr[i2] = statsReport2.getIterationCount();
                    } else {
                        iArr[i2] = list2.get(i3).intValue();
                    }
                    Map learningRates = statsReport2.getLearningRates();
                    String str2 = modelType == ModelType.Layer ? str : str + "_";
                    for (String str3 : learningRates.keySet()) {
                        if (str3.startsWith(str2)) {
                            String substring = str3.substring(Math.min(str3.length(), str2.length()));
                            if (!hashMap.containsKey(substring)) {
                                hashMap.put(substring, new float[size]);
                            }
                            ((float[]) hashMap.get(substring))[i2] = ((Double) learningRates.get(str3)).floatValue();
                        }
                    }
                    i2++;
                }
            }
        }
        ArrayList arrayList = new ArrayList(hashMap.keySet());
        Collections.sort(arrayList);
        HashMap hashMap2 = new HashMap();
        hashMap2.put("iterCounts", iArr);
        hashMap2.put("paramNames", arrayList);
        hashMap2.put("lrs", hashMap);
        return hashMap2;
    }

    private static Map<String, Object> getHistograms(int i, TrainModuleUtils.GraphInfo graphInfo, StatsType statsType, Persistable persistable) {
        if (persistable == null || !(persistable instanceof StatsReport)) {
            return null;
        }
        String str = graphInfo.getOriginalVertexName().get(i);
        Map histograms = ((StatsReport) persistable).getHistograms(statsType);
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        if (str != null) {
            for (String str2 : histograms.keySet()) {
                if (str2.startsWith(str)) {
                    String substring = str2.charAt(str.length()) == '_' ? str2.substring(str.length() + 1) : str2.substring(str.length());
                    arrayList.add(substring);
                    Histogram histogram = (Histogram) histograms.get(str2);
                    HashMap hashMap2 = new HashMap();
                    double min = histogram.getMin();
                    double max = histogram.getMax();
                    if (Double.isNaN(min)) {
                        min = 0.0d;
                        max = 0.0d;
                    }
                    hashMap2.put("min", Double.valueOf(min));
                    hashMap2.put("max", Double.valueOf(max));
                    hashMap2.put("bins", Integer.valueOf(histogram.getNBins()));
                    hashMap2.put("counts", histogram.getBinCounts());
                    hashMap.put(substring, hashMap2);
                }
            }
        }
        hashMap.put("paramNames", arrayList);
        return hashMap;
    }

    private static Map<String, Object> getMemory(List<Persistable> list, List<Persistable> list2, I18N i18n) {
        HashMap hashMap = new HashMap();
        HashSet hashSet = new HashSet();
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        HashMap hashMap4 = new HashMap();
        Iterator<Persistable> it = list.iterator();
        while (it.hasNext()) {
            StatsInitializationReport statsInitializationReport = (Persistable) it.next();
            StatsInitializationReport statsInitializationReport2 = statsInitializationReport;
            String swJvmUID = statsInitializationReport2.getSwJvmUID();
            hashMap2.put(statsInitializationReport.getWorkerID(), swJvmUID);
            hashSet.add(swJvmUID);
            int hwNumDevices = statsInitializationReport2.getHwNumDevices();
            hashMap3.put(statsInitializationReport.getWorkerID(), Integer.valueOf(hwNumDevices));
            if (hwNumDevices > 0) {
                hashMap4.put(statsInitializationReport.getWorkerID(), statsInitializationReport2.getHwDeviceDescription());
            }
        }
        ArrayList<String> arrayList = new ArrayList(hashSet);
        Collections.sort(arrayList);
        int i = 0;
        for (String str : arrayList) {
            ArrayList arrayList2 = new ArrayList();
            for (String str2 : hashMap2.keySet()) {
                if (((String) hashMap2.get(str2)).equals(str)) {
                    arrayList2.add(str2);
                }
            }
            Collections.sort(arrayList2);
            String str3 = (String) arrayList2.get(0);
            int intValue = ((Integer) hashMap3.get(str3)).intValue();
            HashMap hashMap5 = new HashMap();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            ArrayList arrayList5 = new ArrayList();
            long[] jArr = new long[2 + intValue];
            long[] jArr2 = new long[2 + intValue];
            ArrayList arrayList6 = null;
            if (intValue > 0) {
                arrayList6 = new ArrayList(intValue);
                for (int i2 = 0; i2 < intValue; i2++) {
                    arrayList6.add(new ArrayList());
                }
            }
            Iterator<Persistable> it2 = list2.iterator();
            while (it2.hasNext()) {
                StatsReport statsReport = (Persistable) it2.next();
                if (statsReport.getWorkerID().equals(str3) && (statsReport instanceof StatsReport)) {
                    StatsReport statsReport2 = statsReport;
                    arrayList3.add(Long.valueOf(statsReport2.getTimeStamp()));
                    long jvmCurrentBytes = statsReport2.getJvmCurrentBytes();
                    long jvmMaxBytes = statsReport2.getJvmMaxBytes();
                    long offHeapCurrentBytes = statsReport2.getOffHeapCurrentBytes();
                    long offHeapMaxBytes = statsReport2.getOffHeapMaxBytes();
                    double d = jvmCurrentBytes / jvmMaxBytes;
                    double d2 = offHeapCurrentBytes / offHeapMaxBytes;
                    if (Double.isNaN(d)) {
                        d = 0.0d;
                    }
                    if (Double.isNaN(d2)) {
                        d2 = 0.0d;
                    }
                    arrayList4.add(Float.valueOf((float) d));
                    arrayList5.add(Float.valueOf((float) d2));
                    jArr[0] = jvmCurrentBytes;
                    jArr[1] = offHeapCurrentBytes;
                    jArr2[0] = jvmMaxBytes;
                    jArr2[1] = offHeapMaxBytes;
                    if (intValue > 0) {
                        long[] deviceCurrentBytes = statsReport2.getDeviceCurrentBytes();
                        long[] deviceMaxBytes = statsReport2.getDeviceMaxBytes();
                        for (int i3 = 0; i3 < intValue; i3++) {
                            double d3 = deviceCurrentBytes[i3] / deviceMaxBytes[i3];
                            if (Double.isNaN(d3)) {
                                d3 = 0.0d;
                            }
                            ((List) arrayList6.get(i3)).add(Float.valueOf((float) d3));
                            jArr[2 + i3] = deviceCurrentBytes[i3];
                            jArr2[2 + i3] = deviceMaxBytes[i3];
                        }
                    }
                }
            }
            ArrayList arrayList7 = new ArrayList();
            arrayList7.add(arrayList4);
            arrayList7.add(arrayList5);
            String[] strArr = new String[2 + intValue];
            strArr[0] = i18n.getMessage("train.system.hwTable.jvmCurrent");
            strArr[1] = i18n.getMessage("train.system.hwTable.offHeapCurrent");
            boolean[] zArr = new boolean[2 + intValue];
            String[] strArr2 = (String[]) hashMap4.get(str3);
            int i4 = 0;
            while (i4 < intValue) {
                strArr[2 + i4] = (strArr2 == null || strArr2.length <= i4) ? "" : strArr2[i4];
                arrayList7.add(arrayList6.get(i4));
                zArr[2 + i4] = true;
                i4++;
            }
            hashMap5.put("times", arrayList3);
            hashMap5.put("isDevice", zArr);
            hashMap5.put("seriesNames", strArr);
            hashMap5.put("values", arrayList7);
            hashMap5.put("currentBytes", jArr);
            hashMap5.put("maxBytes", jArr2);
            hashMap.put(String.valueOf(i), hashMap5);
            i++;
        }
        return hashMap;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:47:0x02b5, code lost:
    
        switch(r26) {
            case 0: goto L45;
            case 1: goto L46;
            default: goto L47;
        };
     */
    /* JADX WARN: Code restructure failed: missing block: B:48:0x02d0, code lost:
    
        r24 = "CPU";
     */
    /* JADX WARN: Code restructure failed: missing block: B:49:0x02e4, code lost:
    
        r22 = r24;
     */
    /* JADX WARN: Code restructure failed: missing block: B:50:0x02d8, code lost:
    
        r24 = "CUDA";
     */
    /* JADX WARN: Code restructure failed: missing block: B:51:0x02e0, code lost:
    
        r24 = r0;
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private static org.nd4j.linalg.primitives.Pair<java.util.Map<java.lang.String, java.lang.Object>, java.util.Map<java.lang.String, java.lang.Object>> getHardwareSoftwareInfo(java.util.List<org.deeplearning4j.api.storage.Persistable> r7, org.deeplearning4j.ui.api.I18N r8) {
        /*
            Method dump skipped, instructions count: 1039
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.ui.module.train.TrainModule.getHardwareSoftwareInfo(java.util.List, org.deeplearning4j.ui.api.I18N):org.nd4j.linalg.primitives.Pair");
    }

    private static final String asJson(Object obj) {
        try {
            return JSON.writeValueAsString(obj);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.ui.api.UIModule
    public List<I18NResource> getInternationalizationResources() {
        ArrayList arrayList = new ArrayList();
        String[] strArr = {"de", "en", "ja", "ko", "ru", "zh"};
        addAll(arrayList, "train", strArr);
        addAll(arrayList, "train.model", strArr);
        addAll(arrayList, "train.overview", strArr);
        addAll(arrayList, "train.system", strArr);
        return arrayList;
    }

    private static void addAll(List<I18NResource> list, String str, String... strArr) {
        for (String str2 : strArr) {
            list.add(new I18NResource("dl4j_i18n/" + str + "." + str2));
        }
    }

    static {
        EMPTY_LR_MAP.put("iterCounts", new int[0]);
        EMPTY_LR_MAP.put("paramNames", Collections.EMPTY_LIST);
        EMPTY_LR_MAP.put("lrs", Collections.EMPTY_MAP);
    }
}
