package org.nd4j.linalg.memory.deallocation;

import java.lang.ref.ReferenceQueue;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.linalg.api.memory.Deallocatable;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/memory/deallocation/DeallocatorService.class */
public class DeallocatorService {
    private static final Logger log = LoggerFactory.getLogger(DeallocatorService.class);
    private Thread[] deallocatorThreads;
    private ReferenceQueue<Deallocatable>[] queues;
    private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap();
    private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList();

    /* loaded from: input_file:org/nd4j/linalg/memory/deallocation/DeallocatorService$DeallocatorServiceThread.class */
    private class DeallocatorServiceThread extends Thread implements Runnable {
        private final ReferenceQueue<Deallocatable> queue;
        private final int threadIdx;
        public static final String DeallocatorThreadNamePrefix = "DeallocatorServiceThread thread ";
        private final int deviceId;

        private DeallocatorServiceThread(@NonNull ReferenceQueue<Deallocatable> referenceQueue, int i, int i2) {
            if (referenceQueue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            this.queue = referenceQueue;
            this.threadIdx = i;
            setName(DeallocatorThreadNamePrefix + i);
            this.deviceId = i2;
            setContextClassLoader(null);
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(this.deviceId));
            boolean z = true;
            while (z) {
                if (Nd4j.getMemoryManager().isPeriodicGcActive() && this.threadIdx == 0 && Nd4j.getMemoryManager().getAutoGcWindow() > 0) {
                    DeallocatableReference deallocatableReference = (DeallocatableReference) this.queue.poll();
                    if (deallocatableReference == null) {
                        Nd4j.getMemoryManager().getAutoGcWindow();
                        try {
                            Thread.sleep(Nd4j.getMemoryManager().getAutoGcWindow());
                            Nd4j.getMemoryManager().invokeGc();
                        } catch (InterruptedException e) {
                            z = false;
                        }
                    } else {
                        deallocatableReference.getDeallocator().deallocate();
                        DeallocatorService.this.referenceMap.remove(deallocatableReference.getId());
                    }
                } else {
                    try {
                        DeallocatableReference deallocatableReference2 = (DeallocatableReference) this.queue.remove();
                        if (deallocatableReference2 != null) {
                            deallocatableReference2.getDeallocator().deallocate();
                            DeallocatorService.this.referenceMap.remove(deallocatableReference2.getId());
                        }
                    } catch (InterruptedException e2) {
                        z = false;
                    } catch (Exception e3) {
                        throw new RuntimeException(e3);
                    }
                }
            }
        }
    }

    public DeallocatorService() {
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int max = Math.max(2, numberOfDevices * 2);
        for (int i = 0; i < numberOfDevices; i++) {
            this.deviceMap.add(new ArrayList());
        }
        this.deallocatorThreads = new Thread[max];
        this.queues = new ReferenceQueue[max];
        for (int i2 = 0; i2 < max; i2++) {
            log.trace("Starting deallocator thread {}", Integer.valueOf(i2 + 1));
            this.queues[i2] = new ReferenceQueue<>();
            int i3 = i2 % numberOfDevices;
            this.deallocatorThreads[i2] = new DeallocatorServiceThread(this.queues[i2], i2, i3);
            this.deallocatorThreads[i2].setName("DeallocatorServiceThread_" + i2);
            this.deallocatorThreads[i2].setDaemon(true);
            this.deviceMap.get(i3).add(this.queues[i2]);
            this.deallocatorThreads[i2].start();
        }
    }

    public void pickObject(@NonNull Deallocatable deallocatable) {
        if (deallocatable == null) {
            throw new NullPointerException("deallocatable is marked @NonNull but is null");
        }
        List<ReferenceQueue<Deallocatable>> list = this.deviceMap.get(deallocatable.targetDevice());
        this.referenceMap.put(deallocatable.getUniqueId(), new DeallocatableReference(deallocatable, list.get(RandomUtils.nextInt(0, list.size()))));
    }
}
