/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.datavec;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.api.java.function.Function;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.StringSplit;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

public class RecordReaderFunction
implements Function<String, DataSet> {
    private RecordReader recordReader;
    private int labelIndex = -1;
    private int numPossibleLabels = -1;
    private WritableConverter converter;

    public RecordReaderFunction(RecordReader recordReader, int labelIndex, int numPossibleLabels, WritableConverter converter) {
        this.recordReader = recordReader;
        this.labelIndex = labelIndex;
        this.numPossibleLabels = numPossibleLabels;
        this.converter = converter;
    }

    public RecordReaderFunction(RecordReader recordReader, int labelIndex, int numPossibleLabels) {
        this(recordReader, labelIndex, numPossibleLabels, null);
    }

    public DataSet call(String v1) throws Exception {
        this.recordReader.initialize((InputSplit)new StringSplit(v1));
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        List currList = this.recordReader.next();
        INDArray label = null;
        INDArray featureVector = Nd4j.create((int)1, (int)(this.labelIndex >= 0 ? currList.size() - 1 : currList.size()));
        int count = 0;
        for (int j = 0; j < currList.size(); ++j) {
            Writable current;
            if (this.labelIndex >= 0 && j == this.labelIndex) {
                if (this.numPossibleLabels < 1) {
                    throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
                }
                current = (Writable)currList.get(j);
                if (this.converter != null) {
                    current = this.converter.convert(current);
                }
                label = FeatureUtil.toOutcomeVector((long)current.toInt(), (long)this.numPossibleLabels);
                continue;
            }
            current = (Writable)currList.get(j);
            featureVector.putScalar((long)count++, current.toDouble());
        }
        dataSets.add(new DataSet(featureVector, this.labelIndex >= 0 ? label : featureVector));
        ArrayList<INDArray> inputs = new ArrayList<INDArray>();
        ArrayList<INDArray> labels = new ArrayList<INDArray>();
        for (DataSet data : dataSets) {
            inputs.add(data.getFeatures());
            labels.add(data.getLabels());
        }
        DataSet ret = new DataSet(Nd4j.vstack((INDArray[])inputs.toArray(new INDArray[inputs.size()])), Nd4j.vstack((INDArray[])labels.toArray(new INDArray[inputs.size()])));
        return ret;
    }
}

