package org.nd4j.parameterserver.distributed.logic.routing;

import lombok.NonNull;
import org.apache.commons.lang3.RandomUtils;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.messages.VoidMessage;
import org.nd4j.parameterserver.distributed.transport.Transport;

@Deprecated
/* loaded from: input_file:org/nd4j/parameterserver/distributed/logic/routing/RandomRouter.class */
public class RandomRouter extends BaseRouter {
    protected int numShards;

    @Override // org.nd4j.parameterserver.distributed.logic.routing.BaseRouter, org.nd4j.parameterserver.distributed.logic.ClientRouter
    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration is marked @NonNull but is null");
        }
        if (transport == null) {
            throw new NullPointerException("transport is marked @NonNull but is null");
        }
        super.init(voidConfiguration, transport);
        voidConfiguration.getNumberOfShards();
    }

    @Override // org.nd4j.parameterserver.distributed.logic.ClientRouter
    public int assignTarget(TrainingMessage trainingMessage) {
        setOriginator(trainingMessage);
        trainingMessage.setTargetId(getNextShard());
        return trainingMessage.getTargetId();
    }

    @Override // org.nd4j.parameterserver.distributed.logic.ClientRouter
    public int assignTarget(VoidMessage voidMessage) {
        setOriginator(voidMessage);
        voidMessage.setTargetId(getNextShard());
        return voidMessage.getTargetId();
    }

    protected short getNextShard() {
        return (short) RandomUtils.nextInt(0, this.numShards);
    }
}
