package org.nd4j.parameterserver.distributed.v2;

import io.reactivex.Flowable;
import io.reactivex.disposables.Disposable;
import io.reactivex.functions.Consumer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Atomic;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.v2.enums.PropagationMode;
import org.nd4j.parameterserver.distributed.v2.messages.impl.GradientsUpdateMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.handshake.HandshakeResponse;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.ModelParametersMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.ModelParametersRequest;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.UpdaterParametersMessage;
import org.nd4j.parameterserver.distributed.v2.messages.pairs.params.UpdaterParametersRequest;
import org.nd4j.parameterserver.distributed.v2.transport.RestartCallback;
import org.nd4j.parameterserver.distributed.v2.transport.Transport;
import org.nd4j.parameterserver.distributed.v2.transport.UpdaterParametersProvider;
import org.nd4j.parameterserver.distributed.v2.transport.UpdatesHandler;
import org.nd4j.parameterserver.distributed.v2.transport.impl.StaticPortSupplier;
import org.nd4j.parameterserver.distributed.v2.util.AbstractSubscriber;
import org.nd4j.parameterserver.distributed.v2.util.UpdaterParametersHolder;
import org.reactivestreams.Subscriber;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/parameterserver/distributed/v2/ModelParameterServer.class */
public final class ModelParameterServer {
    private static final Logger log = LoggerFactory.getLogger(ModelParameterServer.class);
    protected static final ModelParameterServer INSTANCE = new ModelParameterServer();
    private Transport transport;
    private INDArray masterModelParams;
    private INDArray masterUpdaterParams;
    private UpdaterParametersProvider updaterParametersProvider;
    private final BlockingQueue<INDArray> updatesQueue;
    protected final List<UpdatesHandler> updatesSubscribers;
    protected final List<Subscriber<INDArray>> modelParamsSubsribers;
    protected final List<Subscriber<INDArray>> updaterParamsSubscribers;
    private boolean masterMode;
    protected VoidConfiguration configuration;
    private final AtomicBoolean launchLock;
    private final AtomicBoolean stopLock;
    protected BlockingQueue<INDArray> updatesBacklog;
    protected final Atomic<UpdaterParametersHolder> updaterParameters;
    protected final ReentrantReadWriteLock updaterParamsLock;
    protected final AtomicBoolean gotFinalState;
    private Disposable disposable;
    private AtomicInteger iterationNumber;
    private AtomicInteger epochNumber;

    protected ModelParameterServer() {
        this.updatesQueue = new LinkedBlockingQueue(4096);
        this.updatesSubscribers = new CopyOnWriteArrayList();
        this.modelParamsSubsribers = new CopyOnWriteArrayList();
        this.updaterParamsSubscribers = new CopyOnWriteArrayList();
        this.launchLock = new AtomicBoolean(false);
        this.stopLock = new AtomicBoolean(false);
        this.updatesBacklog = new LinkedBlockingQueue();
        this.updaterParameters = new Atomic<>();
        this.updaterParamsLock = new ReentrantReadWriteLock();
        this.gotFinalState = new AtomicBoolean(false);
        this.iterationNumber = new AtomicInteger(0);
        this.epochNumber = new AtomicInteger(0);
    }

    public static ModelParameterServer getInstance() {
        return INSTANCE;
    }

    protected ModelParameterServer(@NonNull Transport transport) {
        this(transport, false);
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
    }

    protected ModelParameterServer(@NonNull Transport transport, boolean z) {
        this(VoidConfiguration.builder().portSupplier(new StaticPortSupplier(40123)).streamId(119).build(), transport, z);
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
    }

    public ModelParameterServer(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, boolean z) {
        this();
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        configure(voidConfiguration, transport, z);
    }

    public void configure(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, boolean z) {
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        this.transport = transport;
        this.masterMode = z;
        this.configuration = voidConfiguration;
    }

    public void configure(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull UpdaterParametersProvider updaterParametersProvider) {
        if (voidConfiguration == null) {
            throw new NullPointerException("configuration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        if (updaterParametersProvider == null) {
            throw new NullPointerException("updaterProvider is marked @NonNull but is null");
        }
        this.transport = transport;
        this.masterMode = false;
        this.configuration = voidConfiguration;
        this.updaterParametersProvider = updaterParametersProvider;
    }

    public void addUpdatesSubscriber(@NonNull UpdatesHandler updatesHandler) {
        if (updatesHandler == null) {
            throw new NullPointerException("s is marked @NonNull but is null");
        }
        this.updatesSubscribers.add(updatesHandler);
    }

    public void addModelParamsSubscriber(@NonNull Subscriber<INDArray> subscriber) {
        if (subscriber == null) {
            throw new NullPointerException("s is marked @NonNull but is null");
        }
        this.modelParamsSubsribers.add(subscriber);
    }

    public void addUpdaterParamsSubscriber(@NonNull Subscriber<INDArray> subscriber) {
        if (subscriber == null) {
            throw new NullPointerException("s is marked @NonNull but is null");
        }
        this.updaterParamsSubscribers.add(subscriber);
    }

    public boolean isInitialized() {
        return this.launchLock.get();
    }

    public Pair<Integer, Integer> getStartPosition() {
        return Pair.makePair(Integer.valueOf(this.iterationNumber.get()), Integer.valueOf(this.epochNumber.get()));
    }

    public synchronized void launch() {
        log.info("ModelParameterServer starting");
        if (this.launchLock.get()) {
            return;
        }
        this.configuration.setUnicastControllerPort(this.configuration.getPortSupplier().getPort());
        this.transport.setRestartCallback(new RestartCallback() { // from class: org.nd4j.parameterserver.distributed.v2.ModelParameterServer.1
            @Override // org.nd4j.parameterserver.distributed.v2.transport.RestartCallback
            public void call(HandshakeResponse handshakeResponse) {
                try {
                    ModelParameterServer.log.info("Restart callback started...");
                    ModelParametersRequest modelParametersRequest = new ModelParametersRequest();
                    String rootId = ModelParameterServer.this.transport.getRootId();
                    ModelParametersMessage modelParametersMessage = (ModelParametersMessage) ModelParameterServer.this.transport.sendMessageBlocking(modelParametersRequest, rootId);
                    INDArray payload = modelParametersMessage.getPayload();
                    ModelParameterServer.this.modelParamsSubsribers.forEach(subscriber -> {
                        subscriber.onNext(payload);
                    });
                    ModelParameterServer.this.iterationNumber.set(modelParametersMessage.getIterationNumber());
                    ModelParameterServer.this.epochNumber.set(modelParametersMessage.getEpochNumber());
                    INDArray payload2 = ((UpdaterParametersMessage) ModelParameterServer.this.transport.sendMessageBlocking(new UpdaterParametersRequest(), rootId)).getPayload();
                    if (payload2 != null) {
                        ModelParameterServer.this.updaterParamsSubscribers.forEach(subscriber2 -> {
                            subscriber2.onNext(payload2);
                        });
                        ModelParameterServer.log.debug("Updater parameters propagated...");
                    }
                } catch (Exception e) {
                    ModelParameterServer.log.error("RestartCallback processing exception: {}", e);
                    throw new RuntimeException(e);
                }
            }
        });
        this.transport.addRequestConsumer(ModelParametersRequest.class, new Consumer<ModelParametersRequest>() { // from class: org.nd4j.parameterserver.distributed.v2.ModelParameterServer.2
            public void accept(ModelParametersRequest modelParametersRequest) throws Exception {
                ModelParametersMessage modelParametersMessage = new ModelParametersMessage(UUID.randomUUID().toString(), ModelParameterServer.this.updatesSubscribers.get(0).getParametersArray());
                modelParametersMessage.setRequestId(modelParametersRequest.getRequestId());
                modelParametersMessage.setIterationNumber(ModelParameterServer.this.iterationNumber.get());
                modelParametersMessage.setEpochNumber(ModelParameterServer.this.epochNumber.get());
                ModelParameterServer.this.transport.sendMessage(modelParametersMessage, modelParametersRequest.getOriginatorId());
            }
        });
        if (this.masterMode) {
            addUpdaterParamsSubscriber(new AbstractSubscriber<INDArray>() { // from class: org.nd4j.parameterserver.distributed.v2.ModelParameterServer.3
                public void onNext(INDArray iNDArray) {
                    if (ModelParameterServer.this.gotFinalState.get()) {
                        return;
                    }
                    try {
                        ModelParameterServer.this.updaterParamsLock.writeLock().lock();
                        ((UpdaterParametersHolder) ModelParameterServer.this.updaterParameters.get()).setParameters(iNDArray);
                        ((UpdaterParametersHolder) ModelParameterServer.this.updaterParameters.get()).setTimeReceived(System.currentTimeMillis());
                    } finally {
                        ModelParameterServer.this.updaterParamsLock.writeLock().unlock();
                    }
                }
            });
            this.transport.addRequestConsumer(UpdaterParametersRequest.class, new Consumer<UpdaterParametersRequest>() { // from class: org.nd4j.parameterserver.distributed.v2.ModelParameterServer.4
                public void accept(UpdaterParametersRequest updaterParametersRequest) throws Exception {
                    if (!ModelParameterServer.this.gotFinalState.get()) {
                        String randomDownstreamFrom = ModelParameterServer.this.transport.getRandomDownstreamFrom(ModelParameterServer.this.transport.getRootId(), updaterParametersRequest.getOriginatorId());
                        ModelParameterServer.log.debug("Sending UpdaterParameters request to [{}]", randomDownstreamFrom);
                        INDArray payload = ((UpdaterParametersMessage) ModelParameterServer.this.transport.sendMessageBlocking(new UpdaterParametersRequest(), randomDownstreamFrom)).getPayload();
                        try {
                            ModelParameterServer.this.updaterParamsLock.writeLock().lock();
                            if (ModelParameterServer.this.updaterParameters.get() == null) {
                                ModelParameterServer.this.updaterParameters.set(new UpdaterParametersHolder(payload, System.currentTimeMillis(), false));
                            } else {
                                ((UpdaterParametersHolder) ModelParameterServer.this.updaterParameters.get()).setParameters(payload);
                            }
                        } finally {
                            ModelParameterServer.this.updaterParamsLock.writeLock().unlock();
                        }
                    }
                    try {
                        ModelParameterServer.this.updaterParamsLock.readLock().lock();
                        ModelParameterServer.log.debug("Trying to send back Updater parameters...");
                        UpdaterParametersMessage updaterParametersMessage = new UpdaterParametersMessage(UUID.randomUUID().toString(), ((UpdaterParametersHolder) ModelParameterServer.this.updaterParameters.get()).getParameters());
                        updaterParametersMessage.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(updaterParametersMessage, updaterParametersRequest.getOriginatorId());
                        ModelParameterServer.this.updaterParamsLock.readLock().unlock();
                    } catch (Throwable th) {
                        ModelParameterServer.this.updaterParamsLock.readLock().unlock();
                        throw th;
                    }
                }
            });
        } else {
            this.transport.addRequestConsumer(UpdaterParametersRequest.class, new Consumer<UpdaterParametersRequest>() { // from class: org.nd4j.parameterserver.distributed.v2.ModelParameterServer.5
                public void accept(UpdaterParametersRequest updaterParametersRequest) throws Exception {
                    ModelParameterServer.log.debug("Trying to send back Updater parameters...");
                    if (ModelParameterServer.this.updaterParametersProvider != null) {
                        UpdaterParametersMessage updaterParametersMessage = new UpdaterParametersMessage(UUID.randomUUID().toString(), ModelParameterServer.this.updaterParametersProvider.getUpdaterParameters());
                        updaterParametersMessage.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(updaterParametersMessage, updaterParametersRequest.getOriginatorId());
                    } else {
                        ModelParameterServer.log.warn("UpdaterParametersProvider wasn't set!");
                        UpdaterParametersMessage updaterParametersMessage2 = new UpdaterParametersMessage(UUID.randomUUID().toString(), null);
                        updaterParametersMessage2.setRequestId(updaterParametersRequest.getRequestId());
                        ModelParameterServer.this.transport.sendMessage(updaterParametersMessage2, updaterParametersRequest.getOriginatorId());
                    }
                }
            });
        }
        this.disposable = Flowable.fromPublisher(this.transport.incomingPublisher()).subscribe(iNDArrayMessage -> {
            if (!(iNDArrayMessage instanceof GradientsUpdateMessage)) {
                throw new UnsupportedOperationException("Unknown message received: [" + iNDArrayMessage.getClass().getCanonicalName() + "]");
            }
            GradientsUpdateMessage gradientsUpdateMessage = (GradientsUpdateMessage) iNDArrayMessage;
            if (this.iterationNumber.get() < gradientsUpdateMessage.getIteration()) {
                this.iterationNumber.set(gradientsUpdateMessage.getIteration());
            }
            if (this.epochNumber.get() < gradientsUpdateMessage.getEpoch()) {
                this.epochNumber.set(gradientsUpdateMessage.getEpoch());
            }
            if (this.updatesSubscribers.isEmpty()) {
                this.updatesQueue.add(iNDArrayMessage.getPayload());
            } else {
                this.updatesSubscribers.forEach(updatesHandler -> {
                    updatesHandler.onNext(iNDArrayMessage.getPayload());
                });
            }
        });
        if (this.masterMode) {
            this.transport.launchAsMaster();
        } else {
            this.transport.launch();
        }
        this.stopLock.set(false);
        this.launchLock.set(true);
    }

    public synchronized void shutdown() {
        if (this.stopLock.get()) {
            return;
        }
        this.transport.shutdown();
        this.disposable.dispose();
        this.updaterParamsSubscribers.clear();
        this.modelParamsSubsribers.clear();
        this.updatesSubscribers.clear();
        this.updatesQueue.clear();
        this.launchLock.set(false);
        this.stopLock.set(true);
    }

    public void sendUpdate(@NonNull INDArray iNDArray, int i, int i2) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        try {
            GradientsUpdateMessage gradientsUpdateMessage = new GradientsUpdateMessage(UUID.randomUUID().toString(), iNDArray);
            gradientsUpdateMessage.setOriginatorId(this.transport.id());
            gradientsUpdateMessage.setIteration(i);
            gradientsUpdateMessage.setEpoch(i2);
            this.transport.propagateMessage(gradientsUpdateMessage, PropagationMode.BOTH_WAYS);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void sendUpdate(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked @NonNull but is null");
        }
        sendUpdate(iNDArray, 0, 0);
    }

    public Collection<INDArray> getUpdates() {
        ArrayList arrayList = new ArrayList();
        this.updatesQueue.drainTo(arrayList);
        return arrayList;
    }

    public Transport getTransport() {
        return this.transport;
    }

    public INDArray getMasterModelParams() {
        return this.masterModelParams;
    }

    public INDArray getMasterUpdaterParams() {
        return this.masterUpdaterParams;
    }
}
