/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.rest.mcpserver;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.ActionType;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.annotation.ExperimentalApi;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.bytes.CompositeBytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.http.HttpChunk;
import org.opensearch.ml.action.mcpserver.McpAsyncServerHolder;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.mcpserver.action.MLMcpMessageAction;
import org.opensearch.ml.common.transport.mcpserver.requests.message.MLMcpMessageRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.StreamingRestChannel;
import org.opensearch.transport.client.node.NodeClient;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@ExperimentalApi
public class RestMcpConnectionMessageStreamingAction
extends BaseRestHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(RestMcpConnectionMessageStreamingAction.class);
    private static final String MCP_ACTION = "mcp_action";
    public static final String MESSAGE_ENDPOINT = "/_plugins/_ml/mcp/sse/message";
    public static final String SSE_ENDPOINT = "/_plugins/_ml/mcp/sse";
    private final ClusterService clusterService;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;

    public RestMcpConnectionMessageStreamingAction(ClusterService clusterService, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        this.clusterService = clusterService;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    public List<RestHandler.Route> routes() {
        return List.of(new RestHandler.Route(RestRequest.Method.GET, SSE_ENDPOINT), new RestHandler.Route(RestRequest.Method.POST, MESSAGE_ENDPOINT));
    }

    public String getName() {
        return MCP_ACTION;
    }

    public BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) {
        if (!this.mlFeatureEnabledSetting.isMcpServerEnabled()) {
            throw new OpenSearchException(MLCommonsSettings.ML_COMMONS_MCP_SERVER_DISABLED_MESSAGE, new Object[0]);
        }
        String path = request.path();
        String sessionId = request.param("sessionId");
        String sAppendToBaseUrl = request.param("append_to_base_url");
        boolean appendToBaseUrl = Optional.ofNullable(sAppendToBaseUrl).map(x -> Boolean.parseBoolean(sAppendToBaseUrl)).orElse(false);
        BaseRestHandler.StreamingRestChannelConsumer consumer = channel -> this.prepareRequestInternal(path, appendToBaseUrl, sessionId, (StreamingRestChannel)channel, client);
        return channel -> {
            if (channel instanceof StreamingRestChannel) {
                consumer.accept((Object)((StreamingRestChannel)channel));
            } else {
                ActionRequestValidationException validationError = new ActionRequestValidationException();
                validationError.addValidationError("Unable to initiate request / response streaming over non-streaming channel");
                channel.sendResponse((RestResponse)new BytesRestResponse(channel, (Exception)validationError));
            }
        };
    }

    @VisibleForTesting
    protected void prepareRequestInternal(String path, boolean appendToBaseUrl, String sessionId, StreamingRestChannel channel, NodeClient client) {
        if (path.equals(SSE_ENDPOINT)) {
            channel.prepareResponse(RestStatus.OK, Map.of("Content-Type", List.of("text/event-stream"), "Cache-Control", List.of("no-cache"), "Connection", List.of("keep-alive")));
            Mono.from((Publisher)channel).ofType(HttpChunk.class).map(HttpChunk::content).flatMap(x -> McpAsyncServerHolder.getMcpServerTransportProviderInstance().handleSseConnection(channel, appendToBaseUrl, this.clusterService.localNode().getId(), client)).flatMap(y -> Mono.fromRunnable(() -> {
                log.debug("starting to send sse connection chunk result");
                channel.sendChunk(y);
            })).onErrorResume(e -> Mono.fromRunnable(() -> {
                try {
                    channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, new Exception((Throwable)e)));
                }
                catch (IOException ex) {
                    log.error("Failed to send exception response to client during connection due to IOException");
                    throw new RuntimeException(ex);
                }
            })).subscribe();
        } else if (path.equals(MESSAGE_ENDPOINT)) {
            channel.prepareResponse(RestStatus.OK, Map.of("Content-Type", List.of("text/plain")));
            if (sessionId == null) {
                try {
                    channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, (Exception)new IllegalArgumentException("Session ID missing in message endpoint")));
                }
                catch (IOException ex) {
                    log.error("Failed to send exception response to client when sessionId is null");
                }
            } else {
                Flux.from((Publisher)channel).ofType(HttpChunk.class).takeUntil(HttpChunk::isLast).map(HttpChunk::content).reduce((xva$0, xva$1) -> CompositeBytesReference.of((BytesReference[])new BytesReference[]{xva$0, xva$1})).doOnSuccess(x -> {
                    String requestBody = x.utf8ToString();
                    ActionListener listener = ActionListener.wrap(r -> {
                        if (r.isExists()) {
                            String nodeId = String.valueOf(r.getSourceAsMap().get("node_id"));
                            DiscoveryNode node = (DiscoveryNode)this.clusterService.state().getNodes().getNodes().get(nodeId);
                            if (node == null) {
                                log.error("The node:{} is no longer in the current cluster, can not handle the mcp request", (Object)nodeId);
                                channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, (Exception)new IllegalStateException("Session no longer exists as corresponding node crashed, please recreate a new session in client side")));
                            } else if (this.clusterService.localNode().getId().equals(nodeId)) {
                                McpAsyncServerHolder.getMcpServerTransportProviderInstance().handleMessage(sessionId, requestBody).doOnSuccess(y -> {
                                    log.debug("Starting to send rest response to client in local node");
                                    channel.sendChunk(this.createRestResponse());
                                }).onErrorResume(e -> Mono.fromRunnable(() -> {
                                    try {
                                        log.error("Error occurred when handling message", e);
                                        channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, new Exception((Throwable)e)));
                                    }
                                    catch (IOException ex) {
                                        log.error("Failed to send exception response to client during message handling in local due to IOException, nodeId: {}", (Object)nodeId);
                                    }
                                })).subscribe();
                            } else {
                                ActionListener actionListener = ActionListener.wrap(y -> {
                                    if (y.isAcknowledged()) {
                                        log.debug("Starting to send rest response to client as peer node returns successfully");
                                        channel.sendChunk(this.createRestResponse());
                                    }
                                }, e -> {
                                    log.error("MCP request has been dispatched to corresponding node but peer node failed to handle it", (Throwable)e);
                                    try {
                                        channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, e));
                                    }
                                    catch (IOException ex) {
                                        log.error("Failed to send exception response to client during message handling in remote node due to IOException, nodeId: {}", (Object)nodeId);
                                    }
                                });
                                client.execute((ActionType)MLMcpMessageAction.INSTANCE, (ActionRequest)new MLMcpMessageRequest(nodeId, sessionId, requestBody), actionListener);
                            }
                        } else {
                            log.error("SessionId not found in cluster, sessionId: {}", (Object)sessionId);
                            try {
                                channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, (Exception)new IllegalArgumentException("SessionId not found in cluster, please try to create session first, sessionId: " + sessionId)));
                            }
                            catch (IOException ex) {
                                log.error("Failed to send exception response to client when session ID not found in cluster state, sessionId: {}", (Object)sessionId);
                            }
                        }
                    }, e -> {
                        try {
                            channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, e));
                        }
                        catch (IOException ex) {
                            log.error("Failed to get the session management index result with sessionId: {}", (Object)sessionId);
                        }
                    });
                    this.getDiscoveryNode(client, sessionId, (ActionListener<GetResponse>)listener);
                }).onErrorResume(e -> Mono.fromRunnable(() -> {
                    try {
                        channel.sendResponse((RestResponse)new BytesRestResponse((RestChannel)channel, new Exception((Throwable)e)));
                    }
                    catch (IOException ex) {
                        log.error("Failed to send exception response to client during message handling due to IOException", (Throwable)ex);
                    }
                })).subscribe();
            }
        }
    }

    private HttpChunk createRestResponse() {
        return new HttpChunk(this){

            public boolean isLast() {
                return true;
            }

            public BytesReference content() {
                return BytesReference.fromByteBuffer((ByteBuffer)ByteBuffer.wrap("OK".getBytes(StandardCharsets.UTF_8)));
            }

            public void close() {
            }
        };
    }

    private void getDiscoveryNode(NodeClient client, String sessionId, ActionListener<GetResponse> listener) {
        GetRequest getRequest = new GetRequest(".plugins-ml-mcp-session-management", sessionId);
        client.get(getRequest, listener);
    }

    public boolean supportsContentStream() {
        return true;
    }

    public boolean supportsStreaming() {
        return true;
    }

    public boolean allowsUnsafeBuffers() {
        return true;
    }
}

