package org.apache.spark.ml.classification;

import org.apache.spark.SparkException;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.PredictionModel;
import org.apache.spark.ml.knn.KNN;
import org.apache.spark.ml.knn.KNNModelParams;
import org.apache.spark.ml.knn.KNNParams;
import org.apache.spark.ml.knn.Tree;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.param.DoubleParam;
import org.apache.spark.ml.param.IntArrayParam;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.LongParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.ml.param.StringArrayParam;
import org.apache.spark.ml.param.shared.HasInputCols;
import org.apache.spark.ml.param.shared.HasSeed;
import org.apache.spark.ml.param.shared.HasWeightCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import org.apache.spark.storage.StorageLevel$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: KNNClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005ub\u0001B\u0001\u0003\u00015\u0011Qb\u0013(O\u00072\f7o]5gS\u0016\u0014(BA\u0002\u0005\u00039\u0019G.Y:tS\u001aL7-\u0019;j_:T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M!\u0001A\u0004\u000f#!\u0015y\u0001C\u0005\r\u001a\u001b\u0005\u0011\u0011BA\t\u0003\u0005]\u0001&o\u001c2bE&d\u0017n\u001d;jG\u000ec\u0017m]:jM&,'\u000f\u0005\u0002\u0014-5\tAC\u0003\u0002\u0016\t\u00051A.\u001b8bY\u001eL!a\u0006\u000b\u0003\rY+7\r^8s!\ty\u0001\u0001\u0005\u0002\u00105%\u00111D\u0001\u0002\u0017\u0017:s5\t\\1tg&4\u0017nY1uS>tWj\u001c3fYB\u0011Q\u0004I\u0007\u0002=)\u0011q\u0004B\u0001\u0004W:t\u0017BA\u0011\u001f\u0005%YeJ\u0014)be\u0006l7\u000f\u0005\u0002$Q5\tAE\u0003\u0002&M\u000511\u000f[1sK\u0012T!a\n\u0003\u0002\u000bA\f'/Y7\n\u0005%\"#\u0001\u0004%bg^+\u0017n\u001a5u\u0007>d\u0007\u0002C\u0016\u0001\u0005\u000b\u0007I\u0011\t\u0017\u0002\u0007ULG-F\u0001.!\tqCG\u0004\u00020e5\t\u0001GC\u00012\u0003\u0015\u00198-\u00197b\u0013\t\u0019\u0004'\u0001\u0004Qe\u0016$WMZ\u0005\u0003kY\u0012aa\u0015;sS:<'BA\u001a1\u0011!A\u0004A!A!\u0002\u0013i\u0013\u0001B;jI\u0002BQA\u000f\u0001\u0005\u0002m\na\u0001P5oSRtDC\u0001\r=\u0011\u0015Y\u0013\b1\u0001.\u0011\u0015Q\u0004\u0001\"\u0001?)\u0005A\u0002\"\u0002!\u0001\t\u0003\n\u0015AD:fi\u001a+\u0017\r^;sKN\u001cu\u000e\u001c\u000b\u0003\u0005\u000ek\u0011\u0001\u0001\u0005\u0006\t~\u0002\r!L\u0001\u0006m\u0006dW/\u001a\u0005\u0006\r\u0002!\teR\u0001\fg\u0016$H*\u00192fY\u000e{G\u000e\u0006\u0002C\u0011\")A)\u0012a\u0001[!)!\n\u0001C\u0001\u0017\u0006a1/\u001a;XK&<\u0007\u000e^\"pYR\u0011!\t\u0014\u0005\u0006\t&\u0003\r!\f\u0005\u0006\u001d\u0002!\taT\u0001\u0005g\u0016$8\n\u0006\u0002C!\")A)\u0014a\u0001#B\u0011qFU\u0005\u0003'B\u00121!\u00138u\u0011\u0015)\u0006\u0001\"\u0001W\u00039\u0019X\r\u001e+paR\u0013X-Z*ju\u0016$\"AQ,\t\u000b\u0011#\u0006\u0019A)\t\u000be\u0003A\u0011\u0001.\u0002%M,G\u000fV8q)J,W\rT3bMNK'0\u001a\u000b\u0003\u0005nCQ\u0001\u0012-A\u0002ECQ!\u0018\u0001\u0005\u0002y\u000b!c]3u'V\u0014GK]3f\u0019\u0016\fgmU5{KR\u0011!i\u0018\u0005\u0006\tr\u0003\r!\u0015\u0005\u0006C\u0002!\tAY\u0001\u0019g\u0016$()\u001e4gKJ\u001c\u0016N_3TC6\u0004H.Z*ju\u0016\u001cHC\u0001\"d\u0011\u0015!\u0005\r1\u0001e!\ryS-U\u0005\u0003MB\u0012Q!\u0011:sCfDQ\u0001\u001b\u0001\u0005\u0002%\f1c]3u\u0005\u0006d\u0017M\\2f)\"\u0014Xm\u001d5pY\u0012$\"A\u00116\t\u000b\u0011;\u0007\u0019A6\u0011\u0005=b\u0017BA71\u0005\u0019!u.\u001e2mK\")q\u000e\u0001C\u0001a\u000691/\u001a;TK\u0016$GC\u0001\"r\u0011\u0015!e\u000e1\u0001s!\ty3/\u0003\u0002ua\t!Aj\u001c8h\u0011\u00151\b\u0001\"\u0015x\u0003\u0015!(/Y5o)\tI\u0002\u0010C\u0003zk\u0002\u0007!0A\u0004eCR\f7/\u001a;1\u0007m\f9\u0001\u0005\u0003}\u007f\u0006\rQ\"A?\u000b\u0005y4\u0011aA:rY&\u0019\u0011\u0011A?\u0003\u000f\u0011\u000bG/Y:fiB!\u0011QAA\u0004\u0019\u0001!1\"!\u0003y\u0003\u0003\u0005\tQ!\u0001\u0002\f\t\u0019q\fJ\u0019\u0012\t\u00055\u00111\u0003\t\u0004_\u0005=\u0011bAA\ta\t9aj\u001c;iS:<\u0007cA\u0018\u0002\u0016%\u0019\u0011q\u0003\u0019\u0003\u0007\u0005s\u0017\u0010C\u0004\u0002\u001c\u0001!\t%!\b\u0002\u0007\u0019LG\u000fF\u0002\u001a\u0003?Aq!_A\r\u0001\u0004\t\t\u0003\r\u0003\u0002$\u0005\u001d\u0002\u0003\u0002?��\u0003K\u0001B!!\u0002\u0002(\u0011a\u0011\u0011FA\u0010\u0003\u0003\u0005\tQ!\u0001\u0002\f\t\u0019q\f\n\u001a\t\u000f\u00055\u0002\u0001\"\u0011\u00020\u0005!1m\u001c9z)\rA\u0012\u0011\u0007\u0005\t\u0003g\tY\u00031\u0001\u00026\u0005)Q\r\u001f;sCB!\u0011qGA\u001d\u001b\u00051\u0013bAA\u001eM\tA\u0001+\u0019:b[6\u000b\u0007\u000f")
/* loaded from: input_file:org/apache/spark/ml/classification/KNNClassifier.class */
public class KNNClassifier extends ProbabilisticClassifier<Vector, KNNClassifier, KNNClassificationModel> implements KNNParams, HasWeightCol {
    private final String uid;
    private final Param<String> weightCol;
    private final IntParam topTreeSize;
    private final IntParam topTreeLeafSize;
    private final IntParam subTreeLeafSize;
    private final IntArrayParam bufferSizeSampleSizes;
    private final DoubleParam balanceThreshold;
    private final LongParam seed;
    private final Param<String> neighborsCol;
    private final Param<String> distanceCol;
    private final IntParam k;
    private final DoubleParam maxDistance;
    private final DoubleParam bufferSize;
    private final StringArrayParam inputCols;

    public final Param<String> weightCol() {
        return this.weightCol;
    }

    public final void org$apache$spark$ml$param$shared$HasWeightCol$_setter_$weightCol_$eq(Param param) {
        this.weightCol = param;
    }

    public final String getWeightCol() {
        return HasWeightCol.class.getWeightCol(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public IntParam topTreeSize() {
        return this.topTreeSize;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public IntParam topTreeLeafSize() {
        return this.topTreeLeafSize;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public IntParam subTreeLeafSize() {
        return this.subTreeLeafSize;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public IntArrayParam bufferSizeSampleSizes() {
        return this.bufferSizeSampleSizes;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public DoubleParam balanceThreshold() {
        return this.balanceThreshold;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public void org$apache$spark$ml$knn$KNNParams$_setter_$topTreeSize_$eq(IntParam intParam) {
        this.topTreeSize = intParam;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public void org$apache$spark$ml$knn$KNNParams$_setter_$topTreeLeafSize_$eq(IntParam intParam) {
        this.topTreeLeafSize = intParam;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public void org$apache$spark$ml$knn$KNNParams$_setter_$subTreeLeafSize_$eq(IntParam intParam) {
        this.subTreeLeafSize = intParam;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public void org$apache$spark$ml$knn$KNNParams$_setter_$bufferSizeSampleSizes_$eq(IntArrayParam intArrayParam) {
        this.bufferSizeSampleSizes = intArrayParam;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public void org$apache$spark$ml$knn$KNNParams$_setter_$balanceThreshold_$eq(DoubleParam doubleParam) {
        this.balanceThreshold = doubleParam;
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public int getTopTreeSize() {
        return KNNParams.Cclass.getTopTreeSize(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public int getTopTreeLeafSize() {
        return KNNParams.Cclass.getTopTreeLeafSize(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public int getSubTreeLeafSize() {
        return KNNParams.Cclass.getSubTreeLeafSize(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public int[] getBufferSizeSampleSizes() {
        return KNNParams.Cclass.getBufferSizeSampleSizes(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public double getBalanceThreshold() {
        return KNNParams.Cclass.getBalanceThreshold(this);
    }

    @Override // org.apache.spark.ml.knn.KNNParams
    public StructType validateAndTransformSchema(StructType structType) {
        return KNNParams.Cclass.validateAndTransformSchema(this, structType);
    }

    public final LongParam seed() {
        return this.seed;
    }

    public final void org$apache$spark$ml$param$shared$HasSeed$_setter_$seed_$eq(LongParam longParam) {
        this.seed = longParam;
    }

    public final long getSeed() {
        return HasSeed.class.getSeed(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public Param<String> neighborsCol() {
        return this.neighborsCol;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public Param<String> distanceCol() {
        return this.distanceCol;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public IntParam k() {
        return this.k;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public DoubleParam maxDistance() {
        return this.maxDistance;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public DoubleParam bufferSize() {
        return this.bufferSize;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$neighborsCol_$eq(Param param) {
        this.neighborsCol = param;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$distanceCol_$eq(Param param) {
        this.distanceCol = param;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$k_$eq(IntParam intParam) {
        this.k = intParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$maxDistance_$eq(DoubleParam doubleParam) {
        this.maxDistance = doubleParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public void org$apache$spark$ml$knn$KNNModelParams$_setter_$bufferSize_$eq(DoubleParam doubleParam) {
        this.bufferSize = doubleParam;
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public String getNeighborsCol() {
        return KNNModelParams.Cclass.getNeighborsCol(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public String getDistanceCol() {
        return KNNModelParams.Cclass.getDistanceCol(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public int getK() {
        return KNNModelParams.Cclass.getK(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public double getMaxDistance() {
        return KNNModelParams.Cclass.getMaxDistance(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public double getBufferSize() {
        return KNNModelParams.Cclass.getBufferSize(this);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(RDD<Vector> rdd, Broadcast<Tree> broadcast, RDD<Tree> rdd2) {
        return KNNModelParams.Cclass.transform(this, rdd, broadcast, rdd2);
    }

    @Override // org.apache.spark.ml.knn.KNNModelParams
    public RDD<Tuple2<Object, Tuple2<Row, Object>[]>> transform(Dataset<?> dataset, Broadcast<Tree> broadcast, RDD<Tree> rdd) {
        return KNNModelParams.Cclass.transform(this, dataset, broadcast, rdd);
    }

    public final StringArrayParam inputCols() {
        return this.inputCols;
    }

    public final void org$apache$spark$ml$param$shared$HasInputCols$_setter_$inputCols_$eq(StringArrayParam stringArrayParam) {
        this.inputCols = stringArrayParam;
    }

    public final String[] getInputCols() {
        return HasInputCols.class.getInputCols(this);
    }

    public String uid() {
        return this.uid;
    }

    /* renamed from: setFeaturesCol, reason: merged with bridge method [inline-methods] */
    public KNNClassifier m13setFeaturesCol(String str) {
        return set(featuresCol(), str);
    }

    /* renamed from: setLabelCol, reason: merged with bridge method [inline-methods] */
    public KNNClassifier m12setLabelCol(String str) {
        set(labelCol(), str);
        return ((String) $(weightCol())).isEmpty() ? set(inputCols(), new String[]{str}) : set(inputCols(), new String[]{str, (String) $(weightCol())});
    }

    public KNNClassifier setWeightCol(String str) {
        set(weightCol(), str);
        return str.isEmpty() ? set(inputCols(), new String[]{(String) $(labelCol())}) : set(inputCols(), new String[]{(String) $(labelCol()), str});
    }

    public KNNClassifier setK(int i) {
        return set(k(), BoxesRunTime.boxToInteger(i));
    }

    public KNNClassifier setTopTreeSize(int i) {
        return set(topTreeSize(), BoxesRunTime.boxToInteger(i));
    }

    public KNNClassifier setTopTreeLeafSize(int i) {
        return set(topTreeLeafSize(), BoxesRunTime.boxToInteger(i));
    }

    public KNNClassifier setSubTreeLeafSize(int i) {
        return set(subTreeLeafSize(), BoxesRunTime.boxToInteger(i));
    }

    public KNNClassifier setBufferSizeSampleSizes(int[] iArr) {
        return set(bufferSizeSampleSizes(), iArr);
    }

    public KNNClassifier setBalanceThreshold(double d) {
        return set(balanceThreshold(), BoxesRunTime.boxToDouble(d));
    }

    public KNNClassifier setSeed(long j) {
        return set(seed(), BoxesRunTime.boxToLong(j));
    }

    public KNNClassificationModel train(Dataset<?> dataset) {
        RDD map = extractLabeledPoints(dataset).map(new KNNClassifier$$anonfun$1(this), ClassTag$.MODULE$.apply(Tuple2.class));
        StorageLevel storageLevel = dataset.rdd().getStorageLevel();
        StorageLevel NONE = StorageLevel$.MODULE$.NONE();
        if (storageLevel != null ? storageLevel.equals(NONE) : NONE == null) {
            map.persist(StorageLevel$.MODULE$.MEMORY_AND_DISK());
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
        MultiClassSummarizer multiClassSummarizer = new MultiClassSummarizer();
        MultiClassSummarizer multiClassSummarizer2 = (MultiClassSummarizer) map.treeAggregate(multiClassSummarizer, new KNNClassifier$$anonfun$2(this), new KNNClassifier$$anonfun$3(this), map.treeAggregate$default$4(multiClassSummarizer), ClassTag$.MODULE$.apply(MultiClassSummarizer.class));
        double[] histogram = multiClassSummarizer2.histogram();
        long countInvalid = multiClassSummarizer2.countInvalid();
        int length = histogram.length;
        if (countInvalid == 0) {
            return copyValues(new KNN(), copyValues$default$2()).fit(dataset).toNewClassificationModel(uid(), length);
        }
        String stringBuilder = new StringBuilder().append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Classification labels should be in {0 to ", " "})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToInteger(length - 1)}))).append(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Found ", " invalid labels."})).s(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToLong(countInvalid)}))).toString();
        logError(new KNNClassifier$$anonfun$train$1(this, stringBuilder));
        throw new SparkException(stringBuilder);
    }

    public KNNClassificationModel fit(Dataset<?> dataset) {
        transformSchema(dataset.schema(), true);
        KNNClassificationModel train = train(dataset);
        return copyValues(train.setParent(this), copyValues$default$2()).setBufferSize(train.getBufferSize());
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public KNNClassifier m8copy(ParamMap paramMap) {
        return defaultCopy(paramMap);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m9fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    /* renamed from: fit, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ PredictionModel m10fit(Dataset dataset) {
        return fit((Dataset<?>) dataset);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ PredictionModel m11train(Dataset dataset) {
        return train((Dataset<?>) dataset);
    }

    public KNNClassifier(String str) {
        this.uid = str;
        HasInputCols.class.$init$(this);
        KNNModelParams.Cclass.$init$(this);
        HasSeed.class.$init$(this);
        KNNParams.Cclass.$init$(this);
        HasWeightCol.class.$init$(this);
        setDefault(inputCols(), new String[]{(String) $(labelCol())});
        setDefault(Predef$.MODULE$.wrapRefArray(new ParamPair[]{weightCol().$minus$greater("")}));
    }

    public KNNClassifier() {
        this(Identifiable$.MODULE$.randomUID("knnc"));
    }
}
