package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/* loaded from: input_file:ai/djl/nn/core/Embedding.class */
public class Embedding<T> extends ParameterBlock {
    private static final byte VERSION = 2;
    private int embeddingSize;
    private boolean useDefault;
    private DataType dataType;
    private Map<T, Integer> embedder;
    private int numItems;
    private Parameter embedding;

    /* loaded from: input_file:ai/djl/nn/core/Embedding$Builder.class */
    public static final class Builder<T> {
        private Class<T> embeddingType;
        private Collection<T> items;
        private int embeddingSize;
        private boolean useDefault;
        private DataType dataType;

        Builder() {
            this.useDefault = true;
            this.dataType = DataType.FLOAT32;
        }

        private Builder(Class<T> cls, Builder<?> builder) {
            this.useDefault = true;
            this.dataType = DataType.FLOAT32;
            this.embeddingType = cls;
            this.embeddingSize = builder.embeddingSize;
            this.useDefault = builder.useDefault;
            this.dataType = builder.dataType;
        }

        public Class<T> getEmbeddingType() {
            return this.embeddingType;
        }

        public <T> Builder<T> setType(Class<T> cls) {
            return new Builder<>(cls, this);
        }

        public Builder<T> setItems(Collection<T> collection) {
            this.items = collection;
            return this;
        }

        public Builder<T> setEmbeddingSize(int i) {
            this.embeddingSize = i;
            return this;
        }

        public Builder<T> optUseDefault(boolean z) {
            this.useDefault = z;
            return this;
        }

        public Builder<T> optDataType(DataType dataType) {
            this.dataType = dataType;
            return this;
        }

        public Embedding<T> build() {
            if (this.items == null) {
                throw new IllegalArgumentException("You must specify the items to embed");
            }
            if (this.embeddingSize == 0) {
                throw new IllegalArgumentException("You must specify the embedding size");
            }
            return new Embedding<>(this);
        }
    }

    Embedding(Builder<T> builder) {
        this.embeddingSize = ((Builder) builder).embeddingSize;
        this.useDefault = ((Builder) builder).useDefault;
        this.dataType = ((Builder) builder).dataType;
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
        this.embedder = new ConcurrentHashMap(((Builder) builder).items.size());
        this.numItems = 0;
        if (this.useDefault) {
            this.numItems++;
        }
        for (Object obj : ((Builder) builder).items) {
            Map<T, Integer> map = this.embedder;
            int i = this.numItems;
            this.numItems = i + 1;
            map.put(obj, Integer.valueOf(i));
        }
        this.inputShapes = new Shape[]{new Shape(-1)};
    }

    public Embedding(NDArray nDArray, List<T> list) {
        this.embeddingSize = Math.toIntExact(nDArray.getShape().get(1));
        this.useDefault = false;
        this.dataType = nDArray.getDataType();
        this.embedding = new Parameter("embedding", this, ParameterType.WEIGHT);
        this.embedding.setArray(nDArray);
        this.numItems = list.size();
        this.embedder = new ConcurrentHashMap(this.numItems);
        for (int i = 0; i < list.size(); i++) {
            this.embedder.put(list.get(i), Integer.valueOf(i));
        }
        this.inputShapes = new Shape[]{new Shape(-1)};
    }

    public static Builder<?> builder() {
        return new Builder<>();
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0].addAll(new Shape(this.embeddingSize))};
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return Collections.singletonList(this.embedding);
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        if ("embedding".equals(str)) {
            return new Shape(this.numItems, this.embeddingSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        NDList embedding = opInputs.head().getNDArrayInternal().embedding(opInputs, this.numItems, this.embeddingSize, this.dataType, pairList);
        if (nDList.singletonOrThrow().getShape().dimension() == 0) {
            embedding = new NDList(embedding.singletonOrThrow().reshape(this.embeddingSize));
        }
        return embedding;
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        saveInputShapes(dataOutputStream);
        this.embedding.save(dataOutputStream);
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte == VERSION) {
            readInputShapes(dataInputStream);
        } else if (readByte != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        this.embedding.load(nDManager, dataInputStream);
    }

    public boolean hasItem(T t) {
        return this.embedder.containsKey(t);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        NDList nDList2 = new NDList(VERSION);
        if (singletonOrThrow.getShape().dimension() == 0) {
            nDList2.add(singletonOrThrow.reshape(1));
        } else {
            nDList2.add(singletonOrThrow);
        }
        nDList2.add(parameterStore.getValue(this.embedding, device));
        return nDList2;
    }

    public NDArray embed(NDManager nDManager, T[] tArr) {
        return nDManager.create(Arrays.stream(tArr).mapToInt(this::embedHelper).toArray());
    }

    public NDArray embed(NDManager nDManager, T t) {
        return nDManager.create(embedHelper(t));
    }

    private int embedHelper(T t) {
        if (this.embedder.containsKey(t)) {
            return this.embedder.get(t).intValue();
        }
        if (this.useDefault) {
            return 0;
        }
        throw new IllegalArgumentException("The provided item was not found");
    }
}
