package ai.djl.pytorch.jni;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.engine.PtSymbolBlock;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.regex.Pattern;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/pytorch/jni/IValueUtils.class */
public final class IValueUtils {
    private IValueUtils() {
    }

    public static long toIValuePointer(long j) {
        return PyTorchLibrary.LIB.iValueFromTensor(j);
    }

    public static long iValueFromList(long[] jArr) {
        return PyTorchLibrary.LIB.iValueFromList(jArr);
    }

    public static long iValueFromDict(long[] jArr, String[] strArr) {
        return PyTorchLibrary.LIB.iValueFromDict(jArr, strArr);
    }

    public static boolean isNDArray(long j) {
        return PyTorchLibrary.LIB.iValueIsTensor(j);
    }

    public static boolean isNDList(long j) {
        return PyTorchLibrary.LIB.iValueIsTensorList(j);
    }

    public static boolean isList(long j) {
        return PyTorchLibrary.LIB.iValueIsList(j);
    }

    public static boolean isTuple(long j) {
        return PyTorchLibrary.LIB.iValueIsTuple(j);
    }

    public static boolean isMap(long j) {
        return PyTorchLibrary.LIB.iValueIsMap(j);
    }

    public static boolean isString(long j) {
        return PyTorchLibrary.LIB.iValueIsString(j);
    }

    public static PtNDArray toNDArray(long j, PtNDManager ptNDManager) {
        return new PtNDArray(ptNDManager, PyTorchLibrary.LIB.iValueToTensor(j));
    }

    public static NDList toNDList(long j, PtNDManager ptNDManager) {
        long[] iValueToTensorList = PyTorchLibrary.LIB.iValueToTensorList(j);
        NDList nDList = new NDList();
        for (long j2 : iValueToTensorList) {
            nDList.add(new PtNDArray(ptNDManager, j2));
        }
        return nDList;
    }

    public static String toString(long j) {
        return PyTorchLibrary.LIB.iValueToString(j);
    }

    public static long[] toIValueArray(long j) {
        return isTuple(j) ? PyTorchLibrary.LIB.iValueToListFromTuple(j) : PyTorchLibrary.LIB.iValueToList(j);
    }

    public static Map<Long, Long> toIValueMap(long j) {
        long[] iValueToMap = PyTorchLibrary.LIB.iValueToMap(j);
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (int i = 0; i < iValueToMap.length; i += 2) {
            concurrentHashMap.put(Long.valueOf(iValueToMap[i]), Long.valueOf(iValueToMap[i + 1]));
        }
        return concurrentHashMap;
    }

    private static NDList forwardHelper(long j, PtNDManager ptNDManager) {
        NDList nDList = new NDList();
        if (isNDArray(j)) {
            nDList.add(toNDArray(j, ptNDManager));
        } else if (isNDList(j)) {
            nDList.addAll(toNDList(j, ptNDManager));
        } else if (isList(j) || isTuple(j)) {
            for (long j2 : toIValueArray(j)) {
                nDList.addAll(forwardHelper(j2, ptNDManager));
            }
        } else {
            if (!isMap(j)) {
                PyTorchLibrary.LIB.torchDeleteIValue(j);
                throw new UnsupportedOperationException("Unsupported IValue type");
            }
            for (Map.Entry<Long, Long> entry : toIValueMap(j).entrySet()) {
                String iValueUtils = toString(entry.getKey().longValue());
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getKey().longValue());
                PtNDArray nDArray = toNDArray(entry.getValue().longValue(), ptNDManager);
                PyTorchLibrary.LIB.torchDeleteIValue(entry.getValue().longValue());
                nDArray.setName(iValueUtils);
                nDList.add(nDArray);
            }
        }
        PyTorchLibrary.LIB.torchDeleteIValue(j);
        return nDList;
    }

    public static NDList forward(PtSymbolBlock ptSymbolBlock, NDList nDList, boolean z) {
        return forwardHelper(PyTorchLibrary.LIB.moduleForward(((Long) ptSymbolBlock.getHandle()).longValue(), getInputs(nDList.stream().mapToLong(nDArray -> {
            return ((Long) ((PtNDArray) nDArray).getHandle()).longValue();
        }).toArray(), (String[]) nDList.stream().map((v0) -> {
            return v0.getName();
        }).toArray(i -> {
            return new String[i];
        })), z), ((NDArray) nDList.get(0)).getManager());
    }

    private static boolean isNameList(String str) {
        return Pattern.matches("\\w+\\[]", str);
    }

    private static boolean isNameDict(String str) {
        return str.contains(".");
    }

    private static long[] getInputs(long[] jArr, String[] strArr) {
        ArrayList arrayList = new ArrayList();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        for (int i = 0; i < jArr.length; i++) {
            String str = strArr[i];
            if (str == null || !(isNameList(str) || isNameDict(str))) {
                PairList pairList = new PairList();
                pairList.add(new Pair((Object) null, Long.valueOf(toIValuePointer(jArr[i]))));
                arrayList.add(pairList);
            } else {
                String str2 = null;
                boolean isNameDict = isNameDict(strArr[i]);
                if (isNameDict) {
                    String[] split = strArr[i].split("\\.");
                    Preconditions.checkArgument(split.length == 2, "Please make sure you only include one '.' in the name. Nested Map is not supported!");
                    str = split[0];
                    str2 = split[1];
                }
                if (!concurrentHashMap.containsKey(str)) {
                    arrayList.add(new PairList());
                    concurrentHashMap.put(str, Integer.valueOf(arrayList.size() - 1));
                }
                if (isNameDict) {
                    ((PairList) arrayList.get(((Integer) concurrentHashMap.get(str)).intValue())).add(new Pair(str2, Long.valueOf(jArr[i])));
                } else {
                    ((PairList) arrayList.get(((Integer) concurrentHashMap.get(str)).intValue())).add(new Pair(str, Long.valueOf(jArr[i])));
                }
            }
        }
        long[] jArr2 = new long[arrayList.size()];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            if (((PairList) arrayList.get(i2)).size() == 1 && ((PairList) arrayList.get(i2)).get(0).getKey() == null) {
                jArr2[i2] = ((Long) ((PairList) arrayList.get(i2)).get(0).getValue()).longValue();
            } else if (isNameList((String) ((PairList) arrayList.get(i2)).get(0).getKey())) {
                jArr2[i2] = iValueFromList(toPrimitiveLongArray((Long[]) ((PairList) arrayList.get(i2)).valueArray(new Long[0])));
            } else {
                PairList pairList2 = (PairList) arrayList.get(i2);
                jArr2[i2] = iValueFromDict(toPrimitiveLongArray((Long[]) pairList2.valueArray(new Long[0])), (String[]) pairList2.keyArray(new String[0]));
            }
        }
        return jArr2;
    }

    private static long[] toPrimitiveLongArray(Long[] lArr) {
        if (lArr == null) {
            return null;
        }
        return lArr.length == 0 ? new long[0] : Stream.of((Object[]) lArr).mapToLong((v0) -> {
            return v0.longValue();
        }).toArray();
    }
}
