package org.openimaj.ml.linear.experiments.sinabill;

import com.google.common.primitives.Doubles;
import com.jmatio.io.MatFileWriter;
import gov.sandia.cognition.math.matrix.Matrix;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.log4j.ConsoleAppender;
import org.apache.log4j.FileAppender;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.log4j.PatternLayout;
import org.openimaj.io.IOUtils;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.data.BillMatlabFileDataGenerator;
import org.openimaj.ml.linear.evaluation.RootMeanSumLossEvaluator;
import org.openimaj.ml.linear.learner.BilinearLearnerParameters;
import org.openimaj.ml.linear.learner.BilinearSparseOnlineLearner;
import org.openimaj.ml.linear.learner.init.SparseZerosInitStrategy;
import org.openimaj.ml.linear.learner.loss.MatSquareLossFunction;
import org.openimaj.util.pair.Pair;

/* loaded from: input_file:org/openimaj/ml/linear/experiments/sinabill/LambdaSearchAustrian.class */
public class LambdaSearchAustrian {
    private static final int NFOLDS = 1;
    private static final String ROOT = "/Users/ss/Experiments/bilinear/austrian/";
    private static final String OUTPUT_ROOT = "/Users/ss/Dropbox/TrendMiner/Collaboration/StreamingBilinear2014/experiments";
    private final Logger logger = Logger.getLogger(getClass());
    private long expStartTime = System.currentTimeMillis();

    public static void main(String[] strArr) throws IOException {
        new LambdaSearchAustrian().performExperiment();
    }

    public void performExperiment() throws IOException {
        BillMatlabFileDataGenerator billMatlabFileDataGenerator = new BillMatlabFileDataGenerator(new File(dataFromRoot("normalised.mat")), "user_vsr_for_polls_SINA", new File(dataFromRoot("unnormalised.mat")), 98, false, prepareFolds());
        prepareExperimentLog();
        RootMeanSumLossEvaluator rootMeanSumLossEvaluator = new RootMeanSumLossEvaluator();
        for (int i = 0; i < billMatlabFileDataGenerator.nFolds(); i++) {
            this.logger.info("Starting Fold: " + i);
            BilinearSparseOnlineLearner lineSearchParams = lineSearchParams(i, billMatlabFileDataGenerator);
            this.logger.debug("Best params found! Starting test...");
            billMatlabFileDataGenerator.setFold(i, BillMatlabFileDataGenerator.Mode.TEST);
            rootMeanSumLossEvaluator.setLearner(lineSearchParams);
            this.logger.debug("Test RMSE: " + rootMeanSumLossEvaluator.evaluate(billMatlabFileDataGenerator.generateAll()));
        }
    }

    private BilinearSparseOnlineLearner lineSearchParams(int i, BillMatlabFileDataGenerator billMatlabFileDataGenerator) {
        BilinearSparseOnlineLearner bilinearSparseOnlineLearner = null;
        double d = Double.MAX_VALUE;
        RootMeanSumLossEvaluator rootMeanSumLossEvaluator = new RootMeanSumLossEvaluator();
        int i2 = 0;
        List<BilinearLearnerParameters> parameterLineSearch = parameterLineSearch();
        this.logger.info("Optimising params, searching: " + parameterLineSearch.size());
        for (BilinearLearnerParameters bilinearLearnerParameters : parameterLineSearch) {
            this.logger.info(String.format("Optimising params %d/%d", Integer.valueOf(i2 + 1), Integer.valueOf(parameterLineSearch.size())));
            this.logger.debug("Current Params:\n" + bilinearLearnerParameters.toString());
            BilinearSparseOnlineLearner bilinearSparseOnlineLearner2 = new BilinearSparseOnlineLearner(bilinearLearnerParameters);
            billMatlabFileDataGenerator.setFold(i, BillMatlabFileDataGenerator.Mode.TRAINING);
            this.logger.debug("Training...");
            while (true) {
                Pair<Matrix> mo9generate = billMatlabFileDataGenerator.mo9generate();
                if (mo9generate == null) {
                    break;
                }
                bilinearSparseOnlineLearner2.process((Matrix) mo9generate.firstObject(), (Matrix) mo9generate.secondObject());
            }
            this.logger.debug("Generating score of validation set");
            billMatlabFileDataGenerator.setFold(i, BillMatlabFileDataGenerator.Mode.VALIDATION);
            rootMeanSumLossEvaluator.setLearner(bilinearSparseOnlineLearner2);
            double evaluate = rootMeanSumLossEvaluator.evaluate(billMatlabFileDataGenerator.generateAll());
            this.logger.debug("Total RMSE: " + evaluate);
            this.logger.debug("U sparcity: " + CFMatrixUtils.sparsity(bilinearSparseOnlineLearner2.getU()));
            this.logger.debug("W sparcity: " + CFMatrixUtils.sparsity(bilinearSparseOnlineLearner2.getW()));
            if (evaluate < d) {
                this.logger.info("New best score detected!");
                d = evaluate;
                bilinearSparseOnlineLearner = bilinearSparseOnlineLearner2;
                this.logger.info("New Best Config:\n" + bilinearSparseOnlineLearner.getParams());
                this.logger.info("New Best Loss:" + evaluate);
                saveFoldParameterLearner(i, i2, bilinearSparseOnlineLearner2);
            }
            i2++;
        }
        return bilinearSparseOnlineLearner;
    }

    private void saveFoldParameterLearner(int i, int i2, BilinearSparseOnlineLearner bilinearSparseOnlineLearner) {
        File file = new File(String.format("%s/fold_%d", currentOutputRoot(), Integer.valueOf(i)), String.format("learner_%d", Integer.valueOf(i2)));
        File file2 = new File(String.format("%s/fold_%d", currentOutputRoot(), Integer.valueOf(i)), String.format("learner_%d.mat", Integer.valueOf(i2)));
        file.getParentFile().mkdirs();
        try {
            IOUtils.writeBinary(file, bilinearSparseOnlineLearner);
            ArrayList arrayList = new ArrayList();
            arrayList.add(CFMatrixUtils.toMLArray("u", bilinearSparseOnlineLearner.getU()));
            arrayList.add(CFMatrixUtils.toMLArray("w", bilinearSparseOnlineLearner.getW()));
            if (bilinearSparseOnlineLearner.getBias() != null) {
                arrayList.add(CFMatrixUtils.toMLArray("b", bilinearSparseOnlineLearner.getBias()));
            }
            new MatFileWriter(file2, arrayList);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private List<BilinearLearnerParameters> parameterLineSearch() {
        BilinearLearnerParametersLineSearch bilinearLearnerParametersLineSearch = new BilinearLearnerParametersLineSearch(prepareParams());
        bilinearLearnerParametersLineSearch.addIteration(BilinearLearnerParameters.ETA0_U, Doubles.asList(new double[]{1.0E-4d}));
        bilinearLearnerParametersLineSearch.addIteration(BilinearLearnerParameters.ETA0_W, Doubles.asList(new double[]{0.005d}));
        bilinearLearnerParametersLineSearch.addIteration(BilinearLearnerParameters.ETA0_BIAS, Doubles.asList(new double[]{50.0d}));
        bilinearLearnerParametersLineSearch.addIteration(BilinearLearnerParameters.LAMBDA_U, Doubles.asList(new double[]{1.0E-5d}));
        bilinearLearnerParametersLineSearch.addIteration(BilinearLearnerParameters.LAMBDA_W, Doubles.asList(new double[]{1.0E-5d}));
        ArrayList arrayList = new ArrayList();
        Iterator<BilinearLearnerParameters> it = bilinearLearnerParametersLineSearch.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        return arrayList;
    }

    private List<BillMatlabFileDataGenerator.Fold> prepareFolds() {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 1; i++) {
            int i2 = (i * 5) + 48;
            int[] iArr = new int[i2 - 8];
            int[] iArr2 = new int[5];
            int[] iArr3 = new int[8];
            int i3 = 0;
            int i4 = 0;
            int round = ((int) Math.round(i2 / 2.0d)) - 1;
            while (i3 < round - 4) {
                iArr[i4] = i3;
                i3++;
                i4++;
            }
            int i5 = 0;
            while (i5 < iArr3.length) {
                iArr3[i5] = i3;
                i5++;
                i3++;
            }
            while (i3 < i2) {
                iArr[i4] = i3;
                i3++;
                i4++;
            }
            int i6 = 0;
            while (i6 < iArr2.length) {
                iArr2[i6] = i3;
                i6++;
                i3++;
            }
            arrayList.add(new BillMatlabFileDataGenerator.Fold(iArr, iArr2, iArr3));
        }
        return arrayList;
    }

    private BilinearLearnerParameters prepareParams() {
        BilinearLearnerParameters bilinearLearnerParameters = new BilinearLearnerParameters();
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_U, null);
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_W, null);
        bilinearLearnerParameters.put(BilinearLearnerParameters.LAMBDA_U, null);
        bilinearLearnerParameters.put(BilinearLearnerParameters.LAMBDA_W, null);
        bilinearLearnerParameters.put(BilinearLearnerParameters.ETA0_BIAS, null);
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_TOL, Double.valueOf(0.01d));
        bilinearLearnerParameters.put(BilinearLearnerParameters.BICONVEX_MAXITER, 10);
        bilinearLearnerParameters.put(BilinearLearnerParameters.BIAS, true);
        bilinearLearnerParameters.put(BilinearLearnerParameters.WINITSTRAT, new SparseZerosInitStrategy());
        bilinearLearnerParameters.put(BilinearLearnerParameters.UINITSTRAT, new SparseZerosInitStrategy());
        bilinearLearnerParameters.put(BilinearLearnerParameters.LOSS, new MatSquareLossFunction());
        return bilinearLearnerParameters;
    }

    public static String dataFromRoot(String str) {
        return String.format("%s/%s", ROOT, str);
    }

    protected void prepareExperimentLog() throws IOException {
        ConsoleAppender consoleAppender = new ConsoleAppender();
        consoleAppender.setLayout(new PatternLayout("[%p->%C{1}] %m%n"));
        consoleAppender.setThreshold(Level.INFO);
        consoleAppender.activateOptions();
        Logger.getRootLogger().addAppender(consoleAppender);
        File prepareExperimentRoot = prepareExperimentRoot();
        File file = new File(prepareExperimentRoot, "log");
        if (file.exists()) {
            file.delete();
        }
        FileAppender fileAppender = new FileAppender(new PatternLayout("[%d{HH:mm:ss} %p->%C{1}] %m%n"), file.getAbsolutePath());
        fileAppender.setThreshold(Level.DEBUG);
        fileAppender.activateOptions();
        Logger.getRootLogger().addAppender(fileAppender);
        this.logger.info("Experiment root: " + prepareExperimentRoot);
    }

    public File prepareExperimentRoot() throws IOException {
        File file = new File(currentOutputRoot());
        if (file.exists() && file.isDirectory()) {
            return file;
        }
        this.logger.debug("Experiment root: " + file);
        if (file.mkdirs()) {
            return file;
        }
        throw new IOException("Couldn't prepare experiment output");
    }

    private String currentOutputRoot() {
        return String.format("%s/%s/%s", OUTPUT_ROOT, getExperimentSetName(), "" + currentExperimentTime());
    }

    private long currentExperimentTime() {
        return this.expStartTime;
    }

    private String getExperimentSetName() {
        return "streamingBilinear/optimiselambda";
    }
}
