/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.context.impl;

import java.lang.ref.ReferenceQueue;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.LockSupport;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.context.ContextPack;
import org.nd4j.jita.allocator.context.impl.BasicContextPool;
import org.nd4j.jita.allocator.garbage.GarbageResourceReference;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.CudaPointer;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.allocator.pointers.cuda.cusolverDnHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.context.CudaContext;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LimitedContextPool
extends BasicContextPool {
    private static final Logger log = LoggerFactory.getLogger(LimitedContextPool.class);
    protected Map<Integer, LinkedBlockingQueue<CudaContext>> pool = new HashMap<Integer, LinkedBlockingQueue<CudaContext>>();
    protected Map<Long, CudaContext> acquired = new ConcurrentHashMap<Long, CudaContext>();
    protected AtomicInteger currentPoolSize = new AtomicInteger(0);
    protected Map<Integer, ResourceGarbageCollectorThread> collectors = new HashMap<Integer, ResourceGarbageCollectorThread>();
    protected Map<Integer, ReferenceQueue<Thread>> queueMap = new HashMap<Integer, ReferenceQueue<Thread>>();

    public LimitedContextPool() {
        int perDevicePool = CudaEnvironment.getInstance().getConfiguration().getPoolSize();
        for (int i = 0; i < 4; ++i) {
            ReferenceQueue<Thread> queue = new ReferenceQueue<Thread>();
            ResourceGarbageCollectorThread collector = new ResourceGarbageCollectorThread(i, queue);
            collector.start();
            this.collectors.put(i, collector);
            this.queueMap.put(i, queue);
        }
        this.fillPoolWithResources(perDevicePool, false);
        this.currentPoolSize.set(perDevicePool);
    }

    protected void addResourcesToPool(int numResources) {
        int device = AtomicAllocator.getInstance().getDeviceId();
        cublasHandle_t handle = this.createNewCublasHandle();
        for (int cnt = 0; cnt < numResources; ++cnt) {
            CudaContext context = this.createNewStream(device);
            context.initOldStream();
            this.getDeviceBuffers(context, device);
            context.setHandle(handle);
            context.syncOldStream();
            this.pool.get(device).add(context);
        }
    }

    protected synchronized void fillPoolWithResources(int numResources, boolean restoreDevice) {
        List<Integer> devices = CudaEnvironment.getInstance().getConfiguration().getAvailableDevices();
        int cDevice = 0;
        if (restoreDevice) {
            cDevice = AtomicAllocator.getInstance().getDeviceId();
        }
        NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
        for (Integer device : devices) {
            nativeOps.setDevice((Pointer)new CudaPointer(device.intValue()));
            this.pool.put(device, new LinkedBlockingQueue());
            cublasHandle_t handle = this.createNewCublasHandle();
            cusolverDnHandle_t solverHandle = this.createNewSolverHandle();
            for (int cnt = 0; cnt < numResources; ++cnt) {
                CudaContext context = this.createNewStream(device);
                context.initOldStream();
                this.getDeviceBuffers(context, device);
                context.setHandle(handle);
                context.setSolverHandle(solverHandle);
                context.syncOldStream();
                this.pool.get(device).add(context);
            }
        }
        if (restoreDevice) {
            nativeOps.setDevice((Pointer)new CudaPointer(cDevice));
        }
    }

    @Override
    public CudaContext acquireContextForDevice(Integer deviceId) {
        long threadIdx = Thread.currentThread().getId();
        CudaContext context = this.acquired.get(threadIdx);
        if (context != null && deviceId.intValue() == context.getDeviceId()) {
            return context;
        }
        this.nativeOps.setDevice((Pointer)new CudaPointer(deviceId.intValue()));
        context = this.pool.get(deviceId).poll();
        if (context != null) {
            int col = RandomUtils.nextInt((int)0, (int)this.collectors.size());
            this.collectors.get(col);
            GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), this.queueMap.get(col), context, deviceId);
            context.attachReference(reference);
            this.acquired.put(threadIdx, context);
            context.setDeviceId(deviceId);
            return context;
        }
        do {
            try {
                Nd4j.getMemoryManager().invokeGc();
                context = this.pool.get(deviceId).poll(1L, TimeUnit.SECONDS);
                if (context != null) {
                    int col = RandomUtils.nextInt((int)0, (int)this.collectors.size());
                    this.collectors.get(col);
                    GarbageResourceReference reference = new GarbageResourceReference(Thread.currentThread(), this.queueMap.get(col), context, deviceId);
                    context.attachReference(reference);
                    this.acquired.put(threadIdx, context);
                    context.setDeviceId(deviceId);
                    continue;
                }
                if (this.currentPoolSize.get() < CudaEnvironment.getInstance().getConfiguration().getPoolSize() * 3) {
                    this.addResourcesToPool(16);
                    this.currentPoolSize.addAndGet(16);
                    continue;
                }
                log.warn("Can't allocate new context, sleeping...");
                Nd4j.getMemoryManager().invokeGc();
                try {
                    Thread.sleep(500L);
                }
                catch (Exception col) {}
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        } while (context == null);
        return context;
    }

    @Override
    public ContextPack acquireContextPackForDevice(Integer deviceId) {
        return new ContextPack(this.acquireContextForDevice(deviceId));
    }

    @Override
    public CudaContext getContextForDevice(Integer deviceId) {
        return this.acquireContextForDevice(deviceId);
    }

    private class ResourceGarbageCollectorThread
    extends Thread
    implements Runnable {
        private final ReferenceQueue<Thread> queue;

        public ResourceGarbageCollectorThread(@NonNull int threadId, ReferenceQueue<Thread> queue) {
            if (queue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            this.queue = queue;
            this.setDaemon(true);
            this.setName("ResourceGC thread " + threadId);
        }

        @Override
        public void run() {
            while (true) {
                GarbageResourceReference reference;
                if ((reference = (GarbageResourceReference)this.queue.poll()) != null) {
                    CudaContext context = reference.getContext();
                    Long threadId = reference.getThreadId();
                    int deviceId = reference.getDeviceId();
                    LimitedContextPool.this.pool.get(deviceId).add(context);
                    LimitedContextPool.this.acquired.remove(threadId);
                    continue;
                }
                LockSupport.parkNanos(500000L);
            }
        }
    }
}

