package org.deeplearning4j.spark.datavec;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.Text;
import org.apache.spark.api.java.function.PairFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.class */
public class DataVecByteDataSetFunction implements PairFunction<Tuple2<Text, BytesWritable>, Double, DataSet> {
    private int labelIndex;
    private int numPossibleLabels;
    private int byteFileLen;
    private int batchSize;
    private int numExamples;
    private boolean regression;
    private DataSetPreProcessor preProcessor;

    public DataVecByteDataSetFunction(int i, int i2, int i3, int i4) {
        this(i, i2, i3, i4, false, null);
    }

    public DataVecByteDataSetFunction(int i, int i2, int i3, int i4, boolean z) {
        this(i, i2, i3, i4, z, null);
    }

    public DataVecByteDataSetFunction(int i, int i2, int i3, int i4, boolean z, DataSetPreProcessor dataSetPreProcessor) {
        this.labelIndex = 0;
        this.regression = false;
        this.labelIndex = i;
        this.numPossibleLabels = i2;
        this.batchSize = i3;
        this.byteFileLen = i4;
        this.regression = z;
        this.preProcessor = dataSetPreProcessor;
    }

    public Tuple2<Double, DataSet> call(Tuple2<Text, BytesWritable> tuple2) throws Exception {
        int i = 0;
        if (this.numPossibleLabels >= 1) {
            i = this.byteFileLen - 1;
            if (this.labelIndex < 0) {
                this.labelIndex = this.byteFileLen - 1;
            }
        }
        DataInputStream dataInputStream = new DataInputStream(new ByteArrayInputStream(((BytesWritable) tuple2._2()).getBytes()));
        int i2 = 0;
        byte[] bArr = new byte[this.byteFileLen];
        ArrayList<DataSet> arrayList = new ArrayList();
        try {
            INDArray create = Nd4j.create(i);
            while (dataInputStream.read(bArr) != -1 && i2 != this.batchSize) {
                int i3 = 0;
                INDArray outcomeVector = FeatureUtil.toOutcomeVector(bArr[this.labelIndex], this.numPossibleLabels);
                for (int i4 = 1; i4 <= create.length(); i4++) {
                    int i5 = i3;
                    i3++;
                    create.putScalar(i5, bArr[i4]);
                }
                arrayList.add(new DataSet(create, outcomeVector));
                i2++;
                bArr = new byte[this.byteFileLen];
                create = Nd4j.create(i);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (DataSet dataSet : arrayList) {
            arrayList2.add(dataSet.getFeatures());
            arrayList3.add(dataSet.getLabels());
        }
        DataSet dataSet2 = new DataSet(Nd4j.vstack((INDArray[]) arrayList2.toArray(new INDArray[0])), Nd4j.vstack((INDArray[]) arrayList3.toArray(new INDArray[0])));
        if (this.preProcessor != null) {
            this.preProcessor.preProcess(dataSet2);
        }
        return new Tuple2<>(Double.valueOf(i2), dataSet2);
    }
}
