package org.nd4j.arrow;

import com.google.flatbuffers.FlatBufferBuilder;
import org.apache.arrow.flatbuf.Buffer;
import org.apache.arrow.flatbuf.Tensor;
import org.apache.arrow.flatbuf.TensorDim;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/arrow/ArrowSerde.class */
public class ArrowSerde {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.arrow.ArrowSerde$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/arrow/ArrowSerde$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.LONG.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.INT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
        }
    }

    public static INDArray fromTensor(Tensor tensor) {
        byte typeType = tensor.typeType();
        int[] iArr = new int[tensor.shapeLength()];
        int[] iArr2 = new int[tensor.stridesLength()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = (int) tensor.shape(i).size();
            iArr2[i] = (int) tensor.strides(i);
        }
        int prod = ArrayUtil.prod(iArr);
        Buffer data = tensor.data();
        if (data == null) {
            throw new ND4JIllegalStateException("Buffer was not serialized properly.");
        }
        int length = ((int) data.length()) / prod;
        for (int i2 = 0; i2 < iArr2.length; i2++) {
            int i3 = i2;
            iArr2[i3] = iArr2[i3] / length;
        }
        INDArray create = Nd4j.create(DataBufferStruct.createFromByteBuffer(tensor.getByteBuffer(), (int) tensor.data().offset(), typeFromTensorType(typeType, length), prod), iArr);
        create.setShapeAndStride(iArr, iArr2);
        return create;
    }

    public static Tensor toTensor(INDArray iNDArray) {
        FlatBufferBuilder flatBufferBuilder = new FlatBufferBuilder(1024);
        long[] arrowStrides = getArrowStrides(iNDArray);
        int createDims = createDims(flatBufferBuilder, iNDArray);
        int createStridesVector = Tensor.createStridesVector(flatBufferBuilder, arrowStrides);
        Tensor.startTensor(flatBufferBuilder);
        addTypeTypeRelativeToNDArray(flatBufferBuilder, iNDArray);
        Tensor.addShape(flatBufferBuilder, createDims);
        Tensor.addStrides(flatBufferBuilder, createStridesVector);
        Tensor.addData(flatBufferBuilder, addDataForArr(flatBufferBuilder, iNDArray));
        Tensor.finishTensorBuffer(flatBufferBuilder, Tensor.endTensor(flatBufferBuilder));
        return Tensor.getRootAsTensor(flatBufferBuilder.dataBuffer());
    }

    public static int addDataForArr(FlatBufferBuilder flatBufferBuilder, INDArray iNDArray) {
        return Buffer.createBuffer(flatBufferBuilder, DataBufferStruct.createDataBufferStruct(flatBufferBuilder, r10), (iNDArray.isView() ? iNDArray.dup().data() : iNDArray.data()).length() * r10.getElementSize());
    }

    public static void addTypeTypeRelativeToNDArray(FlatBufferBuilder flatBufferBuilder, INDArray iNDArray) {
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[iNDArray.data().dataType().ordinal()]) {
            case 1:
            case 2:
                Tensor.addTypeType(flatBufferBuilder, (byte) 2);
                return;
            case 3:
                Tensor.addTypeType(flatBufferBuilder, (byte) 3);
                return;
            case 4:
                Tensor.addTypeType(flatBufferBuilder, (byte) 7);
                return;
            default:
                return;
        }
    }

    public static int createDims(FlatBufferBuilder flatBufferBuilder, INDArray iNDArray) {
        int[] iArr = new int[iNDArray.rank()];
        int[] iArr2 = new int[iNDArray.rank()];
        for (int i = 0; i < iArr.length; i++) {
            iArr2[i] = flatBufferBuilder.createString("");
            iArr[i] = TensorDim.createTensorDim(flatBufferBuilder, iNDArray.size(i), iArr2[i]);
        }
        return Tensor.createShapeVector(flatBufferBuilder, iArr);
    }

    public static long[] getArrowStrides(INDArray iNDArray) {
        long[] jArr = new long[iNDArray.rank()];
        for (int i = 0; i < iNDArray.rank(); i++) {
            jArr[i] = iNDArray.stride(i) * iNDArray.data().getElementSize();
        }
        return jArr;
    }

    public static DataType typeFromTensorType(byte b, int i) {
        if (b == 3) {
            return DataType.FLOAT;
        }
        if (b == 7) {
            return DataType.DOUBLE;
        }
        if (b != 2) {
            throw new IllegalArgumentException("Only valid types are Type.Decimal and Type.Int");
        }
        if (i == 4) {
            return DataType.INT;
        }
        if (i == 8) {
            return DataType.LONG;
        }
        throw new IllegalArgumentException("Unable to determine data type");
    }
}
