/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.jita.allocator.tad;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicLong;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.jcublas.buffer.AddressRetriever;
import org.nd4j.linalg.jcublas.buffer.CudaLongDataBuffer;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BasicTADManager
implements TADManager {
    protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private static Logger logger = LoggerFactory.getLogger(BasicTADManager.class);
    protected AtomicLong bytes = new AtomicLong(0L);

    public Pair<DataBuffer, DataBuffer> getTADOnlyShapeInfo(INDArray array, int[] dimension) {
        if (dimension != null && dimension.length > 1) {
            Arrays.sort(dimension);
        }
        if (dimension == null) {
            dimension = new int[]{Integer.MAX_VALUE};
        }
        boolean isScalar = dimension == null || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE;
        int targetRank = isScalar ? 2 : array.rank();
        long offsetLength = 0L;
        long tadLength = 1L;
        if (!isScalar) {
            for (int i = 0; i < dimension.length; ++i) {
                tadLength *= array.shape()[dimension[i]];
            }
        }
        offsetLength = !isScalar ? array.lengthLong() / tadLength : 1L;
        CudaLongDataBuffer outputBuffer = new CudaLongDataBuffer(targetRank * 2 + 4);
        CudaLongDataBuffer offsetsBuffer = new CudaLongDataBuffer(offsetLength);
        AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();
        DataBuffer dimensionBuffer = AtomicAllocator.getInstance().getConstantBuffer(dimension);
        Pointer dimensionPointer = AtomicAllocator.getInstance().getHostPointer(dimensionBuffer);
        Pointer xShapeInfo = AddressRetriever.retrieveHostPointer(array.shapeInfoDataBuffer());
        Pointer targetPointer = AddressRetriever.retrieveHostPointer(outputBuffer);
        Pointer offsetsPointer = AddressRetriever.retrieveHostPointer(offsetsBuffer);
        if (!isScalar) {
            this.nativeOps.tadOnlyShapeInfo((LongPointer)xShapeInfo, (IntPointer)dimensionPointer, dimension.length, (LongPointer)targetPointer, (LongPointer)new LongPointerWrapper(offsetsPointer));
        } else {
            outputBuffer.put(0L, 2);
            outputBuffer.put(1L, 1);
            outputBuffer.put(2L, 1);
            outputBuffer.put(3L, 1);
            outputBuffer.put(4L, 1);
            outputBuffer.put(5L, 0);
            outputBuffer.put(6L, 0);
            outputBuffer.put(7L, 99);
        }
        AtomicAllocator.getInstance().getAllocationPoint(outputBuffer).tickHostWrite();
        AtomicAllocator.getInstance().getAllocationPoint(offsetsBuffer).tickHostWrite();
        return new Pair((Object)outputBuffer, (Object)offsetsBuffer);
    }

    public void purgeBuffers() {
    }

    public long getCachedBytes() {
        return this.bytes.get();
    }
}

