package ai.djl.tensorflow.zoo.cv.objectdetction;

import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.ObjectDetectionTranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;

/* loaded from: input_file:ai/djl/tensorflow/zoo/cv/objectdetction/TfSsdTranslator.class */
public class TfSsdTranslator extends ObjectDetectionTranslator {
    private int maxBoxes;
    private String boundingBoxOutputName;
    private String scoresOutputName;
    private String classLabelOutputName;

    /* loaded from: input_file:ai/djl/tensorflow/zoo/cv/objectdetction/TfSsdTranslator$Builder.class */
    public static class Builder extends ObjectDetectionTranslator.ObjectDetectionBuilder<Builder> {
        private int maxBoxes = 10;
        private String boundingBoxOutputName = "detection_boxes";
        private String scoresOutputName = "detection_scores";
        private String classLabelOutputName = "detection_class_labels";

        public Builder optBoundingBoxOutputName(String str) {
            this.boundingBoxOutputName = str;
            return this;
        }

        public Builder optScoresOutputName(String str) {
            this.scoresOutputName = str;
            return this;
        }

        public Builder optClassLabelOutputName(String str) {
            this.classLabelOutputName = str;
            return this;
        }

        public Builder optMaxBoxes(int i) {
            this.maxBoxes = i;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* renamed from: self, reason: merged with bridge method [inline-methods] */
        public Builder m2self() {
            return this;
        }

        protected void configPreProcess(Map<String, ?> map) {
            super.configPreProcess(map);
        }

        protected void configPostProcess(Map<String, ?> map) {
            super.configPostProcess(map);
            this.maxBoxes = TfSsdTranslator.getIntValue(map, "maxBoxes", 10);
            this.threshold = TfSsdTranslator.getFloatValue(map, "threshold", 0.4f);
            this.boundingBoxOutputName = TfSsdTranslator.getStringValue(map, "boundingBoxOutputName", "detection_boxes");
            this.scoresOutputName = TfSsdTranslator.getStringValue(map, "scoresOutputName", "detection_scores");
            this.classLabelOutputName = TfSsdTranslator.getStringValue(map, "classLabelOutputName", "detection_class_labels");
        }

        public TfSsdTranslator build() {
            validate();
            return new TfSsdTranslator(this);
        }
    }

    protected TfSsdTranslator(Builder builder) {
        super(builder);
        this.maxBoxes = builder.maxBoxes;
        this.boundingBoxOutputName = builder.boundingBoxOutputName;
        this.scoresOutputName = builder.scoresOutputName;
        this.classLabelOutputName = builder.classLabelOutputName;
    }

    public NDList processInput(TranslatorContext translatorContext, Image image) {
        return new NDList(new NDArray[]{((NDArray) super.processInput(translatorContext, image).get(0)).expandDims(0)});
    }

    public Batchifier getBatchifier() {
        return null;
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public DetectedObjects m1processOutput(TranslatorContext translatorContext, NDList nDList) {
        int i = (int) ((NDArray) nDList.get(0)).getShape().get(0);
        float[] fArr = new float[i];
        long[] jArr = new long[i];
        NDArray nDArray = (NDArray) nDList.get(0);
        Iterator it = nDList.iterator();
        while (it.hasNext()) {
            NDArray nDArray2 = (NDArray) it.next();
            if (this.scoresOutputName.equals(nDArray2.getName())) {
                fArr = nDArray2.toFloatArray();
            } else if (this.boundingBoxOutputName.equals(nDArray2.getName())) {
                nDArray = nDArray2;
            } else {
                if (!this.classLabelOutputName.equals(nDArray2.getName())) {
                    throw new IllegalStateException("Unexpected result NDArray:" + nDArray2.getName());
                }
                jArr = nDArray2.toLongArray();
            }
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < Math.min(jArr.length, this.maxBoxes); i2++) {
            long j = jArr[i2];
            double d = fArr[i2];
            if (j >= 0 && d > this.threshold) {
                if (j >= this.classes.size()) {
                    throw new AssertionError("Unexpected index: " + j);
                }
                String str = (String) this.classes.get(((int) j) - 1);
                float[] floatArray = nDArray.get(new long[]{i2}).toFloatArray();
                float f = floatArray[0];
                Rectangle rectangle = new Rectangle(floatArray[1], f, floatArray[3] - r0, floatArray[2] - f);
                arrayList.add(str);
                arrayList2.add(Double.valueOf(d));
                arrayList3.add(rectangle);
            }
        }
        return new DetectedObjects(arrayList, arrayList2, arrayList3);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
