package org.deeplearning4j.ui;

import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.impl.MimeMapping;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.io.FilenameUtils;
import org.deeplearning4j.common.util.ND4JFileUtils;
import org.deeplearning4j.config.DL4JClassLoading;
import org.deeplearning4j.core.storage.StatsStorage;
import org.deeplearning4j.core.storage.StatsStorageEvent;
import org.deeplearning4j.core.storage.StatsStorageListener;
import org.deeplearning4j.core.storage.StatsStorageRouter;
import org.deeplearning4j.exception.DL4JException;
import org.deeplearning4j.ui.api.Route;
import org.deeplearning4j.ui.api.UIModule;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.i18n.I18NProvider;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.model.storage.impl.QueueStatsStorageListener;
import org.deeplearning4j.ui.module.SameDiffModule;
import org.deeplearning4j.ui.module.convolutional.ConvolutionalListenerModule;
import org.deeplearning4j.ui.module.defaultModule.DefaultModule;
import org.deeplearning4j.ui.module.remote.RemoteReceiverModule;
import org.deeplearning4j.ui.module.train.TrainModule;
import org.deeplearning4j.ui.module.tsne.TsneModule;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/ui/VertxUIServer.class */
public class VertxUIServer extends AbstractVerticle implements UIServer {
    public static final int DEFAULT_UI_PORT = 9000;
    public static final String ASSETS_ROOT_DIRECTORY = "deeplearning4jUiAssets/";
    private static VertxUIServer instance;
    private static Function<String, StatsStorage> statsStorageProvider;
    private static Integer instancePort;
    private static Thread shutdownHook;
    private RemoteReceiverModule remoteReceiverModule;
    private Function<String, Boolean> statsStorageLoader;
    private HttpServer server;
    private Thread uiEventRoutingThread;
    private static final Logger log = LoggerFactory.getLogger(VertxUIServer.class);
    private static AtomicBoolean multiSession = new AtomicBoolean(false);
    private List<UIModule> uiModules = new CopyOnWriteArrayList();
    private Map<String, List<UIModule>> typeIDModuleMap = new ConcurrentHashMap();
    private AtomicBoolean shutdown = new AtomicBoolean(false);
    private long uiProcessingDelay = 500;
    private final BlockingQueue<StatsStorageEvent> eventQueue = new LinkedBlockingQueue();
    private List<Pair<StatsStorage, StatsStorageListener>> listeners = new CopyOnWriteArrayList();
    private List<StatsStorage> statsStorageInstances = new CopyOnWriteArrayList();

    /* loaded from: input_file:org/deeplearning4j/ui/VertxUIServer$CLIParams.class */
    private static class CLIParams {

        @Parameter(names = {"-r", "--enableRemote"}, description = "Whether to enable remote or not", arity = 1)
        private boolean cliEnableRemote;

        @Parameter(names = {"-p", "--uiPort"}, description = "Custom HTTP port for UI", arity = 1)
        private int cliPort = VertxUIServer.DEFAULT_UI_PORT;

        @Parameter(names = {"-f", "--customStatsFile"}, description = "Path to create custom stats file (remote only)", arity = 1)
        private String cliCustomStatsFile;

        @Parameter(names = {"-m", "--multiSession"}, description = "Whether to enable multiple separate browser sessions or not", arity = 1)
        private boolean cliMultiSession;

        public boolean isCliEnableRemote() {
            return this.cliEnableRemote;
        }

        public int getCliPort() {
            return this.cliPort;
        }

        public String getCliCustomStatsFile() {
            return this.cliCustomStatsFile;
        }

        public boolean isCliMultiSession() {
            return this.cliMultiSession;
        }

        public void setCliEnableRemote(boolean z) {
            this.cliEnableRemote = z;
        }

        public void setCliPort(int i) {
            this.cliPort = i;
        }

        public void setCliCustomStatsFile(String str) {
            this.cliCustomStatsFile = str;
        }

        public void setCliMultiSession(boolean z) {
            this.cliMultiSession = z;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof CLIParams)) {
                return false;
            }
            CLIParams cLIParams = (CLIParams) obj;
            if (!cLIParams.canEqual(this) || isCliEnableRemote() != cLIParams.isCliEnableRemote() || getCliPort() != cLIParams.getCliPort() || isCliMultiSession() != cLIParams.isCliMultiSession()) {
                return false;
            }
            String cliCustomStatsFile = getCliCustomStatsFile();
            String cliCustomStatsFile2 = cLIParams.getCliCustomStatsFile();
            return cliCustomStatsFile == null ? cliCustomStatsFile2 == null : cliCustomStatsFile.equals(cliCustomStatsFile2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof CLIParams;
        }

        public int hashCode() {
            int cliPort = (((((1 * 59) + (isCliEnableRemote() ? 79 : 97)) * 59) + getCliPort()) * 59) + (isCliMultiSession() ? 79 : 97);
            String cliCustomStatsFile = getCliCustomStatsFile();
            return (cliPort * 59) + (cliCustomStatsFile == null ? 43 : cliCustomStatsFile.hashCode());
        }

        public String toString() {
            return "VertxUIServer.CLIParams(cliEnableRemote=" + isCliEnableRemote() + ", cliPort=" + getCliPort() + ", cliCustomStatsFile=" + getCliCustomStatsFile() + ", cliMultiSession=" + isCliMultiSession() + ")";
        }
    }

    /* loaded from: input_file:org/deeplearning4j/ui/VertxUIServer$StatsEventRouterRunnable.class */
    private class StatsEventRouterRunnable implements Runnable {
        private StatsEventRouterRunnable() {
        }

        @Override // java.lang.Runnable
        public void run() {
            try {
                runHelper();
            } catch (Exception e) {
                VertxUIServer.log.error("Unexpected exception from Event routing runnable", e);
            }
        }

        private void runHelper() throws Exception {
            VertxUIServer.log.trace("VertxUIServer.StatsEventRouterRunnable started");
            while (!VertxUIServer.this.shutdown.get()) {
                ArrayList<StatsStorageEvent> arrayList = new ArrayList();
                arrayList.add(VertxUIServer.this.eventQueue.take());
                VertxUIServer.this.eventQueue.drainTo(arrayList);
                for (UIModule uIModule : VertxUIServer.this.uiModules) {
                    List<String> callbackTypeIDs = uIModule.getCallbackTypeIDs();
                    ArrayList arrayList2 = new ArrayList();
                    for (StatsStorageEvent statsStorageEvent : arrayList) {
                        if (callbackTypeIDs.contains(statsStorageEvent.getTypeID()) && VertxUIServer.this.statsStorageInstances.contains(statsStorageEvent.getStatsStorage())) {
                            arrayList2.add(statsStorageEvent);
                        }
                    }
                    uIModule.reportStorageEvents(arrayList2);
                }
                arrayList.clear();
                try {
                    Thread.sleep(VertxUIServer.this.uiProcessingDelay);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    if (!VertxUIServer.this.shutdown.get()) {
                        throw new RuntimeException("Unexpected interrupted exception", e);
                    }
                }
            }
        }
    }

    public static VertxUIServer getInstance(Integer num, boolean z, Function<String, StatsStorage> function) throws DL4JException {
        return getInstance(num, z, function, null);
    }

    public static VertxUIServer getInstance(Integer num, boolean z, Function<String, StatsStorage> function, Promise<String> promise) throws DL4JException {
        if (instance == null || instance.isStopped()) {
            multiSession.set(z);
            setStatsStorageProvider(function);
            instancePort = num;
            if (promise != null) {
                deploy(promise);
            } else {
                deploy();
            }
        } else if (!instance.isStopped()) {
            if (z && !instance.isMultiSession()) {
                throw new DL4JException("Cannot return multi-session instance. UIServer has already started in single-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one.");
            }
            if (!z && instance.isMultiSession()) {
                throw new DL4JException("Cannot return single-session instance. UIServer has already started in multi-session mode at " + instance.getAddress() + " You may stop the UI server instance, and start a new one.");
            }
        }
        return instance;
    }

    private static void deploy() throws DL4JException {
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Promise promise = Promise.promise();
        promise.future().compose(str -> {
            return Future.future(promise2 -> {
                countDownLatch.countDown();
            });
        }, th -> {
            return Future.future(promise2 -> {
                countDownLatch.countDown();
            });
        });
        deploy(promise);
        try {
            countDownLatch.await();
            Future future = promise.future();
            if (future.failed()) {
                throw new DL4JException("Deeplearning4j UI server failed to start.", future.cause());
            }
        } catch (InterruptedException e) {
            throw new DL4JException(e);
        }
    }

    private static void deploy(Promise<String> promise) {
        log.debug("Deeplearning4j UI server is starting.");
        Promise promise2 = Promise.promise();
        promise2.future().compose(str -> {
            return Future.future(promise3 -> {
                promise.complete(str);
            });
        }, th -> {
            return Future.future(promise3 -> {
                promise.fail(new RuntimeException(th));
            });
        });
        Vertx.vertx().deployVerticle(VertxUIServer.class.getName(), promise2);
        shutdownHook = new Thread(() -> {
            if (instance == null || instance.isStopped()) {
                return;
            }
            log.info("Deeplearning4j UI server is auto-stopping in shutdown hook.");
            try {
                instance.stop();
            } catch (InterruptedException e) {
                log.error("Interrupted stopping of Deeplearning4j UI server in shutdown hook.", e);
            }
        });
        Runtime.getRuntime().addShutdownHook(shutdownHook);
    }

    public VertxUIServer() {
        instance = this;
    }

    public static void stopInstance() throws Exception {
        if (instance == null || instance.isStopped()) {
            return;
        }
        instance.stop();
        reset();
    }

    private static void reset() {
        instance = null;
        statsStorageProvider = null;
        instancePort = null;
        multiSession.set(false);
    }

    public void autoAttachStatsStorageBySessionId(Function<String, StatsStorage> function) {
        if (function != null) {
            this.statsStorageLoader = str -> {
                log.info("Loading StatsStorage via StatsStorageProvider for session ID (" + str + ").");
                StatsStorage statsStorage = (StatsStorage) function.apply(str);
                if (statsStorage == null) {
                    log.info("Failed to load StatsStorage via StatsStorageProvider for session ID (" + str + "). StatsStorageProvider returned null.");
                    return false;
                }
                if (statsStorage.sessionExists(str)) {
                    attach(statsStorage);
                    return true;
                }
                log.info("Failed to load StatsStorage via StatsStorageProvider for session ID. Session ID (" + str + ") does not exist in StatsStorage.");
                return false;
            };
        }
    }

    public void start(Promise<Void> promise) throws Exception {
        File file = new File(System.getProperty("java.io.tmpdir"), "DL4JUI_" + System.currentTimeMillis());
        file.mkdirs();
        Router router = Router.router(this.vertx);
        router.route().handler(BodyHandler.create().setUploadsDirectory(file.getAbsolutePath()));
        router.get("/assets/*").handler(routingContext -> {
            String str;
            String substring = routingContext.request().path().substring(8);
            if (substring.contains("webjars")) {
                str = "META-INF/resources/" + substring.substring(substring.indexOf("webjars"));
            } else {
                str = "deeplearning4jUiAssets/" + (substring.startsWith("/") ? substring.substring(1) : substring);
            }
            routingContext.response().putHeader("content-type", MimeMapping.getMimeTypeForFilename(FilenameUtils.getName(str))).sendFile(str);
        });
        if (isMultiSession()) {
            router.get("/setlang/:sessionId/:to").handler(routingContext2 -> {
                String param = routingContext2.request().getParam("sessionID");
                I18NProvider.getInstance(param).setDefaultLanguage(routingContext2.request().getParam("to"));
                routingContext2.response().end();
            });
        } else {
            router.get("/setlang/:to").handler(routingContext3 -> {
                I18NProvider.getInstance().setDefaultLanguage(routingContext3.request().getParam("to"));
                routingContext3.response().end();
            });
        }
        if (statsStorageProvider != null) {
            autoAttachStatsStorageBySessionId(statsStorageProvider);
        }
        this.uiModules.add(new DefaultModule(isMultiSession()));
        this.uiModules.add(new TrainModule());
        this.uiModules.add(new ConvolutionalListenerModule());
        this.uiModules.add(new TsneModule());
        this.uiModules.add(new SameDiffModule());
        this.remoteReceiverModule = new RemoteReceiverModule();
        this.uiModules.add(this.remoteReceiverModule);
        modulesViaServiceLoader(this.uiModules);
        for (UIModule uIModule : this.uiModules) {
            for (Route route : uIModule.getRoutes()) {
                switch (route.getHttpMethod()) {
                    case GET:
                        router.get(route.getRoute()).handler(routingContext4 -> {
                            route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), routingContext4), routingContext4);
                        });
                        break;
                    case PUT:
                        router.put(route.getRoute()).handler(routingContext5 -> {
                            route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), routingContext5), routingContext5);
                        });
                        break;
                    case POST:
                        router.post(route.getRoute()).handler(routingContext6 -> {
                            route.getConsumer().accept(extractArgsFromRoute(route.getRoute(), routingContext6), routingContext6);
                        });
                        break;
                    default:
                        throw new IllegalStateException("Unknown or not supported HTTP method: " + route.getHttpMethod());
                }
            }
            for (String str : uIModule.getCallbackTypeIDs()) {
                List<UIModule> list = this.typeIDModuleMap.get(str);
                if (list == null) {
                    list = Collections.synchronizedList(new ArrayList());
                    this.typeIDModuleMap.put(str, list);
                }
                list.add(uIModule);
            }
        }
        int intValue = instancePort == null ? DEFAULT_UI_PORT : instancePort.intValue();
        String property = System.getProperty("org.deeplearning4j.ui.port");
        if (property != null && !property.isEmpty()) {
            try {
                intValue = Integer.parseInt(property);
            } catch (NumberFormatException e) {
                log.warn("Error parsing port property {}={}", "org.deeplearning4j.ui.port", property);
            }
        }
        if (intValue < 0 || intValue > 65535) {
            throw new IllegalStateException("Valid port range is 0 <= port <= 65535. The given port was " + intValue);
        }
        this.uiEventRoutingThread = new Thread(new StatsEventRouterRunnable());
        this.uiEventRoutingThread.setDaemon(true);
        this.uiEventRoutingThread.start();
        this.server = this.vertx.createHttpServer().requestHandler(router).listen(intValue, asyncResult -> {
            if (!asyncResult.succeeded()) {
                promise.fail(new RuntimeException("Deeplearning4j UI server failed to listen on port " + this.server.actualPort(), asyncResult.cause()));
                return;
            }
            log.info("Deeplearning4j UI server started at: {}", UIServer.getInstance().getAddress());
            promise.complete();
        });
    }

    private List<String> extractArgsFromRoute(String str, RoutingContext routingContext) {
        if (!str.contains(":")) {
            return Collections.emptyList();
        }
        String[] split = str.split("/");
        ArrayList arrayList = new ArrayList();
        for (String str2 : split) {
            if (str2.startsWith(":")) {
                arrayList.add(routingContext.request().getParam(str2.substring(1)));
            }
        }
        return arrayList;
    }

    private void modulesViaServiceLoader(List<UIModule> list) {
        Iterator it = DL4JClassLoading.loadService(UIModule.class).iterator();
        if (it.hasNext()) {
            while (it.hasNext()) {
                UIModule uIModule = (UIModule) it.next();
                Class<?> cls = uIModule.getClass();
                boolean z = false;
                Iterator<UIModule> it2 = list.iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    } else if (it2.next().getClass() == cls) {
                        z = true;
                        break;
                    }
                }
                if (!z) {
                    log.debug("Loaded UI module via service loader: {}", uIModule.getClass());
                    list.add(uIModule);
                }
            }
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void stop() throws InterruptedException {
        CountDownLatch countDownLatch = new CountDownLatch(1);
        Promise<Void> promise = Promise.promise();
        promise.future().compose(r3 -> {
            return Future.future(promise2 -> {
                countDownLatch.countDown();
            });
        }, th -> {
            return Future.future(promise2 -> {
                countDownLatch.countDown();
            });
        });
        stopAsync(promise);
        countDownLatch.await();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void stopAsync(Promise<Void> promise) {
        this.vertx.close(asyncResult -> {
            promise.handle(asyncResult);
        });
    }

    public void stop(Promise<Void> promise) {
        this.shutdown.set(true);
        promise.complete();
        log.info("Deeplearning4j UI server stopped.");
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isStopped() {
        return this.shutdown.get();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isMultiSession() {
        return multiSession.get();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public String getAddress() {
        return "http://localhost:" + this.server.actualPort();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public int getPort() {
        return this.server.actualPort();
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void attach(StatsStorage statsStorage) {
        if (statsStorage == null) {
            throw new IllegalArgumentException("StatsStorage cannot be null");
        }
        if (this.statsStorageInstances.contains(statsStorage)) {
            return;
        }
        QueueStatsStorageListener queueStatsStorageListener = new QueueStatsStorageListener(this.eventQueue);
        this.listeners.add(new Pair<>(statsStorage, queueStatsStorageListener));
        statsStorage.registerStatsStorageListener(queueStatsStorageListener);
        this.statsStorageInstances.add(statsStorage);
        Iterator<UIModule> it = this.uiModules.iterator();
        while (it.hasNext()) {
            it.next().onAttach(statsStorage);
        }
        log.info("StatsStorage instance attached to UI: {}", statsStorage);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void detach(StatsStorage statsStorage) {
        if (statsStorage == null) {
            throw new IllegalArgumentException("StatsStorage cannot be null");
        }
        if (this.statsStorageInstances.contains(statsStorage)) {
            boolean z = false;
            for (Pair<StatsStorage, StatsStorageListener> pair : this.listeners) {
                if (pair.getFirst() == statsStorage) {
                    statsStorage.deregisterStatsStorageListener((StatsStorageListener) pair.getSecond());
                    this.listeners.remove(pair);
                    z = true;
                }
            }
            this.statsStorageInstances.remove(statsStorage);
            Iterator<UIModule> it = this.uiModules.iterator();
            while (it.hasNext()) {
                it.next().onDetach(statsStorage);
            }
            Iterator it2 = statsStorage.listSessionIDs().iterator();
            while (it2.hasNext()) {
                I18NProvider.removeInstance((String) it2.next());
            }
            if (z) {
                log.info("StatsStorage instance detached from UI: {}", statsStorage);
            }
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isAttached(StatsStorage statsStorage) {
        return this.statsStorageInstances.contains(statsStorage);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public List<StatsStorage> getStatsStorageInstances() {
        return new ArrayList(this.statsStorageInstances);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void enableRemoteListener() {
        if (this.remoteReceiverModule == null) {
            this.remoteReceiverModule = new RemoteReceiverModule();
        }
        if (this.remoteReceiverModule.isEnabled()) {
            return;
        }
        enableRemoteListener(new InMemoryStatsStorage(), true);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void enableRemoteListener(StatsStorageRouter statsStorageRouter, boolean z) {
        this.remoteReceiverModule.setEnabled(true);
        this.remoteReceiverModule.setStatsStorage(statsStorageRouter);
        if (z && (statsStorageRouter instanceof StatsStorage)) {
            attach((StatsStorage) statsStorageRouter);
        }
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public void disableRemoteListener() {
        this.remoteReceiverModule.setEnabled(false);
    }

    @Override // org.deeplearning4j.ui.api.UIServer
    public boolean isRemoteListenerEnabled() {
        return this.remoteReceiverModule.isEnabled();
    }

    public void main(String[] strArr) {
        CLIParams cLIParams = new CLIParams();
        new JCommander(cLIParams).parse(strArr);
        instancePort = Integer.valueOf(cLIParams.getCliPort());
        UIServer.getInstance(cLIParams.isCliMultiSession(), null);
        if (cLIParams.isCliEnableRemote()) {
            try {
                File createTempFile = ND4JFileUtils.createTempFile("dl4j", "UIstats");
                createTempFile.delete();
                createTempFile.deleteOnExit();
                enableRemoteListener(new FileStatsStorage(createTempFile), true);
            } catch (Exception e) {
                log.error("Failed to create temporary file for stats storage", e);
                System.exit(1);
            }
        }
    }

    public static VertxUIServer getInstance() {
        return instance;
    }

    public static AtomicBoolean getMultiSession() {
        return multiSession;
    }

    public static Function<String, StatsStorage> getStatsStorageProvider() {
        return statsStorageProvider;
    }

    public static void setStatsStorageProvider(Function<String, StatsStorage> function) {
        statsStorageProvider = function;
    }

    public static Thread getShutdownHook() {
        return shutdownHook;
    }

    public Function<String, Boolean> getStatsStorageLoader() {
        return this.statsStorageLoader;
    }
}
