/*
 * Decompiled with CFR 0.152.
 */
package org.nuxeo.ai.model.serving;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import java.io.IOException;
import java.io.Serializable;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.apache.http.HttpResponse;
import org.apache.http.client.entity.EntityBuilder;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
import org.nuxeo.ai.enrichment.EnrichmentMetadata;
import org.nuxeo.ai.enrichment.EnrichmentService;
import org.nuxeo.ai.metadata.AIMetadata;
import org.nuxeo.ai.metadata.Suggestion;
import org.nuxeo.ai.metadata.SuggestionMetadata;
import org.nuxeo.ai.model.ModelProperty;
import org.nuxeo.ai.model.serving.AbstractRuntimeModel;
import org.nuxeo.ai.model.serving.ModelDescriptor;
import org.nuxeo.ai.pipes.services.JacksonUtil;
import org.nuxeo.ai.pipes.types.BlobTextFromDocument;
import org.nuxeo.ai.rest.RestClient;
import org.nuxeo.ecm.core.api.Blob;
import org.nuxeo.ecm.core.api.DocumentModel;
import org.nuxeo.ecm.core.api.NuxeoException;

public class TFRuntimeModel
extends AbstractRuntimeModel
implements EnrichmentService {
    public static final String VERB_CLASSIFY = "classify";
    public static final String VERB_REGRESS = "regress";
    public static final String VERB_PREDICT = "predict";
    public static final String PREDICTION_CUSTOM = "/prediction/custommodel";
    public static final String KIND_CONFIG = "kind";
    public static final String USE_LABELS = "useLabels";
    public static final String JSON_RESULTS = "results";
    public static final String JSON_OUTPUTS = "output_names";
    public static final String JSON_LABELS = "_labels";
    public static final String JSON_PROBABILITIES = "_prob";
    protected RestClient client;
    protected Set<String> inputNames;
    protected String kind;
    protected boolean useLabels;

    @Override
    public void init(ModelDescriptor descriptor) {
        super.init(descriptor);
        this.client = new RestClient(descriptor.configuration, "", null);
        this.useLabels = Boolean.parseBoolean(descriptor.configuration.getOrDefault(USE_LABELS, Boolean.TRUE.toString()));
        this.kind = descriptor.configuration.getOrDefault(KIND_CONFIG, PREDICTION_CUSTOM);
        this.inputNames = this.inputs.stream().map(ModelProperty::getName).collect(Collectors.toSet());
    }

    @Override
    public SuggestionMetadata predict(DocumentModel doc) {
        return this.predict(this.getProperties(doc));
    }

    public SuggestionMetadata predict(Map<String, Serializable> inputValues) {
        return (SuggestionMetadata)this.client.call(builder -> this.prepareRequest(VERB_PREDICT, (RequestBuilder)builder, inputValues), response -> {
            int statusCode = response.getStatusLine().getStatusCode();
            if (statusCode < 200 || statusCode >= 300) {
                log.warn((Object)String.format("Unsuccessful call to custom model (%s), status is %d", this.getName(), statusCode));
                return null;
            }
            SuggestionMetadata meta = this.handlePredict(response);
            if (log.isDebugEnabled()) {
                log.debug((Object)("Prediction is " + JacksonUtil.MAPPER.writeValueAsString((Object)meta)));
            }
            return meta;
        });
    }

    protected SuggestionMetadata handlePredict(HttpResponse response) {
        String content = this.client.getContent(response);
        Map<String, List<AIMetadata.Label>> labelledResults = this.parseResponse(content);
        if (!labelledResults.isEmpty()) {
            SuggestionMetadata.Builder builder = new SuggestionMetadata.Builder(this.getKind(), this.getId());
            ArrayList suggestions = new ArrayList();
            labelledResults.forEach((output, labels) -> suggestions.add(new Suggestion(output, labels)));
            builder.withSuggestions(suggestions);
            return (SuggestionMetadata)builder.withRawKey(this.saveJsonAsRawBlob(content)).build();
        }
        return null;
    }

    protected Map<String, List<AIMetadata.Label>> parseResponse(String content) {
        HashMap<String, List<AIMetadata.Label>> results = new HashMap<String, List<AIMetadata.Label>>();
        try {
            JsonNode jsonResponse = JacksonUtil.MAPPER.readTree(content);
            jsonResponse.get(JSON_RESULTS).elements().forEachRemaining(resultsNode -> resultsNode.get(JSON_OUTPUTS).elements().forEachRemaining(outputNode -> {
                String outputName = outputNode.asText();
                ArrayNode outputProbabilities = (ArrayNode)resultsNode.get(outputName + JSON_PROBABILITIES);
                ArrayNode outputLabels = (ArrayNode)resultsNode.get(outputName + JSON_LABELS);
                ArrayList<AIMetadata.Label> labels = new ArrayList<AIMetadata.Label>();
                if (outputLabels.size() == outputProbabilities.size()) {
                    for (int i = 0; i < outputLabels.size(); ++i) {
                        float confidence = outputProbabilities.get(i).floatValue();
                        if (!(confidence > this.minConfidence)) continue;
                        labels.add(new AIMetadata.Label(outputLabels.get(i).asText(), confidence));
                    }
                }
                if (!labels.isEmpty()) {
                    results.put(outputName, labels);
                }
            }));
        }
        catch (IOException | NullPointerException e) {
            log.warn((Object)String.format("Unable to read the json response: %s", content), (Throwable)e);
        }
        return results;
    }

    protected HttpUriRequest prepareRequest(String verb, RequestBuilder builder, Map<String, Serializable> inputs) {
        builder.setUri(this.buildUri(verb, builder.getUri().toString()));
        try {
            String json = JacksonUtil.MAPPER.writeValueAsString((Object)new TensorInstances(inputs));
            builder.setEntity(EntityBuilder.create().setText(json).build());
        }
        catch (JsonProcessingException e) {
            log.warn((Object)"Failed to serialize model inputs", (Throwable)e);
            throw new NuxeoException("Unable to make a valid json request", (Throwable)e);
        }
        return builder.build();
    }

    protected String buildUri(String verb, String baseUri) {
        return baseUri + verb;
    }

    @Override
    public String getName() {
        String version = this.getVersion();
        return super.getName() + (StringUtils.isNotBlank((CharSequence)version) ? "_" + this.getVersion() : "");
    }

    public String getKind() {
        return this.kind;
    }

    public Collection<EnrichmentMetadata> enrich(BlobTextFromDocument blobtext) {
        HashMap<String, Serializable> inputProperties = new HashMap<String, Serializable>();
        for (Map.Entry blobEntry : blobtext.getBlobs().entrySet()) {
            inputProperties.put((String)blobEntry.getKey(), this.convertImageBlob((Blob)blobEntry.getValue()));
        }
        inputProperties.putAll(blobtext.getProperties());
        if (inputProperties.isEmpty()) {
            log.warn((Object)String.format("(%s) unable to enrich doc properties for doc %s", this.getName(), blobtext.getId()));
        } else {
            SuggestionMetadata suggestions = this.predict(inputProperties);
            if (!suggestions.getSuggestions().isEmpty()) {
                EnrichmentMetadata.Builder builder = new EnrichmentMetadata.Builder(Instant.now(), this.getKind(), this.getId(), new AIMetadata.Context(blobtext.getRepositoryName(), blobtext.getId(), null, this.inputNames));
                builder.withRawKey(suggestions.getRawKey());
                if (this.useLabels) {
                    if (suggestions.getSuggestions().size() != 1) {
                        log.error((Object)"Multiple outputs is currently unsupported.  The output name will be lost.");
                    }
                    List vals = suggestions.getSuggestions().stream().map(Suggestion::getValues).flatMap(Collection::stream).collect(Collectors.toList());
                    builder.withLabels(vals);
                } else {
                    builder.withSuggestions(suggestions.getSuggestions());
                }
                return Collections.singletonList(builder.build());
            }
        }
        return Collections.emptyList();
    }

    @Override
    protected Serializable convertImageBlob(Blob sourceBlob) {
        Serializable converted = super.convertImageBlob(sourceBlob);
        if (converted instanceof String) {
            return new TensorImage((String)((Object)converted));
        }
        return null;
    }

    protected static class TensorImage
    implements Serializable {
        private static final long serialVersionUID = 2603715122387085509L;
        public final String b64;

        public TensorImage(String b64) {
            this.b64 = b64;
        }
    }

    protected static class TensorInstances {
        public final List<Map<String, Serializable>> instances = new ArrayList<Map<String, Serializable>>();

        public TensorInstances(Map<String, Serializable> inputs) {
            this.instances.add(inputs);
        }
    }
}

