package ai.djl.mxnet.zoo.nlp.qa;

import ai.djl.Model;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Utils;
import java.io.IOException;
import java.util.List;

/* loaded from: input_file:ai/djl/mxnet/zoo/nlp/qa/BertQATranslator.class */
public class BertQATranslator implements Translator<QAInput, String> {
    private List<String> tokens;
    private BertDataParser parser;

    public Batchifier getBatchifier() {
        return null;
    }

    public void prepare(NDManager nDManager, Model model) throws IOException {
        this.parser = (BertDataParser) model.getArtifact("vocab.json", BertDataParser::parse);
    }

    public NDList processInput(TranslatorContext translatorContext, QAInput qAInput) throws IOException {
        List<String> list = BertDataParser.tokenizer(qAInput.getQuestion().toLowerCase());
        List<String> list2 = BertDataParser.tokenizer(qAInput.getParagraph().toLowerCase());
        int size = list.size() + list2.size();
        List<Float> tokenTypes = BertDataParser.getTokenTypes(list, list2, qAInput.getSeqLength());
        this.tokens = BertDataParser.formTokens(list, list2, qAInput.getSeqLength());
        List<Integer> list3 = this.parser.token2idx(this.tokens);
        float[] floatArray = Utils.toFloatArray(tokenTypes);
        float[] floatArray2 = Utils.toFloatArray(list3);
        int seqLength = qAInput.getSeqLength();
        NDManager nDManager = translatorContext.getNDManager();
        NDArray create = nDManager.create(floatArray2, new Shape(new long[]{1, seqLength}));
        create.setName("data0");
        NDArray create2 = nDManager.create(floatArray, new Shape(new long[]{1, seqLength}));
        create2.setName("data1");
        NDArray create3 = nDManager.create(new float[]{size});
        create3.setName("data2");
        return new NDList(new NDArray[]{create, create2, create3});
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public String m10processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDList split = nDList.singletonOrThrow().split(2L, 2);
        NDArray reshape = ((NDArray) split.get(0)).reshape(new Shape(new long[]{1, -1}));
        NDArray reshape2 = ((NDArray) split.get(1)).reshape(new Shape(new long[]{1, -1}));
        return this.tokens.subList((int) reshape.softmax(-1).argMax(1).getLong(new long[0]), ((int) reshape2.softmax(-1).argMax(1).getLong(new long[0])) + 1).toString();
    }
}
