/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import org.apache.spark.SparkException;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.classification.KNNClassificationModel;
import org.apache.spark.ml.classification.MultiClassSummarizer;
import org.apache.spark.ml.classification.ProbabilisticClassifier;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.ml.knn.KNN;
import org.apache.spark.ml.knn.KNNModelParams;
import org.apache.spark.ml.knn.KNNModelParams$class;
import org.apache.spark.ml.knn.KNNParams;
import org.apache.spark.ml.knn.KNNParams$class;
import org.apache.spark.ml.knn.Tree;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.Params;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCols;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001\u0005ub\u0001B\u0001\u0003\u00015\u0011Qb\u0013(O\u00072\f7o]5gS\u0016\u0014(BA\u0002\u0005\u00039\u0019G.Y:tS\u001aL7-\u0019;j_:T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M!\u0001A\u0004\u000f#!\u0015y\u0001C\u0005\r\u001a\u001b\u0005\u0011\u0011BA\t\u0003\u0005]\u0001&o\u001c2bE&d\u0017n\u001d;jG\u000ec\u0017m]:jM&,'\u000f\u0005\u0002\u0014-5\tAC\u0003\u0002\u0016\t\u00051A.\u001b8bY\u001eL!a\u0006\u000b\u0003\rY+7\r^8s!\ty\u0001\u0001\u0005\u0002\u00105%\u00111D\u0001\u0002\u0017\u0017:s5\t\\1tg&4\u0017nY1uS>tWj\u001c3fYB\u0011Q\u0004I\u0007\u0002=)\u0011q\u0004B\u0001\u0004W:t\u0017BA\u0011\u001f\u0005%YeJ\u0014)be\u0006l7\u000f\u0005\u0002$Q5\tAE\u0003\u0002&M\u000511\u000f[1sK\u0012T!a\n\u0003\u0002\u000bA\f'/Y7\n\u0005%\"#\u0001\u0004%bg^+\u0017n\u001a5u\u0007>d\u0007\u0002C\u0016\u0001\u0005\u000b\u0007I\u0011\t\u0017\u0002\u0007ULG-F\u0001.!\tqCG\u0004\u00020e5\t\u0001GC\u00012\u0003\u0015\u00198-\u00197b\u0013\t\u0019\u0004'\u0001\u0004Qe\u0016$WMZ\u0005\u0003kY\u0012aa\u0015;sS:<'BA\u001a1\u0011!A\u0004A!A!\u0002\u0013i\u0013\u0001B;jI\u0002BQA\u000f\u0001\u0005\u0002m\na\u0001P5oSRtDC\u0001\r=\u0011\u0015Y\u0013\b1\u0001.\u0011\u0015Q\u0004\u0001\"\u0001?)\u0005A\u0002\"\u0002!\u0001\t\u0003\n\u0015AD:fi\u001a+\u0017\r^;sKN\u001cu\u000e\u001c\u000b\u0003\u0005\u000ek\u0011\u0001\u0001\u0005\u0006\t~\u0002\r!L\u0001\u0006m\u0006dW/\u001a\u0005\u0006\r\u0002!\teR\u0001\fg\u0016$H*\u00192fY\u000e{G\u000e\u0006\u0002C\u0011\")A)\u0012a\u0001[!)!\n\u0001C\u0001\u0017\u0006a1/\u001a;XK&<\u0007\u000e^\"pYR\u0011!\t\u0014\u0005\u0006\t&\u0003\r!\f\u0005\u0006\u001d\u0002!\taT\u0001\u0005g\u0016$8\n\u0006\u0002C!\")A)\u0014a\u0001#B\u0011qFU\u0005\u0003'B\u00121!\u00138u\u0011\u0015)\u0006\u0001\"\u0001W\u00039\u0019X\r\u001e+paR\u0013X-Z*ju\u0016$\"AQ,\t\u000b\u0011#\u0006\u0019A)\t\u000be\u0003A\u0011\u0001.\u0002%M,G\u000fV8q)J,W\rT3bMNK'0\u001a\u000b\u0003\u0005nCQ\u0001\u0012-A\u0002ECQ!\u0018\u0001\u0005\u0002y\u000b!c]3u'V\u0014GK]3f\u0019\u0016\fgmU5{KR\u0011!i\u0018\u0005\u0006\tr\u0003\r!\u0015\u0005\u0006C\u0002!\tAY\u0001\u0019g\u0016$()\u001e4gKJ\u001c\u0016N_3TC6\u0004H.Z*ju\u0016\u001cHC\u0001\"d\u0011\u0015!\u0005\r1\u0001e!\ryS-U\u0005\u0003MB\u0012Q!\u0011:sCfDQ\u0001\u001b\u0001\u0005\u0002%\f1c]3u\u0005\u0006d\u0017M\\2f)\"\u0014Xm\u001d5pY\u0012$\"A\u00116\t\u000b\u0011;\u0007\u0019A6\u0011\u0005=b\u0017BA71\u0005\u0019!u.\u001e2mK\")q\u000e\u0001C\u0001a\u000691/\u001a;TK\u0016$GC\u0001\"r\u0011\u0015!e\u000e1\u0001s!\ty3/\u0003\u0002ua\t!Aj\u001c8h\u0011\u00151\b\u0001\"\u0015x\u0003\u0015!(/Y5o)\tI\u0002\u0010C\u0003zk\u0002\u0007!0A\u0004eCR\f7/\u001a;1\u0007m\f9\u0001\u0005\u0003}\u007f\u0006\rQ\"A?\u000b\u0005y4\u0011aA:rY&\u0019\u0011\u0011A?\u0003\u000f\u0011\u000bG/Y:fiB!\u0011QAA\u0004\u0019\u0001!1\"!\u0003y\u0003\u0003\u0005\tQ!\u0001\u0002\f\t\u0019q\fJ\u0019\u0012\t\u00055\u00111\u0003\t\u0004_\u0005=\u0011bAA\ta\t9aj\u001c;iS:<\u0007cA\u0018\u0002\u0016%\u0019\u0011q\u0003\u0019\u0003\u0007\u0005s\u0017\u0010C\u0004\u0002\u001c\u0001!\t%!\b\u0002\u0007\u0019LG\u000fF\u0002\u001a\u0003?Aq!_A\r\u0001\u0004\t\t\u0003\r\u0003\u0002$\u0005\u001d\u0002\u0003\u0002?\u0000\u0003K\u0001B!!\u0002\u0002(\u0011a\u0011\u0011FA\u0010\u0003\u0003\u0005\tQ!\u0001\u0002\f\t\u0019q\f\n\u001a\t\u000f\u00055\u0002\u0001\"\u0011\u00020\u0005!1m\u001c9z)\rA\u0012\u0011\u0007\u0005\t\u0003g\tY\u00031\u0001\u00026\u0005)Q\r\u001f;sCB!\u0011qGA\u001d\u001b\u00051\u0013bAA\u001eM\tA\u0001+\u0019:b[6\u000b\u0007\u000f")
public class KNNClassifier
extends ProbabilisticClassifier<Vector, KNNClassifier, KNNClassificationModel>
implements KNNParams,
HasWeightCol {
    private final String uid;
    private final Param<String> weightCol;
    private final IntParam topTreeSize;
    private final IntParam topTreeLeafSize;
    private final IntParam subTreeLeafSize;
    private final IntArrayParam bufferSizeSampleSizes;
    private final DoubleParam balanceThreshold;
    private final LongParam seed;
    private final Param<String> neighborsCol;
    private final Param<String> distanceCol;
    private final IntParam k;
    private final DoubleParam maxDistance;
    private final DoubleParam bufferSize;
    private final StringArrayParam inputCols;

    public final Param<String> weightCol() {
        return this.weightCol;
    }

    public final void org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(Param x$1) {
        this.weightCol = x$1;
    }

    public final String getWeightCol() {
        return HasWeightCol.class.getWeightCol((HasWeightCol)this);
    }

    @Override
    public IntParam topTreeSize() {
        return this.topTreeSize;
    }

    @Override
    public IntParam topTreeLeafSize() {
        return this.topTreeLeafSize;
    }

    @Override
    public IntParam subTreeLeafSize() {
        return this.subTreeLeafSize;
    }

    @Override
    public IntArrayParam bufferSizeSampleSizes() {
        return this.bufferSizeSampleSizes;
    }

    @Override
    public DoubleParam balanceThreshold() {
        return this.balanceThreshold;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNParams$_setter_$topTreeSize_$eq(IntParam x$1) {
        this.topTreeSize = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNParams$_setter_$topTreeLeafSize_$eq(IntParam x$1) {
        this.topTreeLeafSize = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNParams$_setter_$subTreeLeafSize_$eq(IntParam x$1) {
        this.subTreeLeafSize = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNParams$_setter_$bufferSizeSampleSizes_$eq(IntArrayParam x$1) {
        this.bufferSizeSampleSizes = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNParams$_setter_$balanceThreshold_$eq(DoubleParam x$1) {
        this.balanceThreshold = x$1;
    }

    @Override
    public int getTopTreeSize() {
        return KNNParams$class.getTopTreeSize(this);
    }

    @Override
    public int getTopTreeLeafSize() {
        return KNNParams$class.getTopTreeLeafSize(this);
    }

    @Override
    public int getSubTreeLeafSize() {
        return KNNParams$class.getSubTreeLeafSize(this);
    }

    @Override
    public int[] getBufferSizeSampleSizes() {
        return KNNParams$class.getBufferSizeSampleSizes(this);
    }

    @Override
    public double getBalanceThreshold() {
        return KNNParams$class.getBalanceThreshold(this);
    }

    @Override
    public StructType validateAndTransformSchema(StructType schema) {
        return KNNParams$class.validateAndTransformSchema(this, schema);
    }

    public final LongParam seed() {
        return this.seed;
    }

    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam x$1) {
        this.seed = x$1;
    }

    public final long getSeed() {
        return HasSeed.class.getSeed((HasSeed)this);
    }

    @Override
    public Param<String> neighborsCol() {
        return this.neighborsCol;
    }

    @Override
    public Param<String> distanceCol() {
        return this.distanceCol;
    }

    @Override
    public IntParam k() {
        return this.k;
    }

    @Override
    public DoubleParam maxDistance() {
        return this.maxDistance;
    }

    @Override
    public DoubleParam bufferSize() {
        return this.bufferSize;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$neighborsCol_$eq(Param x$1) {
        this.neighborsCol = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$distanceCol_$eq(Param x$1) {
        this.distanceCol = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$k_$eq(IntParam x$1) {
        this.k = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$maxDistance_$eq(DoubleParam x$1) {
        this.maxDistance = x$1;
    }

    @Override
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$bufferSize_$eq(DoubleParam x$1) {
        this.bufferSize = x$1;
    }

    @Override
    public String getNeighborsCol() {
        return KNNModelParams$class.getNeighborsCol(this);
    }

    @Override
    public String getDistanceCol() {
        return KNNModelParams$class.getDistanceCol(this);
    }

    @Override
    public int getK() {
        return KNNModelParams$class.getK(this);
    }

    @Override
    public double getMaxDistance() {
        return KNNModelParams$class.getMaxDistance(this);
    }

    @Override
    public double getBufferSize() {
        return KNNModelParams$class.getBufferSize(this);
    }

    @Override
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(RDD<Vector> data, Broadcast<Tree> topTree, RDD<Tree> subTrees) {
        return KNNModelParams$class.transform((KNNModelParams)this, data, topTree, subTrees);
    }

    @Override
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(Dataset<?> dataset, Broadcast<Tree> topTree, RDD<Tree> subTrees) {
        return KNNModelParams$class.transform((KNNModelParams)this, dataset, topTree, subTrees);
    }

    public final StringArrayParam inputCols() {
        return this.inputCols;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCols$_setter_$inputCols_$eq(StringArrayParam x$1) {
        this.inputCols = x$1;
    }

    public final String[] getInputCols() {
        return HasInputCols.class.getInputCols((HasInputCols)this);
    }

    public String uid() {
        return this.uid;
    }

    public KNNClassifier setFeaturesCol(String value) {
        return (KNNClassifier)this.set(this.featuresCol(), value);
    }

    public KNNClassifier setLabelCol(String value) {
        this.set(this.labelCol(), value);
        return ((String)this.$(this.weightCol())).isEmpty() ? (KNNClassifier)this.set((Param)this.inputCols(), new String[]{value}) : (KNNClassifier)this.set((Param)this.inputCols(), new String[]{value, (String)this.$(this.weightCol())});
    }

    public KNNClassifier setWeightCol(String value) {
        this.set(this.weightCol(), value);
        return value.isEmpty() ? (KNNClassifier)this.set((Param)this.inputCols(), new String[]{(String)this.$(this.labelCol())}) : (KNNClassifier)this.set((Param)this.inputCols(), new String[]{(String)this.$(this.labelCol()), value});
    }

    public KNNClassifier setK(int value) {
        return (KNNClassifier)this.set((Param)this.k(), BoxesRunTime.boxToInteger((int)value));
    }

    public KNNClassifier setTopTreeSize(int value) {
        return (KNNClassifier)this.set((Param)this.topTreeSize(), BoxesRunTime.boxToInteger((int)value));
    }

    public KNNClassifier setTopTreeLeafSize(int value) {
        return (KNNClassifier)this.set((Param)this.topTreeLeafSize(), BoxesRunTime.boxToInteger((int)value));
    }

    public KNNClassifier setSubTreeLeafSize(int value) {
        return (KNNClassifier)this.set((Param)this.subTreeLeafSize(), BoxesRunTime.boxToInteger((int)value));
    }

    public KNNClassifier setBufferSizeSampleSizes(int[] value) {
        return (KNNClassifier)this.set((Param)this.bufferSizeSampleSizes(), value);
    }

    public KNNClassifier setBalanceThreshold(double value) {
        return (KNNClassifier)this.set((Param)this.balanceThreshold(), BoxesRunTime.boxToDouble((double)value));
    }

    public KNNClassifier setSeed(long value) {
        return (KNNClassifier)this.set((Param)this.seed(), BoxesRunTime.boxToLong((long)value));
    }

    public KNNClassificationModel train(Dataset<?> dataset) {
        RDD instances = this.extractLabeledPoints(dataset).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final Tuple2<Object, Vector> apply(LabeledPoint x0$1) {
                LabeledPoint labeledPoint = x0$1;
                if (labeledPoint != null) {
                    double label = labeledPoint.label();
                    Vector features = labeledPoint.features();
                    double d = label;
                    if (features != null) {
                        Vector vector = features;
                        Tuple2 tuple2 = new Tuple2((Object)BoxesRunTime.boxToDouble((double)d), (Object)vector);
                        return tuple2;
                    }
                }
                throw new MatchError((Object)labeledPoint);
            }
        }, ClassTag$.MODULE$.apply(Tuple2.class));
        StorageLevel storageLevel = dataset.rdd().getStorageLevel();
        StorageLevel storageLevel2 = StorageLevel$.MODULE$.NONE();
        boolean handlePersistence = !(storageLevel != null ? !storageLevel.equals(storageLevel2) : storageLevel2 != null);
        Object object = handlePersistence ? instances.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK()) : BoxedUnit.UNIT;
        MultiClassSummarizer x$1 = new MultiClassSummarizer();
        Serializable x$2 = new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final MultiClassSummarizer apply(MultiClassSummarizer c, Tuple2<Object, Vector> v) {
                Tuple2 tuple2 = new Tuple2((Object)c, v);
                if (tuple2 != null) {
                    MultiClassSummarizer labelSummarizer = (MultiClassSummarizer)tuple2._1();
                    Tuple2 tuple22 = (Tuple2)tuple2._2();
                    if (labelSummarizer != null) {
                        MultiClassSummarizer multiClassSummarizer = labelSummarizer;
                        if (tuple22 != null) {
                            double label = tuple22._1$mcD$sp();
                            Vector features = (Vector)tuple22._2();
                            double d = label;
                            if (features != null) {
                                MultiClassSummarizer multiClassSummarizer2 = multiClassSummarizer.add(d, multiClassSummarizer.add$default$2());
                                return multiClassSummarizer2;
                            }
                        }
                    }
                }
                throw new MatchError((Object)tuple2);
            }
        };
        Serializable x$3 = new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final MultiClassSummarizer apply(MultiClassSummarizer c1, MultiClassSummarizer c2) {
                Tuple2 tuple2 = new Tuple2((Object)c1, (Object)c2);
                if (tuple2 != null) {
                    MultiClassSummarizer classSummarizer1 = (MultiClassSummarizer)tuple2._1();
                    MultiClassSummarizer classSummarizer2 = (MultiClassSummarizer)tuple2._2();
                    if (classSummarizer1 != null) {
                        MultiClassSummarizer multiClassSummarizer = classSummarizer1;
                        if (classSummarizer2 != null) {
                            MultiClassSummarizer multiClassSummarizer2 = classSummarizer2;
                            MultiClassSummarizer multiClassSummarizer3 = multiClassSummarizer.merge(multiClassSummarizer2);
                            return multiClassSummarizer3;
                        }
                    }
                }
                throw new MatchError((Object)tuple2);
            }
        };
        int x$4 = instances.treeAggregate$default$4((Object)x$1);
        MultiClassSummarizer labelSummarizer = (MultiClassSummarizer)instances.treeAggregate((Object)x$1, (Function2)x$2, (Function2)x$3, x$4, ClassTag$.MODULE$.apply(MultiClassSummarizer.class));
        double[] histogram = labelSummarizer.histogram();
        long numInvalid = labelSummarizer.countInvalid();
        int numClasses = histogram.length;
        if (numInvalid != 0L) {
            String msg = new StringBuilder().append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Classification labels should be in {0 to ", " "})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)(numClasses - 1))}))).append((Object)new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Found ", " invalid labels."})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToLong((long)numInvalid)}))).toString();
            this.logError((Function0)new Serializable(this, msg){
                public static final long serialVersionUID = 0L;
                private final String msg$1;

                public final String apply() {
                    return this.msg$1;
                }
                {
                    this.msg$1 = msg$1;
                }
            });
            throw new SparkException(msg);
        }
        Model knnModel = ((KNN)this.copyValues((Params)new KNN(), this.copyValues$default$2())).fit((Dataset)dataset);
        return knnModel.toNewClassificationModel(this.uid(), numClasses);
    }

    public KNNClassificationModel fit(Dataset<?> dataset) {
        this.transformSchema(dataset.schema(), true);
        KNNClassificationModel model = this.train(dataset);
        double bufferSize = model.getBufferSize();
        return ((KNNClassificationModel)this.copyValues((Params)model.setParent((Estimator)this), this.copyValues$default$2())).setBufferSize(bufferSize);
    }

    public KNNClassifier copy(ParamMap extra) {
        return (KNNClassifier)this.defaultCopy(extra);
    }

    public KNNClassifier(String uid) {
        this.uid = uid;
        HasInputCols.class.$init$((HasInputCols)this);
        KNNModelParams$class.$init$(this);
        HasSeed.class.$init$((HasSeed)this);
        KNNParams$class.$init$(this);
        HasWeightCol.class.$init$((HasWeightCol)this);
        this.setDefault((Param)this.inputCols(), new String[]{(String)this.$(this.labelCol())});
        this.setDefault((Seq)Predef$.MODULE$.wrapRefArray((Object[])new ParamPair[]{this.weightCol().$minus$greater((Object)"")}));
    }

    public KNNClassifier() {
        this(Identifiable$.MODULE$.randomUID("knnc"));
    }
}

