/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.util;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.input.PortableDataStream;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;
import scala.Tuple2;

public class MLLibUtil {
    private MLLibUtil() {
    }

    public static double toClassifierPrediction(Vector vector) {
        double max = Double.NEGATIVE_INFINITY;
        int maxIndex = 0;
        for (int i = 0; i < vector.size(); ++i) {
            double curr = vector.apply(i);
            if (!(curr > max)) continue;
            maxIndex = i;
            max = curr;
        }
        return maxIndex;
    }

    public static INDArray toMatrix(Matrix arr) {
        return Nd4j.create((double[])arr.toArray(), (int[])new int[]{arr.numRows(), arr.numCols()}, (char)'f');
    }

    public static INDArray toVector(Vector arr) {
        return Nd4j.create((DataBuffer)Nd4j.createBuffer((double[])arr.toArray()));
    }

    public static Matrix toMatrix(INDArray arr) {
        if (!arr.isMatrix()) {
            throw new IllegalArgumentException("passed in array must be a matrix");
        }
        if (arr.isView()) {
            return Matrices.dense((int)arr.rows(), (int)arr.columns(), (double[])arr.dup('f').data().asDouble());
        }
        return Matrices.dense((int)arr.rows(), (int)arr.columns(), (double[])(arr.ordering() == 'f' ? arr.data().asDouble() : arr.dup('f').data().asDouble()));
    }

    public static Vector toVector(INDArray arr) {
        if (!arr.isVector()) {
            throw new IllegalArgumentException("passed in array must be a vector");
        }
        double[] ret = new double[(int)arr.length()];
        int i = 0;
        while ((long)i < arr.length()) {
            ret[i] = arr.getDouble((long)i);
            ++i;
        }
        return Vectors.dense((double[])ret);
    }

    public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> binaryFiles, final RecordReader reader) {
        JavaRDD records = binaryFiles.map((Function)new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>(){

            public Collection<Writable> call(Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2) throws Exception {
                reader.initialize((InputSplit)new InputStreamInputSplit((InputStream)((PortableDataStream)stringPortableDataStreamTuple2._2()).open(), (String)stringPortableDataStreamTuple2._1()));
                return reader.next();
            }
        });
        JavaRDD ret = records.map((Function)new Function<Collection<Writable>, LabeledPoint>(){

            public LabeledPoint call(Collection<Writable> writables) throws Exception {
                return MLLibUtil.pointOf(writables);
            }
        });
        return ret;
    }

    public static JavaRDD<LabeledPoint> fromBinary(JavaRDD<Tuple2<String, PortableDataStream>> binaryFiles, RecordReader reader) {
        return MLLibUtil.fromBinary((JavaPairRDD<String, PortableDataStream>)JavaPairRDD.fromJavaRDD(binaryFiles), reader);
    }

    public static LabeledPoint pointOf(Collection<Writable> writables) {
        double[] ret = new double[writables.size() - 1];
        int count = 0;
        double target = 0.0;
        for (Writable w : writables) {
            if (count < writables.size() - 1) {
                ret[count++] = Float.parseFloat(w.toString());
                continue;
            }
            target = Float.parseFloat(w.toString());
        }
        if (target < 0.0) {
            throw new IllegalStateException("Target must be >= 0");
        }
        return new LabeledPoint(target, Vectors.dense((double[])ret));
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final long numPossibleLabels, long batchSize) {
        JavaRDD mappedData = data.map((Function)new Function<LabeledPoint, DataSet>(){

            public DataSet call(LabeledPoint lp) {
                return MLLibUtil.fromLabeledPoint(lp, numPossibleLabels);
            }
        });
        return mappedData.repartition((int)(mappedData.count() / batchSize));
    }

    @Deprecated
    public static JavaRDD<DataSet> fromLabeledPoint(JavaSparkContext sc, JavaRDD<LabeledPoint> data, final long numPossibleLabels) {
        return data.map((Function)new Function<LabeledPoint, DataSet>(){

            public DataSet call(LabeledPoint lp) {
                return MLLibUtil.fromLabeledPoint(lp, numPossibleLabels);
            }
        });
    }

    @Deprecated
    public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaSparkContext sc, JavaRDD<LabeledPoint> data) {
        return data.map((Function)new Function<LabeledPoint, DataSet>(){

            public DataSet call(LabeledPoint lp) {
                return MLLibUtil.convertToDataset(lp);
            }
        });
    }

    private static DataSet convertToDataset(LabeledPoint lp) {
        Vector features = lp.features();
        double label = lp.label();
        return new DataSet(Nd4j.create((double[])features.toArray()), Nd4j.create((double[])new double[]{label}));
    }

    @Deprecated
    public static JavaRDD<LabeledPoint> fromDataSet(JavaSparkContext sc, JavaRDD<DataSet> data) {
        return data.map((Function)new Function<DataSet, LabeledPoint>(){

            public LabeledPoint call(DataSet pt) {
                return MLLibUtil.toLabeledPoint(pt);
            }
        });
    }

    private static List<LabeledPoint> toLabeledPoint(List<DataSet> labeledPoints) {
        ArrayList<LabeledPoint> ret = new ArrayList<LabeledPoint>();
        for (DataSet point : labeledPoints) {
            ret.add(MLLibUtil.toLabeledPoint(point));
        }
        return ret;
    }

    private static LabeledPoint toLabeledPoint(DataSet point) {
        if (!point.getFeatures().isVector()) {
            throw new IllegalArgumentException("Feature matrix must be a vector");
        }
        Vector features = MLLibUtil.toVector(point.getFeatures().dup());
        double label = Nd4j.getBlasWrapper().iamax(point.getLabels());
        return new LabeledPoint(label, features);
    }

    public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data) {
        return MLLibUtil.fromContinuousLabeledPoint(data, false);
    }

    public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data, boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map((Function)new Function<LabeledPoint, DataSet>(){

            public DataSet call(LabeledPoint lp) {
                return MLLibUtil.convertToDataset(lp);
            }
        });
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, long numPossibleLabels) {
        return MLLibUtil.fromLabeledPoint(data, numPossibleLabels, false);
    }

    public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final long numPossibleLabels, boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map((Function)new Function<LabeledPoint, DataSet>(){

            public DataSet call(LabeledPoint lp) {
                return MLLibUtil.fromLabeledPoint(lp, numPossibleLabels);
            }
        });
    }

    public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data) {
        return MLLibUtil.fromDataSet(data, false);
    }

    public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data, boolean preCache) {
        if (preCache && !data.getStorageLevel().useMemory()) {
            data.cache();
        }
        return data.map((Function)new Function<DataSet, LabeledPoint>(){

            public LabeledPoint call(DataSet dataSet) {
                return MLLibUtil.toLabeledPoint(dataSet);
            }
        });
    }

    private static List<DataSet> fromLabeledPoint(List<LabeledPoint> labeledPoints, long numPossibleLabels) {
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (LabeledPoint point : labeledPoints) {
            ret.add(MLLibUtil.fromLabeledPoint(point, numPossibleLabels));
        }
        return ret;
    }

    private static DataSet fromLabeledPoint(LabeledPoint point, long numPossibleLabels) {
        Vector features = point.features();
        double label = point.label();
        double[] fArr = features.toArray();
        return new DataSet(Nd4j.create((double[])fArr, (long[])new long[]{1L, fArr.length}), FeatureUtil.toOutcomeVector((long)((int)label), (long)((int)numPossibleLabels)));
    }
}

