/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote.streaming;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.security.AccessController;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Generated;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.internal.http2.StreamResetException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.streaming.BaseStreamingHandler;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;

public class HttpStreamingHandler
extends BaseStreamingHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(HttpStreamingHandler.class);
    private final Connector connector;
    private OkHttpClient okHttpClient;
    private String llmInterface;

    public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorClientConfig connectorClientConfig) {
        this.connector = connector;
        this.llmInterface = llmInterface;
        Duration connectionTimeout = Duration.ofSeconds(connectorClientConfig.getConnectionTimeout().intValue());
        Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeout().intValue());
        try {
            AccessController.doPrivileged(() -> {
                this.okHttpClient = new OkHttpClient.Builder().connectTimeout(connectionTimeout).readTimeout(readTimeout).retryOnConnectionFailure(true).build();
                return null;
            });
        }
        catch (Exception e) {
            throw new RuntimeException("Failed to build OkHttpClient", e);
        }
    }

    @Override
    public void startStream(String action, Map<String, String> parameters, String payload, StreamPredictActionListener<MLTaskResponse, ?> actionListener) {
        try {
            log.info("Creating SSE connection for streaming request");
            HTTPEventSourceListener listener = new HTTPEventSourceListener(actionListener, this.llmInterface);
            Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, this.connector, parameters, payload);
            AccessController.doPrivileged(() -> {
                EventSources.createFactory((OkHttpClient)this.okHttpClient).newEventSource(request, listener);
                return null;
            });
        }
        catch (Exception e) {
            log.error("Failed to start HTTP streaming", (Throwable)e);
            this.handleError(e, actionListener);
        }
    }

    @Override
    public void handleError(Throwable error, StreamPredictActionListener<MLTaskResponse, ?> listener) {
        log.error("HTTP streaming error", error);
        listener.onFailure((Exception)new MLException("Fail to execute streaming", error));
    }

    public final class HTTPEventSourceListener
    extends EventSourceListener {
        private StreamPredictActionListener<MLTaskResponse, ?> streamActionListener;
        private final String llmInterface;
        private AtomicBoolean isStreamClosed;
        private boolean functionCallInProgress = false;
        private boolean agentExecutionInProgress = false;
        private String accumulatedToolCallId = null;
        private String accumulatedToolName = null;
        private String accumulatedArguments = "";

        public HTTPEventSourceListener(StreamPredictActionListener<MLTaskResponse, ?> streamActionListener, String llmInterface) {
            this.streamActionListener = streamActionListener;
            this.llmInterface = llmInterface;
            this.isStreamClosed = new AtomicBoolean(false);
        }

        public void onOpen(EventSource eventSource, Response response) {
            log.debug("Connected to SSE Endpoint.");
        }

        public void onEvent(EventSource eventSource, String id, String type, String data) {
            log.debug("The data is: {}", (Object)data);
            switch (this.llmInterface) {
                case "openai/v1/chat/completions": {
                    this.onOpenAIEvent(data);
                    break;
                }
                default: {
                    throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", this.llmInterface));
                }
            }
        }

        public void onClosed(EventSource eventSource) {
            log.debug("SSE CLOSED.");
        }

        public void onFailure(EventSource eventSource, Throwable t, Response response) {
            if (t != null) {
                log.error("Error: " + t.getMessage(), t);
                if (!(t instanceof StreamResetException) || !t.getMessage().contains("NO_ERROR")) {
                    this.streamActionListener.onFailure((Exception)new MLException("SSE failure with network error", t));
                }
            } else if (response != null) {
                try {
                    String errorBody = response.body() != null ? response.body().string() : "";
                    this.streamActionListener.onFailure((Exception)new MLException("Error from remote service: " + errorBody));
                }
                catch (IOException e) {
                    this.streamActionListener.onFailure((Exception)new MLException("SSE failure - unable to read error details"));
                }
            } else {
                this.streamActionListener.onFailure((Exception)new MLException("SSE failure"));
            }
        }

        private void onOpenAIEvent(String data) {
            if ("[DONE]".equals(data)) {
                this.handleDoneEvent();
                return;
            }
            try {
                Map dataMap = (Map)StringUtils.gson.fromJson(data, Map.class);
                this.processStreamChunk(dataMap);
            }
            catch (Exception e) {
                log.debug("Skipping malformed chunk: {}", (Object)data);
            }
        }

        private void handleDoneEvent() {
            if (!this.agentExecutionInProgress) {
                HttpStreamingHandler.this.sendCompletionResponse(this.isStreamClosed, this.streamActionListener);
            }
        }

        private void processStreamChunk(Map<String, Object> dataMap) {
            List toolCalls;
            String finishReason = (String)this.extractPath(dataMap, "$.choices[0].finish_reason");
            if ("stop".equals(finishReason)) {
                this.agentExecutionInProgress = false;
                HttpStreamingHandler.this.sendCompletionResponse(this.isStreamClosed, this.streamActionListener);
                return;
            }
            String content = (String)this.extractPath(dataMap, "$.choices[0].delta.content");
            if (content != null && !content.isEmpty()) {
                HttpStreamingHandler.this.sendContentResponse(content, false, this.streamActionListener);
            }
            if ((toolCalls = (List)this.extractPath(dataMap, "$.choices[0].delta.tool_calls")) != null) {
                this.accumulateFunctionCall(toolCalls);
                HttpStreamingHandler.this.sendContentResponse(StringUtils.toJson((Object)toolCalls), false, this.streamActionListener);
            }
            if ("tool_calls".equals(finishReason) && this.functionCallInProgress) {
                this.completeToolCall();
            }
        }

        private <T> T extractPath(Map<String, Object> dataMap, String path) {
            try {
                return (T)JsonPath.read(dataMap, (String)path, (Predicate[])new Predicate[0]);
            }
            catch (Exception e) {
                return null;
            }
        }

        private void completeToolCall() {
            this.agentExecutionInProgress = true;
            String completeFunctionCall = this.buildCompleteFunctionCallResponse();
            HttpStreamingHandler.this.sendContentResponse(completeFunctionCall, false, this.streamActionListener);
            Map response = (Map)StringUtils.gson.fromJson(completeFunctionCall, Map.class);
            ModelTensorOutput output = this.createModelTensorOutput(response);
            this.streamActionListener.onResponse(new MLTaskResponse((MLOutput)output));
            this.functionCallInProgress = false;
        }

        private String buildCompleteFunctionCallResponse() {
            Map<String, String> function = Map.of("name", this.accumulatedToolName, "arguments", this.accumulatedArguments);
            Map<String, Map<String, String>> toolCall = Map.of("id", this.accumulatedToolCallId, "type", "function", "function", function);
            Map<String, List<Map<String, Map<String, String>>>> message = Map.of("tool_calls", List.of(toolCall));
            Map<String, String> choice = Map.of("message", message, "finish_reason", "tool_calls");
            Map<String, List<Map<String, String>>> response = Map.of("choices", List.of(choice));
            return StringUtils.toJson(response);
        }

        private ModelTensorOutput createModelTensorOutput(Map<String, Object> responseData) {
            ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseData).build();
            ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build();
            return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
        }

        private void accumulateFunctionCall(List<?> toolCalls) {
            this.functionCallInProgress = true;
            for (Object toolCall : toolCalls) {
                Map tcMap = (Map)toolCall;
                if (tcMap.containsKey("id")) {
                    this.accumulatedToolCallId = (String)tcMap.get("id");
                }
                if (!tcMap.containsKey("function")) continue;
                Map func = (Map)tcMap.get("function");
                if (func.containsKey("name")) {
                    this.accumulatedToolName = (String)func.get("name");
                }
                if (!func.containsKey("arguments")) continue;
                this.accumulatedArguments = this.accumulatedArguments + (String)func.get("arguments");
            }
        }
    }
}

