package org.datavec.arrow.recordreader;

import java.io.DataInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.IOUtils;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.Record;
import org.datavec.api.records.listener.RecordListener;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.metadata.RecordMetaDataIndex;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/datavec/arrow/recordreader/ArrowRecordReader.class */
public class ArrowRecordReader implements RecordReader {
    private InputSplit split;
    private Configuration configuration;
    private Iterator<String> pathsIter;
    private int currIdx;
    private String currentPath;
    private Schema schema;
    private List<Writable> recordAllocation = new ArrayList();
    private ArrowWritableRecordBatch currentBatch;
    private List<RecordListener> recordListeners;

    public void initialize(InputSplit inputSplit) {
        this.split = inputSplit;
        this.pathsIter = inputSplit.locationsPathIterator();
    }

    public void initialize(Configuration configuration, InputSplit inputSplit) {
        this.split = inputSplit;
        this.pathsIter = inputSplit.locationsPathIterator();
    }

    public boolean batchesSupported() {
        return true;
    }

    public List<List<Writable>> next(int i) {
        if (this.currentBatch == null || this.currIdx >= this.currentBatch.size()) {
            loadNextBatch();
        }
        if (i == this.currentBatch.getArrowRecordBatch().getLength()) {
            this.currIdx += i;
            return this.currentBatch;
        }
        ArrayList arrayList = new ArrayList(i);
        while (hasNext() && 0 < i) {
            arrayList.add(next());
        }
        return arrayList;
    }

    public List<Writable> next() {
        if (this.currentBatch == null || this.currIdx >= this.currentBatch.size()) {
            loadNextBatch();
        } else {
            ArrowWritableRecordBatch arrowWritableRecordBatch = this.currentBatch;
            int i = this.currIdx;
            this.currIdx = i + 1;
            this.recordAllocation = arrowWritableRecordBatch.m3get(i);
        }
        return this.recordAllocation;
    }

    private void loadNextBatch() {
        String next = this.pathsIter.next();
        try {
            InputStream openInputStreamFor = this.split.openInputStreamFor(next);
            Throwable th = null;
            try {
                try {
                    this.currIdx = 0;
                    Pair<Schema, ArrowWritableRecordBatch> readFromBytes = ArrowConverter.readFromBytes(IOUtils.toByteArray(openInputStreamFor));
                    if (this.schema == null) {
                        this.schema = (Schema) readFromBytes.getFirst();
                    }
                    this.currentBatch = (ArrowWritableRecordBatch) readFromBytes.getRight();
                    this.recordAllocation = this.currentBatch.m3get(0);
                    this.currIdx++;
                    this.currentPath = next;
                    if (openInputStreamFor != null) {
                        if (0 != 0) {
                            try {
                                openInputStreamFor.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openInputStreamFor.close();
                        }
                    }
                } catch (Throwable th3) {
                    th = th3;
                    throw th3;
                }
            } finally {
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public boolean hasNext() {
        return this.pathsIter.hasNext() || this.currIdx < this.currentBatch.size();
    }

    public List<String> getLabels() {
        throw new UnsupportedOperationException();
    }

    public void reset() {
        if (this.split != null) {
            this.split.reset();
        }
    }

    public boolean resetSupported() {
        return true;
    }

    public List<Writable> record(URI uri, DataInputStream dataInputStream) {
        throw new UnsupportedOperationException();
    }

    public Record nextRecord() {
        next();
        return new ArrowRecord(this.currentBatch, this.currIdx - 1, URI.create(this.currentPath));
    }

    public Record loadFromMetaData(RecordMetaData recordMetaData) {
        if (!(recordMetaData instanceof RecordMetaDataIndex)) {
            throw new IllegalArgumentException("Unable to load from meta data. No index specified for record");
        }
        RecordMetaDataIndex recordMetaDataIndex = (RecordMetaDataIndex) recordMetaData;
        initialize(new FileSplit(new File(recordMetaDataIndex.getURI())));
        this.currIdx = (int) recordMetaDataIndex.getIndex();
        return nextRecord();
    }

    public List<Record> loadFromMetaData(List<RecordMetaData> list) {
        HashMap hashMap = new HashMap();
        for (RecordMetaData recordMetaData : list) {
            if (!(recordMetaData instanceof RecordMetaDataIndex)) {
                throw new IllegalArgumentException("Unable to load from meta data. No index specified for record");
            }
            List list2 = (List) hashMap.get(recordMetaData.getURI().toString());
            if (list2 == null) {
                list2 = new ArrayList();
                hashMap.put(recordMetaData.getURI().toString(), list2);
            }
            list2.add(recordMetaData);
        }
        ArrayList arrayList = new ArrayList();
        for (String str : hashMap.keySet()) {
            List list3 = (List) hashMap.get(str);
            initialize(new FileSplit(new File(URI.create(str))));
            Iterator it = list3.iterator();
            while (it.hasNext()) {
                this.currIdx = (int) ((RecordMetaData) it.next()).getIndex();
                arrayList.add(nextRecord());
            }
        }
        return arrayList;
    }

    public List<RecordListener> getListeners() {
        return this.recordListeners;
    }

    public void setListeners(RecordListener... recordListenerArr) {
        this.recordListeners = new ArrayList(Arrays.asList(recordListenerArr));
    }

    public void setListeners(Collection<RecordListener> collection) {
        this.recordListeners = new ArrayList(collection);
    }

    public void close() {
        if (this.currentBatch != null) {
            try {
                this.currentBatch.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    public void setConf(Configuration configuration) {
        this.configuration = configuration;
    }

    public Configuration getConf() {
        return this.configuration;
    }

    public ArrowWritableRecordBatch getCurrentBatch() {
        return this.currentBatch;
    }
}
