/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.modality.cv.MultiBoxTarget;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class BoundingBoxError
extends Evaluator {
    private Map<String, Float> ssdBoxPredictionError;
    private MultiBoxTarget multiBoxTarget = MultiBoxTarget.builder().build();

    public BoundingBoxError(String name) {
        super(name);
        this.ssdBoxPredictionError = new ConcurrentHashMap<String, Float>();
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        NDArray anchors = (NDArray)predictions.get(0);
        NDArray classPredictions = (NDArray)predictions.get(1);
        NDArray boundingBoxPredictions = (NDArray)predictions.get(2);
        NDList targets = this.multiBoxTarget.target(new NDList(anchors, labels.head(), classPredictions.transpose(0, 2, 1)));
        NDArray boundingBoxLabels = (NDArray)targets.get(0);
        NDArray boundingBoxMasks = (NDArray)targets.get(1);
        return boundingBoxLabels.sub(boundingBoxPredictions).mul(boundingBoxMasks).abs();
    }

    @Override
    public void addAccumulator(String key) {
        this.totalInstances.put(key, 0L);
        this.ssdBoxPredictionError.put(key, Float.valueOf(0.0f));
    }

    @Override
    public void updateAccumulator(String key, NDList labels, NDList predictions) {
        NDArray boundingBoxError = this.evaluate(labels, predictions);
        float update = boundingBoxError.sum().getFloat(new long[0]);
        this.totalInstances.compute(key, (k, v) -> v + boundingBoxError.size());
        this.ssdBoxPredictionError.compute(key, (k, v) -> Float.valueOf(v.floatValue() + update));
    }

    @Override
    public void resetAccumulator(String key) {
        this.totalInstances.compute(key, (k, v) -> 0L);
        this.ssdBoxPredictionError.compute(key, (k, v) -> Float.valueOf(0.0f));
    }

    @Override
    public float getAccumulator(String key) {
        Long total = (Long)this.totalInstances.get(key);
        if (total == null) {
            throw new IllegalArgumentException("No evaluator found at that path");
        }
        if (total == 0L) {
            return Float.NaN;
        }
        return this.ssdBoxPredictionError.get(key).floatValue() / (float)((Long)this.totalInstances.get(key)).longValue();
    }
}

