/*
 * Decompiled with CFR 0.152.
 */
package org.apache.ignite.ml.composition.combinators.parallel;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;

public class TrainersParallelComposition<I, O, L>
extends DatasetTrainer<IgniteModel<I, List<O>>, L> {
    private final List<DatasetTrainer<IgniteModel<I, O>, L>> trainers;

    public <T extends DatasetTrainer<? extends IgniteModel<I, O>, L>> TrainersParallelComposition(List<T> trainers) {
        this.trainers = trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
    }

    public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(List<T> trainers) {
        List trs = trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
        return new TrainersParallelComposition<I, O, L>(trs);
    }

    @Override
    public <K, V> IgniteModel<I, List<O>> fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        List tasks = this.trainers.stream().map(tr -> () -> tr.fit(datasetBuilder, preprocessor)).collect(Collectors.toList());
        List mdls = this.environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
        return new ModelsParallelComposition(mdls);
    }

    @Override
    public <K, V> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        this.learningEnvironment().initDeployingContext(preprocessor);
        ModelsParallelComposition typedMdl = (ModelsParallelComposition)mdl;
        assert (typedMdl.submodels().size() == this.trainers.size());
        ArrayList tasks = new ArrayList();
        int i = 0;
        while (i < this.trainers.size()) {
            int j = i++;
            tasks.add(() -> this.trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder, preprocessor));
        }
        List mdls = this.environment.parallelismStrategy().submit(tasks).stream().map(Promise::unsafeGet).collect(Collectors.toList());
        return new ModelsParallelComposition(mdls);
    }

    @Override
    public boolean isUpdateable(IgniteModel<I, List<O>> mdl) {
        throw new IllegalStateException();
    }

    @Override
    protected <K, V> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
        throw new IllegalStateException();
    }
}

