package org.nd4j.python4j.numpy;

import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.PyTypeObject;
import org.bytedeco.cpython.global.python;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.python4j.PythonConstants;
import org.nd4j.python4j.PythonContextManager;
import org.nd4j.python4j.PythonException;
import org.nd4j.python4j.PythonExecutioner;
import org.nd4j.python4j.PythonGIL;
import org.nd4j.python4j.PythonObject;
import org.nd4j.python4j.PythonType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/python4j/numpy/NumpyArray.class */
public class NumpyArray extends PythonType<INDArray> {
    public static final NumpyArray INSTANCE;
    public static final String IMPORT_NUMPY_ARRAY = "org.eclipse.python4j.numpyimport";
    public static final String ADD_JAVACPP_NUMPY_TO_PATH = "org.eclipse.python4j.numpyimport";
    public static final String DEFAULT_IMPORT_NUMPY_ARRAY = "true";
    public static final String DEFAULT_ADD_JAVACPP_NUMPY_TO_PATH = "true";
    private static final Logger log = LoggerFactory.getLogger(NumpyArray.class);
    private static final AtomicBoolean init = new AtomicBoolean(false);
    private static final Map<String, DataBuffer> cache = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.python4j.numpy.NumpyArray$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/python4j/numpy/NumpyArray$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BFLOAT16.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.SHORT.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.LONG.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT16.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT32.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UINT64.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BOOL.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.BYTE.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.UBYTE.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.HALF.ordinal()] = 13;
            } catch (NoSuchFieldError e13) {
            }
        }
    }

    public File[] packages() {
        try {
            return new File[]{numpy.cachePackage()};
        } catch (Exception e) {
            throw new PythonException(e);
        }
    }

    public synchronized void init() {
        if (init.get()) {
            return;
        }
        init.set(true);
        if (Boolean.parseBoolean(System.getProperty("org.eclipse.python4j.numpyimport", "true"))) {
            if (Boolean.parseBoolean(System.getProperty("org.eclipse.python4j.numpyimport", "true"))) {
                python.Py_AddPath(numpy.cachePackages());
            }
            PythonConstants.setInitializePython(false);
            python.Py_Initialize();
            if (numpy._import_array() < 0) {
                System.out.println("Numpy import failed!");
                throw new PythonException("Numpy import failed!");
            }
        }
        if (PythonGIL.locked()) {
            throw new PythonException("Can not initialize numpy - GIL already acquired.");
        }
    }

    public NumpyArray() {
        super("numpy.ndarray", INDArray.class);
    }

    /* renamed from: toJava, reason: merged with bridge method [inline-methods] */
    public INDArray m1toJava(PythonObject pythonObject) {
        DataType dataType;
        log.debug("Converting PythonObject to INDArray...");
        PyObject PyImport_ImportModule = python.PyImport_ImportModule("numpy");
        PyObject PyObject_GetAttrString = python.PyObject_GetAttrString(PyImport_ImportModule, "ndarray");
        if (python.PyObject_IsInstance(pythonObject.getNativePythonObject(), PyObject_GetAttrString) != 1) {
            python.Py_DecRef(PyObject_GetAttrString);
            python.Py_DecRef(PyImport_ImportModule);
            throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
        }
        python.Py_DecRef(PyObject_GetAttrString);
        python.Py_DecRef(PyImport_ImportModule);
        PyArrayObject pyArrayObject = new PyArrayObject(pythonObject.getNativePythonObject());
        long[] jArr = new long[numpy.PyArray_NDIM(pyArrayObject)];
        SizeTPointer PyArray_SHAPE = numpy.PyArray_SHAPE(pyArrayObject);
        if (PyArray_SHAPE != null) {
            PyArray_SHAPE.get(jArr, 0, jArr.length);
        }
        long[] jArr2 = new long[jArr.length];
        SizeTPointer PyArray_STRIDES = numpy.PyArray_STRIDES(pyArrayObject);
        if (PyArray_STRIDES != null) {
            PyArray_STRIDES.get(jArr2, 0, jArr2.length);
        }
        int PyArray_TYPE = numpy.PyArray_TYPE(pyArrayObject);
        switch (PyArray_TYPE) {
            case 0:
                dataType = DataType.BOOL;
                break;
            case 1:
                dataType = DataType.INT8;
                break;
            case 2:
                dataType = DataType.UINT8;
                break;
            case 3:
                dataType = DataType.SHORT;
                break;
            case 4:
                dataType = DataType.UINT16;
                break;
            case 5:
                dataType = DataType.INT32;
                break;
            case 6:
                dataType = DataType.UINT32;
                break;
            case 7:
                dataType = DataType.INT64;
                break;
            case 8:
            case 10:
                dataType = DataType.UINT64;
                break;
            case 9:
                dataType = DataType.INT64;
                break;
            case 11:
                dataType = DataType.FLOAT;
                break;
            case 12:
                dataType = DataType.DOUBLE;
                break;
            case 13:
            case 14:
            case 15:
            case 16:
            case 17:
            case 18:
            case 19:
            case 20:
            case 21:
            case 22:
            default:
                throw new PythonException("Unsupported array data type: " + PyArray_TYPE);
            case 23:
                dataType = DataType.FLOAT16;
                break;
        }
        long j = 1;
        int i = 0;
        while (i < jArr.length) {
            int i2 = i;
            i++;
            j *= jArr[i2];
        }
        long address = numpy.PyArray_DATA(pyArrayObject).address();
        String str = address + "_" + address + "_" + j;
        DataBuffer dataBuffer = cache.get(str);
        if (dataBuffer == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            try {
                dataBuffer = Nd4j.createBuffer(NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address).limit(j).capacity(j), j, dataType);
                cache.put(str, dataBuffer);
                if (scopeOutOfWorkspaces != null) {
                    scopeOutOfWorkspaces.close();
                }
            } catch (Throwable th) {
                if (scopeOutOfWorkspaces != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        int elementSize = dataBuffer.getElementSize();
        long[] jArr3 = new long[jArr2.length];
        for (int i3 = 0; i3 < jArr2.length; i3++) {
            jArr3[i3] = jArr2[i3] / elementSize;
        }
        INDArray create = Nd4j.create(dataBuffer, jArr, jArr3, 0L, Shape.getOrder(jArr, jArr3, 1L), dataType);
        Nd4j.getAffinityManager().tagLocation(create, AffinityManager.Location.HOST);
        log.debug("Done creating numpy array.");
        return create;
    }

    public PythonObject toPython(INDArray iNDArray) {
        int i;
        Object obj;
        log.debug("Converting INDArray to PythonObject...");
        DataType dataType = iNDArray.dataType();
        DataBuffer data = iNDArray.data();
        long address = data.pointer().address();
        cache.put(address + "_" + address + "_" + data.length(), data);
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType.ordinal()]) {
            case 1:
                i = 12;
                obj = "c_double";
                break;
            case 2:
            case 3:
                i = 11;
                obj = "c_float";
                break;
            case 4:
                i = 3;
                obj = "c_short";
                break;
            case 5:
                i = 5;
                obj = "c_int";
                break;
            case 6:
                i = numpy.NPY_INT64;
                obj = "c_int64";
                break;
            case 7:
                i = 4;
                obj = "c_uint16";
                break;
            case 8:
                i = 6;
                obj = "c_uint";
                break;
            case 9:
                i = numpy.NPY_UINT64;
                obj = "c_uint64";
                break;
            case 10:
                i = 0;
                obj = "c_bool";
                break;
            case 11:
                i = 1;
                obj = "c_byte";
                break;
            case 12:
                i = 2;
                obj = "c_ubyte";
                break;
            case 13:
                i = 23;
                obj = "c_short";
                break;
            default:
                throw new RuntimeException("Unsupported dtype: " + dataType);
        }
        long[] shape = iNDArray.shape();
        INDArray iNDArray2 = iNDArray;
        if (dataType == DataType.BFLOAT16) {
            log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
            iNDArray2 = iNDArray.castTo(DataType.FLOAT);
        }
        Nd4j.getAffinityManager().ensureLocation(iNDArray2, AffinityManager.Location.HOST);
        if (PythonConstants.releaseGilAutomatically() && !PythonConstants.createNpyViaPython()) {
            log.debug("NUMPY: PyArray_Type()");
            PyTypeObject PyArray_Type = numpy.PyArray_Type();
            log.debug("NUMPY: PyArray_New()");
            PyObject PyArray_New = numpy.PyArray_New(PyArray_Type, shape.length, new SizeTPointer(shape), i, (SizeTPointer) null, iNDArray2.data().addressPointer(), 0, 1281, (PyObject) null);
            log.debug("Created numpy array.");
            return new PythonObject(PyArray_New);
        }
        PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter");
        try {
            log.debug("Stringing exec...");
            Object obj2 = obj;
            long length = iNDArray.length();
            long address2 = iNDArray.data().pointer().address();
            if (i != 23) {
                String str = "ctypes." + obj;
            }
            Arrays.toString(iNDArray.shape());
            PythonExecutioner.exec("import ctypes\nimport numpy as np\ncArr = (ctypes." + obj2 + "*" + length + ").from_address(" + obj2 + ")\nnpArr = np.frombuffer(cArr, dtype=" + address2 + ").reshape(" + obj2 + ")");
            log.debug("exec done.");
            PythonObject variable = PythonExecutioner.getVariable("npArr");
            python.Py_IncRef(variable.getNativePythonObject());
            context.close();
            return variable;
        } catch (Throwable th) {
            try {
                context.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public boolean accepts(Object obj) {
        return obj instanceof INDArray;
    }

    /* renamed from: adapt, reason: merged with bridge method [inline-methods] */
    public INDArray m2adapt(Object obj) {
        if (obj instanceof INDArray) {
            return (INDArray) obj;
        }
        throw new PythonException("Cannot cast object of type " + obj.getClass().getName() + " to INDArray");
    }

    static {
        new PythonExecutioner();
        INSTANCE = new NumpyArray();
        INSTANCE.init();
    }
}
