/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.util.generators.primitives.vector;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.util.generators.DataStreamGenerator;
import org.apache.ignite.ml.util.generators.primitives.scalar.DiscreteRandomProducer;
import org.apache.ignite.ml.util.generators.primitives.vector.VectorGenerator;

public class VectorGeneratorsFamily
implements VectorGenerator {
    private final List<VectorGenerator> family;
    private final DiscreteRandomProducer selector;

    private VectorGeneratorsFamily(List<VectorGenerator> family, DiscreteRandomProducer selector) {
        this.family = family;
        this.selector = selector;
    }

    @Override
    public Vector get() {
        return (Vector)this.family.get(this.selector.getInt()).get();
    }

    public VectorWithDistributionId getWithId() {
        int id = this.selector.getInt();
        return new VectorWithDistributionId((Vector)this.family.get(id).get(), id);
    }

    @Override
    public DataStreamGenerator asDataStream() {
        final VectorGeneratorsFamily gen = this;
        return new DataStreamGenerator(){

            @Override
            public Stream<LabeledVector<Double>> labeled() {
                return Stream.generate(gen::getWithId).map(v -> new LabeledVector<Double>(((VectorWithDistributionId)v).vector, Double.valueOf(((VectorWithDistributionId)v).distributionId)));
            }
        };
    }

    public static class VectorWithDistributionId {
        private final Vector vector;
        private final int distributionId;

        public VectorWithDistributionId(Vector vector, int distributionId) {
            this.vector = vector;
            this.distributionId = distributionId;
        }

        public Vector vector() {
            return this.vector;
        }

        public int distributionId() {
            return this.distributionId;
        }
    }

    public static class Builder {
        private final List<VectorGenerator> family = new ArrayList<VectorGenerator>();
        private final List<Double> weights = new ArrayList<Double>();
        private IgniteFunction<VectorGenerator, VectorGenerator> mapper = x -> x;

        public Builder add(VectorGenerator generator, double weight) {
            A.ensure((weight > 0.0 ? 1 : 0) != 0, (String)"weight > 0");
            this.family.add(generator);
            this.weights.add(weight);
            return this;
        }

        public Builder add(VectorGenerator generator) {
            return this.add(generator, 1.0);
        }

        public Builder map(IgniteFunction<VectorGenerator, VectorGenerator> mapper) {
            IgniteFunction<VectorGenerator, VectorGenerator> old = this.mapper;
            this.mapper = x -> (VectorGenerator)mapper.apply((VectorGenerator)old.apply((VectorGenerator)x));
            return this;
        }

        public VectorGeneratorsFamily build() {
            return this.build(System.currentTimeMillis());
        }

        public VectorGeneratorsFamily build(long seed) {
            A.notEmpty(this.family, (String)"family.size != 0");
            double sumOfWeigts = this.weights.stream().mapToDouble(x -> x).sum();
            double[] probs = this.weights.stream().mapToDouble(w -> w / sumOfWeigts).toArray();
            List mappedFamily = this.family.stream().map(this.mapper).collect(Collectors.toList());
            return new VectorGeneratorsFamily(mappedFamily, new DiscreteRandomProducer(seed, probs));
        }
    }
}

