/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.connector;

import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.connector.MLExecuteConnectorRequest;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class ExecuteConnectorTransportAction
extends HandledTransportAction<ActionRequest, MLTaskResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(ExecuteConnectorTransportAction.class);
    Client client;
    ClusterService clusterService;
    ScriptService scriptService;
    NamedXContentRegistry xContentRegistry;
    ConnectorAccessControlHelper connectorAccessControlHelper;
    EncryptorImpl encryptor;
    MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public ExecuteConnectorTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, ClusterService clusterService, ScriptService scriptService, NamedXContentRegistry xContentRegistry, ConnectorAccessControlHelper connectorAccessControlHelper, EncryptorImpl encryptor, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/connectors/execute", transportService, actionFilters, MLExecuteConnectorRequest::new);
        this.client = client;
        this.clusterService = clusterService;
        this.scriptService = scriptService;
        this.xContentRegistry = xContentRegistry;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.encryptor = encryptor;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLTaskResponse> actionListener) {
        MLExecuteConnectorRequest executeConnectorRequest = MLExecuteConnectorRequest.fromActionRequest((ActionRequest)request);
        String connectorId = executeConnectorRequest.getConnectorId();
        if (executeConnectorRequest.getMlInput() == null) {
            actionListener.onFailure((Exception)new IllegalArgumentException("MLInput cannot be null"));
            return;
        }
        if (!(executeConnectorRequest.getMlInput().getInputDataset() instanceof RemoteInferenceInputDataSet)) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Input dataset must be of type RemoteInferenceInputDataSet"));
            return;
        }
        RemoteInferenceInputDataSet inputDataset = (RemoteInferenceInputDataSet)executeConnectorRequest.getMlInput().getInputDataset();
        String connectorAction = ConnectorAction.ActionType.EXECUTE.name();
        if (inputDataset.getParameters() != null && inputDataset.getParameters().get("connector_action") != null) {
            connectorAction = (String)inputDataset.getParameters().get("connector_action");
        }
        if (MLIndicesHandler.doesMultiTenantIndexExist((ClusterService)this.clusterService, (boolean)this.mlFeatureEnabledSetting.isMultiTenancyEnabled(), (String)".plugins-ml-connector")) {
            String finalConnectorAction = connectorAction;
            ActionListener listener = ActionListener.wrap(connector -> {
                if (this.connectorAccessControlHelper.validateConnectorAccess(this.client, (Connector)connector)) {
                    this.executeWithConnector((Connector)connector, finalConnectorAction, executeConnectorRequest, actionListener, true);
                }
            }, e -> {
                log.error("Failed to get connector " + connectorId, (Throwable)e);
                actionListener.onFailure(e);
            });
            try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
                this.connectorAccessControlHelper.getConnector(this.client, connectorId, (ActionListener<Connector>)ActionListener.runBefore((ActionListener)listener, () -> ((ThreadContext.StoredContext)threadContext).restore()));
            }
        } else {
            actionListener.onFailure((Exception)new ResourceNotFoundException("Can't find connector " + connectorId, new Object[0]));
        }
    }

    private void executeWithConnector(Connector connector, String action, MLExecuteConnectorRequest request, ActionListener<MLTaskResponse> listener, boolean decryptWithEncryptor) {
        String connectorTenantId = connector.getTenantId();
        if (decryptWithEncryptor) {
            connector.decrypt(action, (credential, tenantId) -> this.encryptor.decrypt(credential, tenantId), connectorTenantId);
        } else {
            connector.decrypt(action, (credential, tenantId) -> credential, connectorTenantId);
        }
        try {
            RemoteConnectorExecutor connectorExecutor = (RemoteConnectorExecutor)MLEngineClassLoader.initInstance((Object)connector.getProtocol(), (Object)connector, Connector.class);
            connectorExecutor.setConnectorPrivateIpEnabled(this.mlFeatureEnabledSetting.isConnectorPrivateIpEnabled());
            connectorExecutor.setScriptService(this.scriptService);
            connectorExecutor.setClusterService(this.clusterService);
            connectorExecutor.setClient(this.client);
            connectorExecutor.setXContentRegistry(this.xContentRegistry);
            connectorExecutor.executeAction(action, request.getMlInput(), ActionListener.wrap(response -> {
                connector.removeCredential();
                listener.onResponse(response);
            }, e -> {
                connector.removeCredential();
                listener.onFailure(e);
            }));
        }
        catch (Exception e2) {
            connector.removeCredential();
            listener.onFailure(e2);
        }
    }
}

