package org.nuxeo.ai.model.serving;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.node.ArrayNode;
import java.io.IOException;
import java.io.Serializable;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpResponse;
import org.apache.http.client.entity.EntityBuilder;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
import org.nuxeo.ai.enrichment.EnrichmentMetadata;
import org.nuxeo.ai.enrichment.EnrichmentService;
import org.nuxeo.ai.metadata.AIMetadata;
import org.nuxeo.ai.metadata.Suggestion;
import org.nuxeo.ai.metadata.SuggestionMetadata;
import org.nuxeo.ai.pipes.services.JacksonUtil;
import org.nuxeo.ai.pipes.types.BlobTextFromDocument;
import org.nuxeo.ai.rest.RestClient;
import org.nuxeo.ecm.core.api.Blob;
import org.nuxeo.ecm.core.api.DocumentModel;
import org.nuxeo.ecm.core.api.NuxeoException;

/* loaded from: input_file:org/nuxeo/ai/model/serving/TFRuntimeModel.class */
public class TFRuntimeModel extends AbstractRuntimeModel implements EnrichmentService {
    public static final String VERB_CLASSIFY = "classify";
    public static final String VERB_REGRESS = "regress";
    public static final String VERB_PREDICT = "predict";
    public static final String PREDICTION_CUSTOM = "/prediction/custommodel";
    public static final String KIND_CONFIG = "kind";
    public static final String USE_LABELS = "useLabels";
    public static final String JSON_RESULTS = "results";
    public static final String JSON_OUTPUTS = "output_names";
    public static final String JSON_LABELS = "_labels";
    public static final String JSON_PROBABILITIES = "_prob";
    protected RestClient client;
    protected Set<String> inputNames;
    protected String kind;
    protected boolean useLabels;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nuxeo/ai/model/serving/TFRuntimeModel$TensorImage.class */
    public static class TensorImage implements Serializable {
        private static final long serialVersionUID = 2603715122387085509L;
        public final String b64;

        public TensorImage(String str) {
            this.b64 = str;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nuxeo/ai/model/serving/TFRuntimeModel$TensorInstances.class */
    public static class TensorInstances {
        public final List<Map<String, Serializable>> instances = new ArrayList();

        public TensorInstances(Map<String, Serializable> map) {
            this.instances.add(map);
        }
    }

    @Override // org.nuxeo.ai.model.serving.AbstractRuntimeModel, org.nuxeo.ai.model.serving.RuntimeModel
    public void init(ModelDescriptor modelDescriptor) {
        super.init(modelDescriptor);
        this.client = new RestClient(modelDescriptor.configuration, "", (Function) null);
        this.useLabels = Boolean.parseBoolean(modelDescriptor.configuration.getOrDefault(USE_LABELS, Boolean.TRUE.toString()));
        this.kind = modelDescriptor.configuration.getOrDefault(KIND_CONFIG, PREDICTION_CUSTOM);
        this.inputNames = (Set) this.inputs.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toSet());
    }

    @Override // org.nuxeo.ai.model.serving.RuntimeModel
    public SuggestionMetadata predict(DocumentModel documentModel) {
        return predict(getProperties(documentModel));
    }

    public SuggestionMetadata predict(Map<String, Serializable> map) {
        return (SuggestionMetadata) this.client.call(requestBuilder -> {
            return prepareRequest(VERB_PREDICT, requestBuilder, map);
        }, httpResponse -> {
            int statusCode = httpResponse.getStatusLine().getStatusCode();
            if (statusCode < 200 || statusCode >= 300) {
                log.warn(String.format("Unsuccessful call to custom model (%s), status is %d", getName(), Integer.valueOf(statusCode)));
                return null;
            }
            SuggestionMetadata handlePredict = handlePredict(httpResponse);
            if (log.isDebugEnabled()) {
                log.debug("Prediction is " + JacksonUtil.MAPPER.writeValueAsString(handlePredict));
            }
            return handlePredict;
        });
    }

    protected SuggestionMetadata handlePredict(HttpResponse httpResponse) {
        String content = this.client.getContent(httpResponse);
        Map<String, List<AIMetadata.Label>> parseResponse = parseResponse(content);
        if (parseResponse.isEmpty()) {
            return null;
        }
        SuggestionMetadata.Builder builder = new SuggestionMetadata.Builder(getKind(), getId(), this.inputNames);
        ArrayList arrayList = new ArrayList();
        parseResponse.forEach((str, list) -> {
            arrayList.add(new Suggestion(str, list));
        });
        builder.withSuggestions(arrayList);
        return builder.withRawKey(saveJsonAsRawBlob(content)).build();
    }

    protected Map<String, List<AIMetadata.Label>> parseResponse(String str) {
        HashMap hashMap = new HashMap();
        try {
            JacksonUtil.MAPPER.readTree(str).get(JSON_RESULTS).elements().forEachRemaining(jsonNode -> {
                jsonNode.get(JSON_OUTPUTS).elements().forEachRemaining(jsonNode -> {
                    String asText = jsonNode.asText();
                    ArrayNode arrayNode = jsonNode.get(asText + JSON_PROBABILITIES);
                    ArrayNode arrayNode2 = jsonNode.get(asText + JSON_LABELS);
                    ArrayList arrayList = new ArrayList();
                    if (arrayNode2.size() == arrayNode.size()) {
                        for (int i = 0; i < arrayNode2.size(); i++) {
                            float floatValue = arrayNode.get(i).floatValue();
                            if (floatValue > this.minConfidence) {
                                arrayList.add(new AIMetadata.Label(arrayNode2.get(i).asText(), floatValue));
                            }
                        }
                    }
                    if (arrayList.isEmpty()) {
                        return;
                    }
                    hashMap.put(asText, arrayList);
                });
            });
        } catch (IOException | NullPointerException e) {
            log.warn(String.format("Unable to read the json response: %s", str), e);
        }
        return hashMap;
    }

    protected HttpUriRequest prepareRequest(String str, RequestBuilder requestBuilder, Map<String, Serializable> map) {
        requestBuilder.setUri(buildUri(str, requestBuilder.getUri().toString()));
        try {
            requestBuilder.setEntity(EntityBuilder.create().setText(JacksonUtil.MAPPER.writeValueAsString(new TensorInstances(map))).build());
            return requestBuilder.build();
        } catch (JsonProcessingException e) {
            log.warn("Failed to serialize model inputs", e);
            throw new NuxeoException("Unable to make a valid json request", e);
        }
    }

    protected String buildUri(String str, String str2) {
        return str2 + str;
    }

    @Override // org.nuxeo.ai.model.serving.AbstractRuntimeModel, org.nuxeo.ai.model.AIModel
    public String getName() {
        return super.getName() + (StringUtils.isNotBlank(getVersion()) ? "_" + getVersion() : "");
    }

    public String getKind() {
        return this.kind;
    }

    public Collection<EnrichmentMetadata> enrich(BlobTextFromDocument blobTextFromDocument) {
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : blobTextFromDocument.getBlobs().entrySet()) {
            hashMap.put(entry.getKey(), convertImageBlob((Blob) entry.getValue()));
        }
        hashMap.putAll(blobTextFromDocument.getProperties());
        if (hashMap.isEmpty()) {
            log.warn(String.format("(%s) unable to enrich doc properties for doc %s", getName(), blobTextFromDocument.getId()));
        } else {
            SuggestionMetadata predict = predict(hashMap);
            if (!predict.getSuggestions().isEmpty()) {
                EnrichmentMetadata.Builder builder = new EnrichmentMetadata.Builder(Instant.now(), getKind(), getId(), new AIMetadata.Context(blobTextFromDocument.getRepositoryName(), blobTextFromDocument.getId(), (Set) null, this.inputNames));
                builder.withRawKey(predict.getRawKey());
                if (this.useLabels) {
                    if (predict.getSuggestions().size() != 1) {
                        log.error("Multiple outputs is currently unsupported.  The output name will be lost.");
                    }
                    builder.withLabels((List) predict.getSuggestions().stream().map((v0) -> {
                        return v0.getValues();
                    }).flatMap((v0) -> {
                        return v0.stream();
                    }).collect(Collectors.toList()));
                } else {
                    builder.withSuggestions(predict.getSuggestions());
                }
                return Collections.singletonList(builder.build());
            }
        }
        return Collections.emptyList();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.nuxeo.ai.model.serving.AbstractRuntimeModel
    public Serializable convertImageBlob(Blob blob) {
        Serializable convertImageBlob = super.convertImageBlob(blob);
        if (convertImageBlob instanceof String) {
            return new TensorImage((String) convertImageBlob);
        }
        return null;
    }
}
