/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
import lombok.Generated;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.bulk.BackoffPolicy;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.action.support.RetryableAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.connector.MLPreProcessFunction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.ExecutionContext;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorThrottlingException;
import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener;
import org.opensearch.ml.engine.processor.ProcessorChain;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.client.Client;

public interface RemoteConnectorExecutor {
    public static final String RETRY_EXECUTOR = "opensearch_ml_predict_remote";

    default public void executeAction(String action, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
        this.executeAction(action, mlInput, actionListener, null);
    }

    default public void executeAction(String action, MLInput mlInput, ActionListener<MLTaskResponse> actionListener, TransportChannel channel) {
        if (channel != null) {
            ActionListener streamingListener = ActionListener.wrap(response -> {
                ModelTensors tensors = (ModelTensors)response.v2();
                MLTaskResponse mlResponse = new MLTaskResponse((MLOutput)new ModelTensorOutput(Arrays.asList(tensors)));
                actionListener.onResponse((Object)mlResponse);
            }, arg_0 -> actionListener.onFailure(arg_0));
            this.preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), (ActionListener<Tuple<Integer, ModelTensors>>)streamingListener, actionListener, channel);
            return;
        }
        ActionListener tensorActionListener = ActionListener.wrap(r -> {
            ModelTensors[] modelTensors = new ModelTensors[r.size()];
            r.forEach(sequenceNoAndModelTensor -> {
                modelTensors[((Integer)sequenceNoAndModelTensor.v1()).intValue()] = (ModelTensors)sequenceNoAndModelTensor.v2();
            });
            actionListener.onResponse((Object)new MLTaskResponse((MLOutput)new ModelTensorOutput(Arrays.asList(modelTensors))));
        }, arg_0 -> actionListener.onFailure(arg_0));
        try {
            if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
                TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset();
                Tuple<Integer, Integer> calculatedChunkSize = this.calculateChunkSize(action, textDocsInputDataSet);
                GroupedActionListener groupedActionListener = new GroupedActionListener(tensorActionListener, ((Integer)calculatedChunkSize.v1()).intValue());
                int sequence = 0;
                for (int processedDocs = 0; processedDocs < textDocsInputDataSet.getDocs().size(); processedDocs += ((Integer)calculatedChunkSize.v2()).intValue()) {
                    List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, Math.min(processedDocs + (Integer)calculatedChunkSize.v2(), textDocsInputDataSet.getDocs().size()));
                    this.preparePayloadAndInvoke(action, MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).parameters(mlInput.getParameters()).inputDataset((MLInputDataset)TextDocsInputDataSet.builder().docs(textDocs).build()).build(), new ExecutionContext(sequence++), (ActionListener<Tuple<Integer, ModelTensors>>)groupedActionListener);
                }
            } else {
                this.preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), (ActionListener<Tuple<Integer, ModelTensors>>)new GroupedActionListener(tensorActionListener, 1));
            }
        }
        catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    private Tuple<Integer, Integer> calculateChunkSize(String action, TextDocsInputDataSet textDocsInputDataSet) {
        int textDocsLength = textDocsInputDataSet.getDocs().size();
        Map parameters = this.getConnector().getParameters();
        if (parameters != null && parameters.containsKey("input_docs_processed_step_size")) {
            boolean isDivisible;
            int stepSize = Integer.parseInt((String)parameters.get("input_docs_processed_step_size"));
            if (stepSize <= 0) {
                throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
            }
            boolean bl = isDivisible = textDocsLength % stepSize == 0;
            if (isDivisible) {
                return Tuple.tuple((Object)(textDocsLength / stepSize), (Object)stepSize);
            }
            return Tuple.tuple((Object)(textDocsLength / stepSize + 1), (Object)stepSize);
        }
        Optional connectorAction = this.getConnector().findAction(action);
        if (connectorAction.isEmpty()) {
            throw new IllegalArgumentException("no " + action + " action found");
        }
        String preProcessFunction = ((ConnectorAction)connectorAction.get()).getPreProcessFunction();
        if (preProcessFunction == null) {
            return Tuple.tuple((Object)1, (Object)textDocsLength);
        }
        if ("connector.pre_process.bedrock.embedding".equals(preProcessFunction) || !MLPreProcessFunction.contains((String)preProcessFunction)) {
            return Tuple.tuple((Object)textDocsLength, (Object)1);
        }
        return Tuple.tuple((Object)1, (Object)textDocsLength);
    }

    default public void setScriptService(ScriptService scriptService) {
    }

    public ScriptService getScriptService();

    public Connector getConnector();

    public TokenBucket getRateLimiter();

    public Map<String, TokenBucket> getUserRateLimiterMap();

    public MLGuard getMlGuard();

    public Client getClient();

    public Logger getLogger();

    public ConnectorClientConfig getConnectorClientConfig();

    default public void setClient(Client client) {
    }

    default public void setConnectorPrivateIpEnabled(AtomicBoolean connectorPrivateIpEnabled) {
    }

    default public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
    }

    default public void setClusterService(ClusterService clusterService) {
    }

    default public void setRateLimiter(TokenBucket rateLimiter) {
    }

    default public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
    }

    default public void setMlGuard(MLGuard mlGuard) {
    }

    default public void preparePayloadAndInvoke(String action, MLInput mlInput, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        this.preparePayloadAndInvoke(action, mlInput, executionContext, actionListener, null, null);
    }

    default public void preparePayloadAndInvoke(String action, MLInput mlInput, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener, ActionListener<MLTaskResponse> agentListener, TransportChannel channel) {
        RemoteInferenceInputDataSet inputData;
        Connector connector = this.getConnector();
        HashMap<String, String> parameters = new HashMap<String, String>();
        if (connector.getParameters() != null) {
            parameters.putAll(connector.getParameters());
        }
        MLInputDataset inputDataset = mlInput.getInputDataset();
        HashMap inputParameters = new HashMap();
        if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet)inputDataset).getParameters() != null) {
            ConnectorUtils.escapeRemoteInferenceInputData((RemoteInferenceInputDataSet)inputDataset);
            inputParameters.putAll(((RemoteInferenceInputDataSet)inputDataset).getParameters());
        }
        parameters.putAll(inputParameters);
        MLAlgoParams algoParams = mlInput.getParameters();
        if (algoParams != null) {
            try {
                Map<String, String> parametersMap = RemoteConnectorExecutor.getParams(mlInput);
                parameters.putAll(parametersMap);
            }
            catch (IOException e) {
                actionListener.onFailure((Exception)e);
                return;
            }
        }
        if ((inputData = ConnectorUtils.processInput(action, mlInput, connector, parameters, this.getScriptService())).getParameters() != null) {
            parameters.putAll(inputData.getParameters());
        }
        parameters.putAll(inputParameters);
        String payload = (String)connector.createPayload(action, parameters);
        List<Map<String, Object>> processorConfigs = ProcessorChain.extractProcessorConfigs(parameters, "input_processors");
        if (!processorConfigs.isEmpty()) {
            ProcessorChain processorChain = new ProcessorChain(processorConfigs);
            payload = StringUtils.toJson((Object)processorChain.process(payload));
        }
        if (!Boolean.parseBoolean(parameters.getOrDefault("skip_validating_missing_parameters", "false"))) {
            connector.validatePayload(payload);
        }
        String userStr = (String)this.getClient().threadPool().getThreadContext().getTransient("_opendistro_security_user_info");
        User user = User.parse((String)userStr);
        if (this.getRateLimiter() != null && !this.getRateLimiter().request()) {
            this.getLogger().error("Request is throttled at model level.");
            throw new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (user != null && this.getUserRateLimiterMap() != null && this.getUserRateLimiterMap().get(user.getName()) != null && !this.getUserRateLimiterMap().get(user.getName()).request()) {
            this.getLogger().error("Request is throttled at user level.");
            throw new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (this.getMlGuard() != null && !this.getMlGuard().validate(payload, MLGuard.Type.INPUT, parameters).booleanValue()) {
            this.getLogger().error("guardrails triggered for user input");
            throw new IllegalArgumentException("guardrails triggered for user input");
        }
        if (this.getConnectorClientConfig().getMaxRetryTimes() != 0) {
            this.invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener);
        } else if (parameters.containsKey("stream")) {
            String memoryId = (String)parameters.get("memory_id");
            String parentInteractionId = (String)parameters.get("parent_interaction_id");
            boolean isAgentRequest = memoryId != null || parentInteractionId != null;
            StreamPredictActionListener streamListener = new StreamPredictActionListener(channel, (ActionListener<MLTaskResponse>)(isAgentRequest ? agentListener : null), memoryId, parentInteractionId);
            this.invokeRemoteServiceStream(action, mlInput, parameters, payload, executionContext, streamListener);
        } else {
            this.invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener);
        }
    }

    public static Map<String, String> getParams(MLInput mlInput) throws IOException {
        XContentBuilder builder = XContentFactory.jsonBuilder();
        mlInput.getParameters().toXContent(builder, ToXContent.EMPTY_PARAMS);
        builder.flush();
        String json = builder.toString();
        Map tempMap = (Map)StringUtils.MAPPER.readValue(json, Map.class);
        return StringUtils.getParameterMap((Map)tempMap);
    }

    default public BackoffPolicy getRetryBackoffPolicy(ConnectorClientConfig connectorClientConfig) {
        switch (connectorClientConfig.getRetryBackoffPolicy()) {
            case EXPONENTIAL_EQUAL_JITTER: {
                return BackoffPolicy.exponentialEqualJitterBackoff((long)connectorClientConfig.getRetryBackoffMillis().intValue(), (long)(connectorClientConfig.getRetryTimeoutSeconds() * 1000));
            }
            case EXPONENTIAL_FULL_JITTER: {
                return BackoffPolicy.exponentialFullJitterBackoff((long)connectorClientConfig.getRetryBackoffMillis().intValue());
            }
        }
        return BackoffPolicy.constantBackoff((TimeValue)TimeValue.timeValueMillis((long)connectorClientConfig.getRetryBackoffMillis().intValue()), (int)Integer.MAX_VALUE);
    }

    default public void invokeRemoteServiceWithRetry(String action, MLInput mlInput, Map<String, String> parameters, String payload, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        RetryableActionExtension invokeRemoteModelAction = new RetryableActionExtension(this.getLogger(), this.getClient().threadPool(), TimeValue.timeValueMillis((long)this.getConnectorClientConfig().getRetryBackoffMillis().intValue()), TimeValue.timeValueSeconds((long)this.getConnectorClientConfig().getRetryTimeoutSeconds().intValue()), actionListener, this.getRetryBackoffPolicy(this.getConnectorClientConfig()), RetryableActionExtensionArgs.builder().connectionExecutor(this).mlInput(mlInput).action(action).parameters(parameters).executionContext(executionContext).payload(payload).build());
        invokeRemoteModelAction.run();
    }

    public void invokeRemoteService(String var1, MLInput var2, Map<String, String> var3, String var4, ExecutionContext var5, ActionListener<Tuple<Integer, ModelTensors>> var6);

    public void invokeRemoteServiceStream(String var1, MLInput var2, Map<String, String> var3, String var4, ExecutionContext var5, StreamPredictActionListener<MLTaskResponse, ?> var6);

    public static class RetryableActionExtension
    extends RetryableAction<Tuple<Integer, ModelTensors>> {
        private final RetryableActionExtensionArgs args;
        int retryTimes = 0;

        RetryableActionExtension(Logger logger, ThreadPool threadPool, TimeValue initialDelay, TimeValue timeoutValue, ActionListener<Tuple<Integer, ModelTensors>> listener, BackoffPolicy backoffPolicy, RetryableActionExtensionArgs args) {
            super(logger, threadPool, initialDelay, timeoutValue, listener, backoffPolicy, RemoteConnectorExecutor.RETRY_EXECUTOR);
            this.args = args;
        }

        public void tryAction(ActionListener<Tuple<Integer, ModelTensors>> listener) {
            this.args.connectionExecutor.invokeRemoteService(this.args.action, this.args.mlInput, this.args.parameters, this.args.payload, this.args.executionContext, listener);
        }

        public boolean shouldRetry(Exception e) {
            Throwable cause = ExceptionsHelper.unwrapCause((Throwable)e);
            Integer maxRetryTimes = this.args.connectionExecutor.getConnectorClientConfig().getMaxRetryTimes();
            boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException;
            if (++this.retryTimes > maxRetryTimes && maxRetryTimes != -1) {
                shouldRetry = false;
            }
            if (shouldRetry) {
                this.args.connectionExecutor.getLogger().debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", this.retryTimes), (Throwable)e);
            }
            return shouldRetry;
        }
    }

    public static class RetryableActionExtensionArgs {
        private final RemoteConnectorExecutor connectionExecutor;
        private final MLInput mlInput;
        private final String action;
        private final Map<String, String> parameters;
        private final ExecutionContext executionContext;
        private final String payload;

        @Generated
        RetryableActionExtensionArgs(RemoteConnectorExecutor connectionExecutor, MLInput mlInput, String action, Map<String, String> parameters, ExecutionContext executionContext, String payload) {
            this.connectionExecutor = connectionExecutor;
            this.mlInput = mlInput;
            this.action = action;
            this.parameters = parameters;
            this.executionContext = executionContext;
            this.payload = payload;
        }

        @Generated
        public static RetryableActionExtensionArgsBuilder builder() {
            return new RetryableActionExtensionArgsBuilder();
        }

        @Generated
        public static class RetryableActionExtensionArgsBuilder {
            @Generated
            private RemoteConnectorExecutor connectionExecutor;
            @Generated
            private MLInput mlInput;
            @Generated
            private String action;
            @Generated
            private Map<String, String> parameters;
            @Generated
            private ExecutionContext executionContext;
            @Generated
            private String payload;

            @Generated
            RetryableActionExtensionArgsBuilder() {
            }

            @Generated
            public RetryableActionExtensionArgsBuilder connectionExecutor(RemoteConnectorExecutor connectionExecutor) {
                this.connectionExecutor = connectionExecutor;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgsBuilder mlInput(MLInput mlInput) {
                this.mlInput = mlInput;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgsBuilder action(String action) {
                this.action = action;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgsBuilder parameters(Map<String, String> parameters) {
                this.parameters = parameters;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgsBuilder executionContext(ExecutionContext executionContext) {
                this.executionContext = executionContext;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgsBuilder payload(String payload) {
                this.payload = payload;
                return this;
            }

            @Generated
            public RetryableActionExtensionArgs build() {
                return new RetryableActionExtensionArgs(this.connectionExecutor, this.mlInput, this.action, this.parameters, this.executionContext, this.payload);
            }

            @Generated
            public String toString() {
                return "RemoteConnectorExecutor.RetryableActionExtensionArgs.RetryableActionExtensionArgsBuilder(connectionExecutor=" + String.valueOf(this.connectionExecutor) + ", mlInput=" + String.valueOf(this.mlInput) + ", action=" + this.action + ", parameters=" + String.valueOf(this.parameters) + ", executionContext=" + String.valueOf(this.executionContext) + ", payload=" + this.payload + ")";
            }
        }
    }
}

