package org.deeplearning4j.spark.impl.repartitioner;

import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.deeplearning4j.spark.api.Repartitioner;
import org.deeplearning4j.spark.impl.common.CountPartitionsFunction;
import org.deeplearning4j.spark.impl.common.repartition.EqualPartitioner;
import org.deeplearning4j.spark.util.SparkUtils;
import org.nd4j.linalg.util.MathUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

/* loaded from: input_file:org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.class */
public class EqualRepartitioner implements Repartitioner {
    private static final Logger log = LoggerFactory.getLogger(EqualRepartitioner.class);

    @Override // org.deeplearning4j.spark.api.Repartitioner
    public <T> JavaRDD<T> repartition(JavaRDD<T> javaRDD, int i, int i2) {
        return repartition(javaRDD, i2, (List<Tuple2<Integer, Integer>>) javaRDD.mapPartitionsWithIndex(new CountPartitionsFunction(), true).collect());
    }

    public static <T> JavaRDD<T> repartition(JavaRDD<T> javaRDD, int i, List<Tuple2<Integer, Integer>> list) {
        int i2 = 0;
        int size = list.size();
        Iterator<Tuple2<Integer, Integer>> it = list.iterator();
        while (it.hasNext()) {
            i2 += ((Integer) it.next()._2()).intValue();
        }
        int floor = (int) Math.floor(i2 / i);
        int ceil = (int) Math.ceil(i2 / i);
        boolean z = false;
        for (Tuple2<Integer, Integer> tuple2 : list) {
            if (((Integer) tuple2._2()).intValue() < floor || ((Integer) tuple2._2()).intValue() > ceil) {
                z = true;
                break;
            }
        }
        if (size == i && !z) {
            return javaRDD;
        }
        JavaPairRDD indexedRDD = SparkUtils.indexedRDD(javaRDD);
        int i3 = i2 % i;
        int[] iArr = null;
        if (i3 > 0) {
            iArr = new int[i3];
            int[] iArr2 = new int[i];
            for (int i4 = 0; i4 < iArr2.length; i4++) {
                iArr2[i4] = i4;
            }
            MathUtils.shuffleArray(iArr2, new Random());
            System.arraycopy(iArr2, 0, iArr, 0, i3);
        }
        return indexedRDD.partitionBy(new EqualPartitioner(i, i2 / i, iArr)).values();
    }
}
