/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.common.tests;

import java.lang.management.ManagementFactory;
import java.lang.reflect.Method;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestInfo;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseND4JTest {
    private static final Logger log = LoggerFactory.getLogger(BaseND4JTest.class);
    protected long startTime;
    protected int threadCountBefore;
    private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
    protected Boolean integrationTest;

    public long getTimeoutMilliseconds() {
        return 90000L;
    }

    public OpExecutioner.ProfilingMode getProfilingMode() {
        return OpExecutioner.ProfilingMode.SCOPE_PANIC;
    }

    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public DataType getDefaultFPDataType() {
        return this.getDataType();
    }

    public int numThreads() {
        return this.DEFAULT_THREADS;
    }

    public boolean isIntegrationTests() {
        if (this.integrationTest == null) {
            String prop = System.getenv("DL4J_INTEGRATION_TESTS");
            this.integrationTest = Boolean.parseBoolean(prop);
        }
        return this.integrationTest;
    }

    public void skipUnlessIntegrationTests() {
        Assumptions.assumeTrue((boolean)this.isIntegrationTests(), (String)"Skipping integration test - integration profile is not enabled");
    }

    @BeforeEach
    public void beforeTest(TestInfo testInfo) {
        log.info("{}.{}", (Object)this.getClass().getSimpleName(), (Object)((Method)testInfo.getTestMethod().get()).getName());
        System.setProperty("org.nd4j.log.initialization", "false");
        System.setProperty("org.nd4j.avx.ignore", "true");
        Nd4j.getExecutioner().setProfilingMode(this.getProfilingMode());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.setDefaultDataTypes((DataType)this.getDataType(), (DataType)this.getDefaultFPDataType());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.getExecutioner().enableDebugMode(false);
        Nd4j.getExecutioner().enableVerboseMode(false);
        int numThreads = this.numThreads();
        Preconditions.checkState((numThreads > 0 ? 1 : 0) != 0, (String)"Number of threads must be > 0");
        if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
            Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
        }
        this.startTime = System.currentTimeMillis();
        this.threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
    }

    @AfterEach
    public void afterTest(TestInfo testInfo) {
        List l;
        Properties p;
        Object o;
        Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
        MemoryWorkspace currWS = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace(null);
        if (currWS != null) {
            log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", new Object[]{currWS.getId(), currWS.isScopeActive(), currWS});
            System.out.println("Open workspace leaked from test! Exiting - " + currWS.getId() + ", isOpen = " + currWS.isScopeActive() + " - " + currWS);
            System.out.flush();
            try {
                Thread.sleep(1000L);
            }
            catch (InterruptedException interruptedException) {
                // empty catch block
            }
            ILoggerFactory lf = LoggerFactory.getILoggerFactory();
            if (lf.getClass().getName().equals("ch.qos.logback.classic.LoggerContext")) {
                Method method = lf.getClass().getMethod("stop", new Class[0]);
                method.setAccessible(true);
                method.invoke((Object)lf, new Object[0]);
            }
            try {
                Thread.sleep(1000L);
            }
            catch (InterruptedException method) {
                // empty catch block
            }
            System.exit(1);
        }
        StringBuilder sb = new StringBuilder();
        long maxPhys = Pointer.maxPhysicalBytes();
        long maxBytes = Pointer.maxBytes();
        long currPhys = Pointer.physicalBytes();
        long currBytes = Pointer.totalBytes();
        long jvmTotal = Runtime.getRuntime().totalMemory();
        long jvmMax = Runtime.getRuntime().maxMemory();
        int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount();
        long duration = System.currentTimeMillis() - this.startTime;
        sb.append(this.getClass().getSimpleName()).append(".").append(((Method)testInfo.getTestMethod().get()).getName()).append(": ").append(duration).append(" ms").append(", threadCount: (").append(this.threadCountBefore).append("->").append(threadsAfter).append(")").append(", jvmTotal=").append(jvmTotal).append(", jvmMax=").append(jvmMax).append(", totalBytes=").append(currBytes).append(", maxBytes=").append(maxBytes).append(", currPhys=").append(currPhys).append(", maxPhys=").append(maxPhys);
        List ws = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
        if (ws != null && ws.size() > 0) {
            long currSize = 0L;
            for (MemoryWorkspace w : ws) {
                currSize += w.getCurrentSize();
            }
            if (currSize > 0L) {
                sb.append(", threadWSSize=").append(currSize).append(" (").append(ws.size()).append(" WSs)");
            }
        }
        if ((o = (p = Nd4j.getExecutioner().getEnvironmentInformation()).get("cuda.devicesInformation")) instanceof List && (l = (List)o).size() > 0) {
            sb.append(" [").append(l.size()).append(" GPUs: ");
            for (int i = 0; i < l.size(); ++i) {
                Map m = (Map)l.get(i);
                if (i > 0) {
                    sb.append(",");
                }
                sb.append("(").append(m.get("cuda.freeMemory")).append(" free, ").append(m.get("cuda.totalMemory")).append(" total)");
            }
            sb.append("]");
        }
        log.info(sb.toString());
    }
}

