/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.function_calling;

import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.PathNotFoundException;
import com.jayway.jsonpath.Predicate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
import org.opensearch.ml.engine.function_calling.GeminiMessage;
import org.opensearch.ml.engine.function_calling.LLMMessage;

public class GeminiV1BetaGenerateContentFunctionCalling
implements FunctionCalling {
    @Generated
    private static final Logger log = LogManager.getLogger(GeminiV1BetaGenerateContentFunctionCalling.class);
    public static final String CALL_PATH = "$.candidates[0].content.parts[*].functionCall";
    public static final String NAME = "name";
    public static final String INPUT = "args";
    public static final String ID_PATH = "name";
    public static final String TOOL_ERROR = "tool_error";
    public static final String GEMINI_TOOL_TEMPLATE = "{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"parameters\":${tool.attributes.input_schema_cleaned}}";

    @Override
    public void configure(Map<String, String> params) {
        if (!params.containsKey("no_escape_params")) {
            params.put("no_escape_params", "_chat_history,_tools,_interactions,tool_configs");
        }
        params.put("llm_response_filter", "$.candidates[0].content.parts[0].text");
        params.put("tool_template", GEMINI_TOOL_TEMPLATE);
        params.put("gemini.schema.cleaner", "true");
        params.put("tool_calls_path", CALL_PATH);
        params.put("tool_calls.tool_name", "name");
        params.put("tool_calls.tool_input", INPUT);
        params.put("tool_calls.id_path", "name");
        params.put("tool_configs", ", \"tools\": [{\"functionDeclarations\": [${parameters._tools:-}]}], \"toolConfig\": {\"functionCallingConfig\": {\"mode\": \"AUTO\"}}");
        params.put("interaction_template.assistant_tool_calls_path", "$.candidates[0].content");
        params.put("interaction_template.tool_response", "{\"role\":\"user\",\"parts\":[{\"functionResponse\":{\"name\":\"${_interactions.tool_call_id}\",\"response\":{\"text\":\"${_interactions.tool_response}\"}}}]}");
        params.put("chat_history_template.user_question", "{\"role\":\"user\",\"parts\":[{\"text\":\"${_chat_history.message.question}\"}]}");
        params.put("chat_history_template.ai_response", "{\"role\":\"model\",\"parts\":[{\"text\":\"${_chat_history.message.response}\"}]}");
        params.put("llm_finish_reason_path", "$.candidates[0].finishReason");
        params.put("llm_finish_reason_tool_use", "N/A");
    }

    @Override
    public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
        List functionCalls;
        ArrayList<Map<String, String>> output = new ArrayList<Map<String, String>>();
        Map<String, ?> dataAsMap = ((ModelTensor)((ModelTensors)tmpModelTensorOutput.getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
        String llmResponseExcludePath = parameters.get("llm_response_exclude_path");
        if (llmResponseExcludePath != null) {
            dataAsMap = AgentUtils.removeJsonPath(dataAsMap, llmResponseExcludePath, true);
        }
        try {
            functionCalls = (List)JsonPath.read(dataAsMap, (String)CALL_PATH, (Predicate[])new Predicate[0]);
        }
        catch (PathNotFoundException e) {
            return output;
        }
        if (CollectionUtils.isEmpty((Collection)functionCalls)) {
            return output;
        }
        for (Object call : functionCalls) {
            String toolName = (String)JsonPath.read(call, (String)"name", (Predicate[])new Predicate[0]);
            String toolInput = StringUtils.toJson((Object)JsonPath.read(call, (String)INPUT, (Predicate[])new Predicate[0]));
            String toolCallId = (String)JsonPath.read(call, (String)"name", (Predicate[])new Predicate[0]);
            output.add(Map.of("tool_name", toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
        }
        return output;
    }

    @Override
    public List<LLMMessage> supply(List<Map<String, Object>> toolResults) {
        GeminiMessage toolMessage = new GeminiMessage();
        for (Map<String, Object> toolResult : toolResults) {
            String toolUseId = (String)toolResult.get("tool_call_id");
            if (toolUseId == null) continue;
            Map<String, Object> functionResponse = Map.of("name", toolUseId, "response", toolResult.get("tool_result"));
            toolMessage.getContent().add(Map.of("functionResponse", functionResponse));
            if (!toolResult.containsKey(TOOL_ERROR)) continue;
            log.debug("Tool error detected for function: {}", (Object)toolUseId);
        }
        return List.of(toolMessage);
    }

    @Override
    public Map<String, ?> filterToFirstToolCall(Map<String, ?> dataAsMap, Map<String, String> parameters) {
        try {
            List partsList = (List)JsonPath.read(dataAsMap, (String)"$.candidates[0].content.parts", (Predicate[])new Predicate[0]);
            if (partsList == null || partsList.size() <= 1) {
                return dataAsMap;
            }
            ArrayList filteredParts = new ArrayList();
            ArrayList<String> allToolNames = new ArrayList<String>();
            String selectedToolName = null;
            boolean foundFirstFunctionCall = false;
            for (Object item : partsList) {
                if (item instanceof Map && ((Map)item).containsKey("functionCall")) {
                    Map functionCallMap = (Map)((Map)item).get("functionCall");
                    String toolName = functionCallMap != null ? String.valueOf(functionCallMap.get("name")) : "unknown";
                    allToolNames.add(toolName);
                    if (foundFirstFunctionCall) continue;
                    filteredParts.add(item);
                    selectedToolName = toolName;
                    foundFirstFunctionCall = true;
                    continue;
                }
                filteredParts.add(item);
            }
            if (!foundFirstFunctionCall) {
                return dataAsMap;
            }
            if (allToolNames.size() > 1) {
                log.info("LLM suggested {} tool(s): {}. Selected first tool: {}", (Object)allToolNames.size(), allToolNames, selectedToolName);
            }
            Map mutableCopy = (Map)StringUtils.gson.fromJson(StringUtils.toJson(dataAsMap), Map.class);
            DocumentContext context = JsonPath.parse((Object)mutableCopy);
            context.set("$.candidates[0].content.parts", filteredParts, new Predicate[0]);
            return (Map)context.json();
        }
        catch (Exception e) {
            log.error("Failed to filter out to only first tool call", (Throwable)e);
            return dataAsMap;
        }
    }
}

