package org.openimaj.ml.linear.data;

import com.jmatio.io.MatFileReader;
import com.jmatio.types.MLArray;
import com.jmatio.types.MLCell;
import com.jmatio.types.MLDouble;
import com.jmatio.types.MLSparse;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.openimaj.util.filter.FilterUtils;
import org.openimaj.util.function.Predicate;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/data/BillMatlabFileDataGenerator.class */
public class BillMatlabFileDataGenerator implements MatrixDataGenerator<Matrix> {
    private Map<String, MLArray> content;
    private List<Fold> folds;
    private int ndays;
    private int nusers;
    private int nwords;
    private List<Matrix> dayWords;
    private List<Matrix> dayPolls;
    private int currentIndex;
    private int ntasks;
    private int[] indexes;
    private Map<Integer, String> voc;
    private String[] tasks;
    private Set<Integer> keepIndex;
    private Map<Integer, Integer> indexToVoc;
    private boolean filter;
    String mainMatrixKey;

    /* loaded from: input_file:org/openimaj/ml/linear/data/BillMatlabFileDataGenerator$Fold.class */
    public static class Fold {
        int[] training;
        int[] test;
        int[] validation;

        public Fold(int[] iArr, int[] iArr2, int[] iArr3) {
            this.training = iArr;
            this.test = iArr2;
            this.validation = iArr3;
        }
    }

    /* loaded from: input_file:org/openimaj/ml/linear/data/BillMatlabFileDataGenerator$Mode.class */
    public enum Mode {
        TRAINING { // from class: org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode.1
            @Override // org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode
            public int[] indexes(Fold fold) {
                return fold.training;
            }
        },
        TEST { // from class: org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode.2
            @Override // org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode
            public int[] indexes(Fold fold) {
                return fold.test;
            }
        },
        VALIDATION { // from class: org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode.3
            @Override // org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode
            public int[] indexes(Fold fold) {
                return fold.validation;
            }
        },
        ALL { // from class: org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode.4
            @Override // org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.Mode
            public int[] indexes(Fold fold) {
                return null;
            }
        };

        public abstract int[] indexes(Fold fold);
    }

    public BillMatlabFileDataGenerator(File file, int i, boolean z) throws IOException {
        this.mainMatrixKey = "user_vsr_for_polls";
        MatFileReader matFileReader = new MatFileReader(file);
        this.ndays = i;
        this.content = matFileReader.getContent();
        this.currentIndex = 0;
        this.filter = z;
        prepareVocabulary();
        prepareFolds();
        prepareDayUserWords();
        prepareDayPolls();
    }

    public BillMatlabFileDataGenerator(File file, String str, File file2, int i, boolean z) throws IOException {
        this.mainMatrixKey = "user_vsr_for_polls";
        MatFileReader matFileReader = new MatFileReader(file);
        this.mainMatrixKey = str;
        this.ndays = i;
        this.content = matFileReader.getContent();
        this.currentIndex = 0;
        this.filter = z;
        prepareVocabulary();
        prepareFolds();
        prepareDayUserWords();
        this.content = new MatFileReader(file2).getContent();
        prepareDayPolls();
        this.content = null;
    }

    public BillMatlabFileDataGenerator(File file, String str, File file2, int i, boolean z, List<Fold> list) throws IOException {
        this.mainMatrixKey = "user_vsr_for_polls";
        MatFileReader matFileReader = new MatFileReader(file);
        this.mainMatrixKey = str;
        this.ndays = i;
        this.content = matFileReader.getContent();
        this.currentIndex = 0;
        this.filter = z;
        prepareVocabulary();
        this.folds = list;
        prepareDayUserWords();
        this.content = new MatFileReader(file2).getContent();
        prepareDayPolls();
        this.content = null;
    }

    public Map<Integer, String> getVocabulary() {
        return this.voc;
    }

    private void prepareVocabulary() {
        this.keepIndex = new HashSet();
        MLDouble mLDouble = this.content.get("voc_keep_terms_index");
        if (mLDouble != null) {
            for (double d : mLDouble.getArray()[0]) {
                this.keepIndex.add(Integer.valueOf(((int) d) - 1));
            }
        }
        MLCell mLCell = this.content.get("voc");
        if (mLCell != null) {
            this.indexToVoc = new HashMap();
            ArrayList cells = mLCell.cells();
            int i = 0;
            int i2 = 0;
            this.voc = new HashMap();
            Iterator it = cells.iterator();
            while (it.hasNext()) {
                String string = ((MLArray) it.next()).getString(0);
                if (this.filter && this.keepIndex.contains(Integer.valueOf(i))) {
                    this.voc.put(Integer.valueOf(i2), string);
                    this.indexToVoc.put(Integer.valueOf(i), Integer.valueOf(i2));
                    i2++;
                }
                i++;
            }
        }
    }

    public void setFold(int i, Mode mode) {
        if (i == -1) {
            this.indexes = new int[this.dayWords.size()];
            for (int i2 = 0; i2 < this.indexes.length; i2++) {
                this.indexes[i2] = i2;
            }
        } else {
            this.indexes = mode.indexes(this.folds.get(i));
        }
        this.currentIndex = 0;
    }

    private void prepareDayPolls() {
        ArrayList filter = FilterUtils.filter(this.content.keySet(), new Predicate<String>() { // from class: org.openimaj.ml.linear.data.BillMatlabFileDataGenerator.1
            public boolean test(String str) {
                return str.endsWith("per_unique_extended");
            }
        });
        this.ntasks = filter.size();
        this.dayPolls = new ArrayList();
        for (int i = 0; i < this.ndays; i++) {
            this.dayPolls.add(SparseMatrixFactoryMTJ.INSTANCE.createMatrix(1, this.ntasks));
        }
        this.tasks = new String[this.ntasks];
        for (int i2 = 0; i2 < this.ntasks; i2++) {
            String str = (String) filter.get(i2);
            this.tasks[i2] = str;
            MLDouble mLDouble = this.content.get(str);
            for (int i3 = 0; i3 < this.ndays; i3++) {
                this.dayPolls.get(i3).setElement(0, i2, ((Double) mLDouble.get(i3, 0)).doubleValue());
            }
        }
    }

    public String[] getTasks() {
        return this.tasks;
    }

    private void prepareDayUserWords() {
        MLSparse mLSparse = this.content.get(this.mainMatrixKey);
        Double[] exportReal = mLSparse.exportReal();
        int[] ir = mLSparse.getIR();
        int[] ic = mLSparse.getIC();
        if (this.voc == null) {
            this.nwords = mLSparse.getN();
        } else {
            this.nwords = this.voc.size();
        }
        this.nusers = mLSparse.getM() / this.ndays;
        this.dayWords = new ArrayList();
        for (int i = 0; i < this.ndays; i++) {
            this.dayWords.add(SparseMatrixFactoryMTJ.INSTANCE.createMatrix(this.nwords, this.nusers));
        }
        for (int i2 = 0; i2 < ir.length; i2++) {
            if (!this.filter || this.keepIndex.contains(Integer.valueOf(ic[i2]))) {
                int i3 = ic[i2];
                if (this.indexToVoc != null) {
                    i3 = this.indexToVoc.get(Integer.valueOf(i3)).intValue();
                }
                int i4 = ir[i2] / this.nusers;
                this.dayWords.get(i4).setElement(i3, ir[i2] - (i4 * this.nusers), exportReal[i2].doubleValue());
            }
        }
    }

    private void prepareFolds() {
        MLCell mLCell = (MLArray) this.content.get("set_fold");
        if (mLCell == null) {
            return;
        }
        if (!mLCell.isCell()) {
            throw new RuntimeException("Can't find set_folds in expected format");
        }
        this.folds = new ArrayList();
        MLCell mLCell2 = mLCell;
        int m = mLCell2.getM();
        System.out.println(String.format("Found %d folds", Integer.valueOf(m)));
        for (int i = 0; i < m; i++) {
            this.folds.add(new Fold(toIntArray((MLDouble) mLCell2.get(i, 0)), toIntArray((MLDouble) mLCell2.get(i, 1)), toIntArray((MLDouble) mLCell2.get(i, 2))));
        }
    }

    private int[] toIntArray(MLDouble mLDouble) {
        int[] iArr = new int[mLDouble.getN()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = ((Double) mLDouble.get(0, i)).intValue();
        }
        return iArr;
    }

    @Override // org.openimaj.ml.linear.data.DataGenerator
    /* renamed from: generate */
    public Pair<Matrix> mo9generate() {
        if (this.currentIndex >= this.indexes.length) {
            return null;
        }
        int i = this.indexes[this.currentIndex];
        Pair<Matrix> pair = new Pair<>(this.dayWords.get(i), this.dayPolls.get(i));
        this.currentIndex++;
        return pair;
    }

    public int nFolds() {
        return this.folds.size();
    }

    public List<Pair<Matrix>> generateAll() {
        ArrayList arrayList = new ArrayList();
        while (true) {
            Pair<Matrix> mo9generate = mo9generate();
            if (mo9generate == null) {
                return arrayList;
            }
            arrayList.add(mo9generate);
        }
    }
}
