package ai.djl.nn.norm;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:ai/djl/nn/norm/Dropout.class */
public class Dropout extends ParameterBlock {
    private static final byte VERSION = 2;
    private float probability;
    private int[] sharedAxes;

    /* loaded from: input_file:ai/djl/nn/norm/Dropout$Builder.class */
    public static final class Builder {
        private float probability = 0.5f;
        private int[] sharedAxes = new int[0];

        Builder() {
        }

        public Builder optProbability(float f) {
            this.probability = f;
            return this;
        }

        public Builder optSharedAxes(int[] iArr) {
            this.sharedAxes = iArr;
            return this;
        }

        public Dropout build() {
            return new Dropout(this);
        }
    }

    Dropout(Builder builder) {
        this.probability = builder.probability;
        this.sharedAxes = builder.sharedAxes;
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, PairList<String, Object> pairList) {
        return nDList.singletonOrThrow().getNDArrayInternal().dropout(nDList, this.probability, this.sharedAxes, pairList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return Collections.emptyList();
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        throw new IllegalArgumentException("Dropout has no parameters");
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        saveInputShapes(dataOutputStream);
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte == VERSION) {
            readInputShapes(dataInputStream);
        } else if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
