/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vertexai.palm2.api;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;

public class VertexAiPaLm2Api {
    public static final String DEFAULT_GENERATE_MODEL = "chat-bison-001";
    public static final String DEFAULT_EMBEDDING_MODEL = "embedding-gecko-001";
    public static final String DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta3";
    private final RestClient restClient;
    private final String apiKey;
    private final String chatModel;
    private final String embeddingModel;

    public VertexAiPaLm2Api(String apiKey) {
        this(DEFAULT_BASE_URL, apiKey, DEFAULT_GENERATE_MODEL, DEFAULT_EMBEDDING_MODEL, RestClient.builder());
    }

    public VertexAiPaLm2Api(String baseUrl, String apiKey, String model, String embeddingModel, RestClient.Builder restClientBuilder) {
        this.chatModel = model;
        this.embeddingModel = embeddingModel;
        this.apiKey = apiKey;
        Consumer<HttpHeaders> jsonContentHeaders = headers -> {
            headers.setAccept(List.of(MediaType.APPLICATION_JSON));
            headers.setContentType(MediaType.APPLICATION_JSON);
        };
        ResponseErrorHandler responseErrorHandler = new ResponseErrorHandler(){

            public boolean hasError(ClientHttpResponse response) throws IOException {
                return response.getStatusCode().isError();
            }

            public void handleError(ClientHttpResponse response) throws IOException {
                if (response.getStatusCode().isError()) {
                    throw new RuntimeException(String.format("%s - %s", response.getStatusCode().value(), new ObjectMapper().readValue(response.getBody(), ResponseError.class)));
                }
            }
        };
        this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).defaultStatusHandler(responseErrorHandler).build();
    }

    public GenerateMessageResponse generateMessage(GenerateMessageRequest request) {
        Assert.notNull((Object)request, (String)"The request body can not be null.");
        return (GenerateMessageResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/models/{model}:generateMessage?key={apiKey}", new Object[]{this.chatModel, this.apiKey})).body((Object)request).retrieve().body(GenerateMessageResponse.class);
    }

    public Embedding embedText(String text) {
        Assert.hasText((String)text, (String)"The text can not be null or empty.");
        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        record EmbeddingResponse(Embedding embedding) {
        }
        EmbeddingResponse response = (EmbeddingResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/models/{model}:embedText?key={apiKey}", new Object[]{this.embeddingModel, this.apiKey})).body(Map.of("text", text)).retrieve().body(EmbeddingResponse.class);
        return response != null ? response.embedding() : null;
    }

    public List<Embedding> batchEmbedText(List<String> texts) {
        Assert.notNull(texts, (String)"The texts can not be null.");
        BatchEmbeddingResponse response = (BatchEmbeddingResponse)((RestClient.RequestBodySpec)this.restClient.post().uri("/models/{model}:batchEmbedText?key={apiKey}", new Object[]{this.embeddingModel, this.apiKey})).body(Map.of("texts", texts)).retrieve().body(BatchEmbeddingResponse.class);
        return response != null ? response.embeddings() : null;
    }

    public Integer countMessageTokens(MessagePrompt prompt) {
        Assert.notNull((Object)prompt, (String)"The message prompt can not be null.");
        record TokenCount(@JsonProperty(value="tokenCount") Integer tokenCount) {
        }
        TokenCount tokenCountResponse = (TokenCount)((RestClient.RequestBodySpec)this.restClient.post().uri("/models/{model}:countMessageTokens?key={apiKey}", new Object[]{this.chatModel, this.apiKey})).body(Map.of("prompt", prompt)).retrieve().body(TokenCount.class);
        return tokenCountResponse != null ? tokenCountResponse.tokenCount() : null;
    }

    public List<String> listModels() {
        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        record ModelList(@JsonProperty(value="models") List<1ModelList.ModelName> models) {

            record 1ModelList.ModelName(String name) {
            }
        }
        ModelList modelList = (ModelList)this.restClient.get().uri("/models?key={apiKey}", new Object[]{this.apiKey}).retrieve().body(ModelList.class);
        return modelList == null ? List.of() : modelList.models().stream().map(1ModelList.ModelName::name).toList();
    }

    public Model getModel(String modelName) {
        Assert.hasText((String)modelName, (String)"The model name can not be null or empty.");
        if (modelName.startsWith("models/")) {
            modelName = modelName.substring("models/".length());
        }
        return (Model)this.restClient.get().uri("/models/{model}?key={apiKey}", new Object[]{modelName, this.apiKey}).retrieve().body(Model.class);
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record GenerateMessageResponse(@JsonProperty(value="candidates") List<Message> candidates, @JsonProperty(value="messages") List<Message> messages, @JsonProperty(value="filters") List<ContentFilter> filters) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record ContentFilter(@JsonProperty(value="reason") BlockedReason reason, @JsonProperty(value="message") String message) {

            public static enum BlockedReason {
                BLOCKED_REASON_UNSPECIFIED,
                SAFETY,
                OTHER;

            }
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Embedding(@JsonProperty(value="value") float[] value) {
        @Override
        public final int hashCode() {
            return Arrays.hashCode(this.value);
        }

        @Override
        public final boolean equals(Object arg0) {
            return Arrays.equals(this.value, ((Embedding)arg0).value);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    record BatchEmbeddingResponse(List<Embedding> embeddings) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Model(@JsonProperty(value="name") String name, @JsonProperty(value="baseModelId") String baseModelId, @JsonProperty(value="version") String version, @JsonProperty(value="displayName") String displayName, @JsonProperty(value="description") String description, @JsonProperty(value="inputTokenLimit") Integer inputTokenLimit, @JsonProperty(value="outputTokenLimit") Integer outputTokenLimit, @JsonProperty(value="supportedGenerationMethods") List<String> supportedGenerationMethods, @JsonProperty(value="temperature") Float temperature, @JsonProperty(value="topP") Float topP, @JsonProperty(value="topK") Integer topK) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record GenerateMessageRequest(@JsonProperty(value="prompt") MessagePrompt prompt, @JsonProperty(value="temperature") Float temperature, @JsonProperty(value="candidateCount") Integer candidateCount, @JsonProperty(value="topP") Float topP, @JsonProperty(value="topK") Integer topK) {
        public GenerateMessageRequest(MessagePrompt prompt) {
            this(prompt, null, null, null, null);
        }

        public GenerateMessageRequest(MessagePrompt prompt, Float temperature, Integer topK) {
            this(prompt, temperature, null, null, topK);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record MessagePrompt(@JsonProperty(value="context") String context, @JsonProperty(value="examples") List<Example> examples, @JsonProperty(value="messages") List<Message> messages) {
        public MessagePrompt(List<Message> messages) {
            this(null, null, messages);
        }

        public MessagePrompt(String context, List<Message> messages) {
            this(context, null, messages);
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record Example(@JsonProperty(value="input") Message input, @JsonProperty(value="output") Message output) {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Message(@JsonProperty(value="author") String author, @JsonProperty(value="content") String content, @JsonProperty(value="citationMetadata") CitationMetadata citationMetadata) {
        public Message(String author, String content) {
            this(author, content, null);
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record CitationMetadata(@JsonProperty(value="citationSources") List<CitationSource> citationSources) {
        }

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record CitationSource(@JsonProperty(value="startIndex") Integer startIndex, @JsonProperty(value="endIndex") Integer endIndex, @JsonProperty(value="uri") String uri, @JsonProperty(value="license") String license) {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record ResponseError(@JsonProperty(value="error") Error error) {

        @JsonInclude(value=JsonInclude.Include.NON_NULL)
        public record Error(@JsonProperty(value="message") String message, @JsonProperty(value="code") String code, @JsonProperty(value="status") String status) {
        }
    }
}

