From deb3ccf015e9dea35bf844603b0384224aff7e2d Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 30 Oct 2023 13:07:25 -0700 Subject: [PATCH] add concurrent limit on datasource and sessions (#2390) (#2395) * add concurrent limit on datasource and sessions * fix ut coverage --------- (cherry picked from commit d3ce049be22e6df7e83b57f8b61f8533241aab83) Signed-off-by: Peng Huo Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../sql/common/setting/Settings.java | 1 + datasources/build.gradle | 3 +- .../TransportCreateDataSourceAction.java | 41 ++-- .../TransportCreateDataSourceActionTest.java | 42 +++- docs/user/admin/settings.rst | 37 ++- .../sql/datasource/DataSourceAPIsIT.java | 51 ++++ .../setting/OpenSearchSettings.java | 14 ++ .../org/opensearch/sql/plugin/SQLPlugin.java | 6 +- .../spark/dispatcher/AsyncQueryHandler.java | 8 +- .../spark/dispatcher/BatchQueryHandler.java | 36 +++ .../dispatcher/InteractiveQueryHandler.java | 64 +++++ .../dispatcher/SparkQueryDispatcher.java | 221 ++++-------------- .../dispatcher/StreamingQueryHandler.java | 67 ++++++ .../model/DispatchQueryContext.java | 21 ++ .../execution/session/SessionManager.java | 16 -- .../execution/statestore/StateStore.java | 5 +- .../ConcurrencyLimitExceededException.java | 13 ++ .../leasemanager/DefaultLeaseManager.java | 69 ++++++ .../sql/spark/leasemanager/LeaseManager.java | 19 ++ .../leasemanager/model/LeaseRequest.java | 18 ++ ...AsyncQueryExecutorServiceImplSpecTest.java | 85 ++++++- .../dispatcher/SparkQueryDispatcherTest.java | 36 +-- .../leasemanager/DefaultLeaseManagerTest.java | 27 +++ 23 files changed, 649 insertions(+), 251 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/leasemanager/ConcurrencyLimitExceededException.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/leasemanager/LeaseManager.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/leasemanager/model/LeaseRequest.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java diff --git a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java index 6ef3921b39..61d23a1a34 100644 --- a/common/src/main/java/org/opensearch/sql/common/setting/Settings.java +++ b/common/src/main/java/org/opensearch/sql/common/setting/Settings.java @@ -34,6 +34,7 @@ public enum Key { QUERY_SIZE_LIMIT("plugins.query.size_limit"), ENCYRPTION_MASTER_KEY("plugins.query.datasources.encryption.masterkey"), DATASOURCES_URI_HOSTS_DENY_LIST("plugins.query.datasources.uri.hosts.denylist"), + DATASOURCES_LIMIT("plugins.query.datasources.limit"), METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"), METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"), diff --git a/datasources/build.gradle b/datasources/build.gradle index cdd2ee813b..be4e12b3bd 100644 --- a/datasources/build.gradle +++ b/datasources/build.gradle @@ -16,6 +16,7 @@ repositories { dependencies { implementation project(':core') implementation project(':protocol') + implementation project(':opensearch') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation group: 'org.opensearch', name: 'opensearch-x-content', version: "${opensearch_version}" implementation group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" @@ -35,7 +36,7 @@ dependencies { test { useJUnitPlatform() testLogging { - events "passed", "skipped", "failed" + events "skipped", "failed" exceptionFormat "full" } } diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceAction.java index b3c1ba4196..95e6493e05 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceAction.java @@ -7,6 +7,7 @@ package org.opensearch.sql.datasources.transport; +import static org.opensearch.sql.common.setting.Settings.Key.DATASOURCES_LIMIT; import static org.opensearch.sql.protocol.response.format.JsonResponseFormatter.Style.PRETTY; import org.opensearch.action.ActionType; @@ -30,6 +31,7 @@ public class TransportCreateDataSourceAction new ActionType<>(NAME, CreateDataSourceActionResponse::new); private DataSourceService dataSourceService; + private org.opensearch.sql.opensearch.setting.OpenSearchSettings settings; /** * TransportCreateDataSourceAction action for creating datasource. @@ -42,13 +44,15 @@ public class TransportCreateDataSourceAction public TransportCreateDataSourceAction( TransportService transportService, ActionFilters actionFilters, - DataSourceServiceImpl dataSourceService) { + DataSourceServiceImpl dataSourceService, + org.opensearch.sql.opensearch.setting.OpenSearchSettings settings) { super( TransportCreateDataSourceAction.NAME, transportService, actionFilters, CreateDataSourceActionRequest::new); this.dataSourceService = dataSourceService; + this.settings = settings; } @Override @@ -56,19 +60,28 @@ protected void doExecute( Task task, CreateDataSourceActionRequest request, ActionListener actionListener) { - try { - DataSourceMetadata dataSourceMetadata = request.getDataSourceMetadata(); - dataSourceService.createDataSource(dataSourceMetadata); - String responseContent = - new JsonResponseFormatter(PRETTY) { - @Override - protected Object buildJsonObject(String response) { - return response; - } - }.format("Created DataSource with name " + dataSourceMetadata.getName()); - actionListener.onResponse(new CreateDataSourceActionResponse(responseContent)); - } catch (Exception e) { - actionListener.onFailure(e); + int dataSourceLimit = settings.getSettingValue(DATASOURCES_LIMIT); + if (dataSourceService.getDataSourceMetadata(false).size() >= dataSourceLimit) { + actionListener.onFailure( + new IllegalStateException( + String.format( + "domain concurrent datasources can not" + " exceed %d", dataSourceLimit))); + } else { + try { + + DataSourceMetadata dataSourceMetadata = request.getDataSourceMetadata(); + dataSourceService.createDataSource(dataSourceMetadata); + String responseContent = + new JsonResponseFormatter(PRETTY) { + @Override + protected Object buildJsonObject(String response) { + return response; + } + }.format("Created DataSource with name " + dataSourceMetadata.getName()); + actionListener.onResponse(new CreateDataSourceActionResponse(responseContent)); + } catch (Exception e) { + actionListener.onFailure(e); + } } } } diff --git a/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceActionTest.java b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceActionTest.java index e104672fa3..2b9973b31b 100644 --- a/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceActionTest.java +++ b/datasources/src/test/java/org/opensearch/sql/datasources/transport/TransportCreateDataSourceActionTest.java @@ -1,14 +1,19 @@ package org.opensearch.sql.datasources.transport; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.common.setting.Settings.Key.DATASOURCES_LIMIT; import java.util.HashSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Captor; import org.mockito.Mock; @@ -21,6 +26,7 @@ import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionRequest; import org.opensearch.sql.datasources.model.transport.CreateDataSourceActionResponse; import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -29,9 +35,13 @@ public class TransportCreateDataSourceActionTest { @Mock private TransportService transportService; @Mock private TransportCreateDataSourceAction action; - @Mock private DataSourceServiceImpl dataSourceService; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private DataSourceServiceImpl dataSourceService; + @Mock private Task task; @Mock private ActionListener actionListener; + @Mock private OpenSearchSettings settings; @Captor private ArgumentCaptor @@ -43,7 +53,9 @@ public class TransportCreateDataSourceActionTest { public void setUp() { action = new TransportCreateDataSourceAction( - transportService, new ActionFilters(new HashSet<>()), dataSourceService); + transportService, new ActionFilters(new HashSet<>()), dataSourceService, settings); + when(dataSourceService.getDataSourceMetadata(false).size()).thenReturn(1); + when(settings.getSettingValue(DATASOURCES_LIMIT)).thenReturn(20); } @Test @@ -79,4 +91,30 @@ public void testDoExecuteWithException() { Assertions.assertTrue(exception instanceof RuntimeException); Assertions.assertEquals("Error", exception.getMessage()); } + + @Test + public void testDataSourcesLimit() { + DataSourceMetadata dataSourceMetadata = new DataSourceMetadata(); + dataSourceMetadata.setName("test_datasource"); + dataSourceMetadata.setConnector(DataSourceType.PROMETHEUS); + CreateDataSourceActionRequest request = new CreateDataSourceActionRequest(dataSourceMetadata); + when(dataSourceService.getDataSourceMetadata(false).size()).thenReturn(1); + when(settings.getSettingValue(DATASOURCES_LIMIT)).thenReturn(1); + + action.doExecute( + task, + request, + new ActionListener() { + @Override + public void onResponse(CreateDataSourceActionResponse createDataSourceActionResponse) { + fail(); + } + + @Override + public void onFailure(Exception e) { + assertEquals("domain concurrent datasources can not exceed 1", e.getMessage()); + } + }); + verify(dataSourceService, times(0)).createDataSource(dataSourceMetadata); + } } diff --git a/docs/user/admin/settings.rst b/docs/user/admin/settings.rst index dfce554e2b..3acb005c12 100644 --- a/docs/user/admin/settings.rst +++ b/docs/user/admin/settings.rst @@ -348,12 +348,12 @@ SQL query:: } plugins.query.executionengine.spark.session.limit -=================================================== +================================================== Description ----------- -Each datasource can have maximum 100 sessions running in parallel by default. You can increase limit by this setting. +Each cluster can have maximum 100 sessions running in parallel by default. You can increase limit by this setting. 1. The default value is 100. 2. This setting is node scope. @@ -383,3 +383,36 @@ SQL query:: } } + +plugins.query.datasources.limit +=============================== + +Description +----------- + +Each cluster can have maximum 20 datasources. You can increase limit by this setting. + +1. The default value is 20. +2. This setting is node scope. +3. This setting can be updated dynamically. + +You can update the setting with a new value like this. + +SQL query:: + + sh$ curl -sS -H 'Content-Type: application/json' -X PUT localhost:9200/_plugins/_query/settings \ + ... -d '{"transient":{"plugins.query.datasources.limit":25}}' + { + "acknowledged": true, + "persistent": {}, + "transient": { + "plugins": { + "query": { + "datasources": { + "limit": "25" + } + } + } + } + } + diff --git a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java index 92c1a4df16..c681b58eb4 100644 --- a/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/datasource/DataSourceAPIsIT.java @@ -22,6 +22,7 @@ import java.util.Map; import lombok.SneakyThrows; import org.apache.commons.lang3.StringUtils; +import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Test; @@ -34,6 +35,11 @@ public class DataSourceAPIsIT extends PPLIntegTestCase { + @After + public void cleanUp() throws IOException { + wipeAllClusterSettings(); + } + @AfterClass protected static void deleteDataSourcesCreated() throws IOException { Request deleteRequest = getDeleteDataSourceRequest("create_prometheus"); @@ -51,6 +57,10 @@ protected static void deleteDataSourcesCreated() throws IOException { deleteRequest = getDeleteDataSourceRequest("Create_Prometheus"); deleteResponse = client().performRequest(deleteRequest); Assert.assertEquals(204, deleteResponse.getStatusLine().getStatusCode()); + + deleteRequest = getDeleteDataSourceRequest("duplicate_prometheus"); + deleteResponse = client().performRequest(deleteRequest); + Assert.assertEquals(204, deleteResponse.getStatusLine().getStatusCode()); } @SneakyThrows @@ -283,4 +293,45 @@ public void issue2196() { Assert.assertNull(dataSourceMetadata.getProperties().get("prometheus.auth.password")); Assert.assertEquals("Prometheus Creation for Integ test", dataSourceMetadata.getDescription()); } + + @Test + public void datasourceLimitTest() throws InterruptedException, IOException { + DataSourceMetadata d1 = mockDataSourceMetadata("duplicate_prometheus"); + Request createRequest = getCreateDataSourceRequest(d1); + Response response = client().performRequest(createRequest); + Assert.assertEquals(201, response.getStatusLine().getStatusCode()); + // Datasource is not immediately created. so introducing a sleep of 2s. + Thread.sleep(2000); + + updateClusterSettings(new ClusterSetting(TRANSIENT, "plugins.query.datasources.limit", "1")); + + DataSourceMetadata d2 = mockDataSourceMetadata("d2"); + ResponseException exception = + Assert.assertThrows( + ResponseException.class, () -> client().performRequest(getCreateDataSourceRequest(d2))); + Assert.assertEquals(400, exception.getResponse().getStatusLine().getStatusCode()); + String prometheusGetResponseString = getResponseBody(exception.getResponse()); + JsonObject errorMessage = new Gson().fromJson(prometheusGetResponseString, JsonObject.class); + Assert.assertEquals( + "domain concurrent datasources can not exceed 1", + errorMessage.get("error").getAsJsonObject().get("details").getAsString()); + } + + public DataSourceMetadata mockDataSourceMetadata(String name) { + return new DataSourceMetadata( + name, + "Prometheus Creation for Integ test", + DataSourceType.PROMETHEUS, + ImmutableList.of(), + ImmutableMap.of( + "prometheus.uri", + "https://localhost:9090", + "prometheus.auth.type", + "basicauth", + "prometheus.auth.username", + "username", + "prometheus.auth.password", + "password"), + null); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java index d041eb386e..6b5f3cf0f1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/setting/OpenSearchSettings.java @@ -172,6 +172,13 @@ public class OpenSearchSettings extends Settings { Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting DATASOURCES_LIMIT_SETTING = + Setting.intSetting( + Key.DATASOURCES_LIMIT.getKeyValue(), + 20, + Setting.Property.NodeScope, + Setting.Property.Dynamic); + /** Construct OpenSearchSetting. The OpenSearchSetting must be singleton. */ @SuppressWarnings("unchecked") public OpenSearchSettings(ClusterSettings clusterSettings) { @@ -272,6 +279,12 @@ public OpenSearchSettings(ClusterSettings clusterSettings) { Key.AUTO_INDEX_MANAGEMENT_ENABLED, AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, new Updater(Key.AUTO_INDEX_MANAGEMENT_ENABLED)); + register( + settingBuilder, + clusterSettings, + Key.DATASOURCES_LIMIT, + DATASOURCES_LIMIT_SETTING, + new Updater(Key.DATASOURCES_LIMIT)); registerNonDynamicSettings( settingBuilder, clusterSettings, Key.CLUSTER_NAME, ClusterName.CLUSTER_NAME_SETTING); defaultSettings = settingBuilder.build(); @@ -342,6 +355,7 @@ public static List> pluginSettings() { .add(SESSION_INDEX_TTL_SETTING) .add(RESULT_INDEX_TTL_SETTING) .add(AUTO_INDEX_MANAGEMENT_ENABLED_SETTING) + .add(DATASOURCES_LIMIT_SETTING) .build(); } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 905c697e5b..9d37fe28d0 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -98,6 +98,7 @@ import org.opensearch.sql.spark.execution.session.SessionManager; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.RestAsyncQueryManagementAction; import org.opensearch.sql.spark.storage.SparkStorageFactory; @@ -258,7 +259,7 @@ public Collection createComponents( OpenSearchSettings.AUTO_INDEX_MANAGEMENT_ENABLED_SETTING, environment.settings()); return ImmutableList.of( - dataSourceService, asyncQueryExecutorService, clusterManagerEventListener); + dataSourceService, asyncQueryExecutorService, clusterManagerEventListener, pluginSettings); } @Override @@ -333,7 +334,8 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings), + new DefaultLeaseManager(pluginSettings, stateStore)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java index 77a0e1cd09..2823e64af7 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/AsyncQueryHandler.java @@ -12,6 +12,9 @@ import com.amazonaws.services.emrserverless.model.JobRunState; import org.json.JSONObject; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; /** Process async query request. */ public abstract class AsyncQueryHandler { @@ -45,5 +48,8 @@ protected abstract JSONObject getResponseFromResultIndex( protected abstract JSONObject getResponseFromExecutor( AsyncQueryJobMetadata asyncQueryJobMetadata); - abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); + public abstract String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata); + + public abstract DispatchQueryResponse submit( + DispatchQueryRequest request, DispatchQueryContext context); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java index 8a582278e1..c6bac9b288 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/BatchQueryHandler.java @@ -7,12 +7,21 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; +import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.JOB_TYPE_TAG_KEY; import com.amazonaws.services.emrserverless.model.GetJobRunResult; +import java.util.Map; import lombok.RequiredArgsConstructor; import org.json.JSONObject; +import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.JobType; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @RequiredArgsConstructor @@ -47,4 +56,31 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { asyncQueryJobMetadata.getApplicationId(), asyncQueryJobMetadata.getJobId()); return asyncQueryJobMetadata.getQueryId().getId(); } + + @Override + public DispatchQueryResponse submit( + DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { + String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + Map tags = context.getTags(); + DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); + + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource(context.getDataSourceMetadata()) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + false, + dataSourceMetadata.getResultIndex()); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse( + context.getQueryId(), jobId, false, dataSourceMetadata.getResultIndex(), null); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java index 52cc2efbe2..d75f568275 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/InteractiveQueryHandler.java @@ -6,24 +6,38 @@ package org.opensearch.sql.spark.dispatcher; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; +import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; +import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.JOB_TYPE_TAG_KEY; +import java.util.Map; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.json.JSONObject; +import org.opensearch.sql.datasource.model.DataSourceMetadata; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; import org.opensearch.sql.spark.execution.session.Session; import org.opensearch.sql.spark.execution.session.SessionId; import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.execution.statement.Statement; import org.opensearch.sql.spark.execution.statement.StatementId; import org.opensearch.sql.spark.execution.statement.StatementState; +import org.opensearch.sql.spark.leasemanager.LeaseManager; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; import org.opensearch.sql.spark.response.JobExecutionResponseReader; @RequiredArgsConstructor public class InteractiveQueryHandler extends AsyncQueryHandler { private final SessionManager sessionManager; private final JobExecutionResponseReader jobExecutionResponseReader; + private final LeaseManager leaseManager; @Override protected JSONObject getResponseFromResultIndex(AsyncQueryJobMetadata asyncQueryJobMetadata) { @@ -50,6 +64,56 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { return queryId; } + @Override + public DispatchQueryResponse submit( + DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { + Session session = null; + String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; + Map tags = context.getTags(); + DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); + + // todo, manage lease lifecycle + leaseManager.borrow( + new LeaseRequest(JobType.INTERACTIVE, dispatchQueryRequest.getDatasource())); + + if (dispatchQueryRequest.getSessionId() != null) { + // get session from request + SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); + Optional createdSession = sessionManager.getSession(sessionId); + if (createdSession.isPresent()) { + session = createdSession.get(); + } + } + if (session == null || !session.isReady()) { + // create session if not exist or session dead/fail + tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); + session = + sessionManager.createSession( + new CreateSessionRequest( + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .className(FLINT_SESSION_CLASS_NAME) + .dataSource(dataSourceMetadata) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), + tags, + dataSourceMetadata.getResultIndex(), + dataSourceMetadata.getName())); + } + session.submit( + new QueryRequest( + context.getQueryId(), + dispatchQueryRequest.getLangType(), + dispatchQueryRequest.getQuery())); + return new DispatchQueryResponse( + context.getQueryId(), + session.getSessionModel().getJobId(), + false, + dataSourceMetadata.getResultIndex(), + session.getSessionId().getSessionId()); + } + private Statement getStatementByQueryId(String sid, String qid) { SessionId sessionId = new SessionId(sid); Optional session = sessionManager.getSession(sessionId); diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java index b603ee6909..a800e45dd6 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcher.java @@ -7,7 +7,6 @@ import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD; import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD; -import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME; import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD; import com.amazonaws.services.emrserverless.model.JobRunState; @@ -15,7 +14,6 @@ import java.util.Base64; import java.util.HashMap; import java.util.Map; -import java.util.Optional; import java.util.concurrent.ExecutionException; import lombok.AllArgsConstructor; import lombok.Getter; @@ -33,17 +31,16 @@ import org.opensearch.sql.datasources.auth.DataSourceUserAuthorizationHelperImpl; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; -import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.dispatcher.model.*; -import org.opensearch.sql.spark.execution.session.CreateSessionRequest; -import org.opensearch.sql.spark.execution.session.Session; -import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.execution.session.SessionManager; -import org.opensearch.sql.spark.execution.statement.QueryRequest; import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; +import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; import org.opensearch.sql.spark.utils.SQLQueryUtils; @@ -72,19 +69,49 @@ public class SparkQueryDispatcher { private SessionManager sessionManager; + private LeaseManager leaseManager; + public DispatchQueryResponse dispatch(DispatchQueryRequest dispatchQueryRequest) { - if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) { - return handleSQLQuery(dispatchQueryRequest); - } else { - // Since we don't need any extra handling for PPL, we are treating it as normal dispatch - // Query. - return handleNonIndexQuery(dispatchQueryRequest); + DataSourceMetadata dataSourceMetadata = + this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); + dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); + + AsyncQueryHandler asyncQueryHandler = + sessionManager.isEnabled() + ? new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) + : new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader); + DispatchQueryContext.DispatchQueryContextBuilder contextBuilder = + DispatchQueryContext.builder() + .dataSourceMetadata(dataSourceMetadata) + .tags(getDefaultTagsForJobSubmission(dispatchQueryRequest)) + .queryId(AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName())); + + // override asyncQueryHandler with specific. + if (LangType.SQL.equals(dispatchQueryRequest.getLangType()) + && SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { + IndexQueryDetails indexQueryDetails = + SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); + fillMissingDetails(dispatchQueryRequest, indexQueryDetails); + contextBuilder.indexQueryDetails(indexQueryDetails); + + if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { + // todo, fix in DROP INDEX PR. + return handleDropIndexQuery(dispatchQueryRequest, indexQueryDetails); + } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType()) + && indexQueryDetails.isAutoRefresh()) { + asyncQueryHandler = + new StreamingQueryHandler(emrServerlessClient, jobExecutionResponseReader); + } else if (IndexQueryActionType.REFRESH.equals(indexQueryDetails.getIndexQueryActionType())) { + // manual refresh should be handled by batch handler + asyncQueryHandler = new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader); + } } + return asyncQueryHandler.submit(dispatchQueryRequest, contextBuilder.build()); } public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) .getQueryResponse(asyncQueryJobMetadata); } else { return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) @@ -94,7 +121,7 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata) public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { if (asyncQueryJobMetadata.getSessionId() != null) { - return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader) + return new InteractiveQueryHandler(sessionManager, jobExecutionResponseReader, leaseManager) .cancelJob(asyncQueryJobMetadata); } else { return new BatchQueryHandler(emrServerlessClient, jobExecutionResponseReader) @@ -102,25 +129,6 @@ public String cancelJob(AsyncQueryJobMetadata asyncQueryJobMetadata) { } } - private DispatchQueryResponse handleSQLQuery(DispatchQueryRequest dispatchQueryRequest) { - if (SQLQueryUtils.isFlintExtensionQuery(dispatchQueryRequest.getQuery())) { - IndexQueryDetails indexQueryDetails = - SQLQueryUtils.extractIndexDetails(dispatchQueryRequest.getQuery()); - fillMissingDetails(dispatchQueryRequest, indexQueryDetails); - - // TODO: refactor this code properly. - if (IndexQueryActionType.DROP.equals(indexQueryDetails.getIndexQueryActionType())) { - return handleDropIndexQuery(dispatchQueryRequest, indexQueryDetails); - } else if (IndexQueryActionType.CREATE.equals(indexQueryDetails.getIndexQueryActionType())) { - return handleStreamingQueries(dispatchQueryRequest, indexQueryDetails); - } else { - return handleFlintNonStreamingQueries(dispatchQueryRequest, indexQueryDetails); - } - } else { - return handleNonIndexQuery(dispatchQueryRequest); - } - } - // TODO: Revisit this logic. // Currently, Spark if datasource is not provided in query. // Spark Assumes the datasource to be catalog. @@ -135,151 +143,10 @@ private static void fillMissingDetails( } } - private DispatchQueryResponse handleStreamingQueries( - DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { - DataSourceMetadata dataSourceMetadata = - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); - dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; - Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName()); - if (indexQueryDetails.isAutoRefresh()) { - tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); - } - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .structuredStreaming(indexQueryDetails.isAutoRefresh()) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - indexQueryDetails.isAutoRefresh(), - dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - jobId, - false, - dataSourceMetadata.getResultIndex(), - null); - } - - private DispatchQueryResponse handleFlintNonStreamingQueries( - DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { - DataSourceMetadata dataSourceMetadata = - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); - dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; - Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - indexQueryDetails.isAutoRefresh(), - dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse( - AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), - jobId, - false, - dataSourceMetadata.getResultIndex(), - null); - } - - private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQueryRequest) { - DataSourceMetadata dataSourceMetadata = - this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); - AsyncQueryId queryId = AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()); - dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); - String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query"; - Map tags = getDefaultTagsForJobSubmission(dispatchQueryRequest); - - if (sessionManager.isEnabled()) { - Session session = null; - - if (dispatchQueryRequest.getSessionId() != null) { - // get session from request - SessionId sessionId = new SessionId(dispatchQueryRequest.getSessionId()); - Optional createdSession = sessionManager.getSession(sessionId); - if (createdSession.isPresent()) { - session = createdSession.get(); - } - } - if (session == null || !session.isReady()) { - // create session if not exist or session dead/fail - tags.put(JOB_TYPE_TAG_KEY, JobType.INTERACTIVE.getText()); - session = - sessionManager.createSession( - new CreateSessionRequest( - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .className(FLINT_SESSION_CLASS_NAME) - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()), - tags, - dataSourceMetadata.getResultIndex(), - dataSourceMetadata.getName())); - } - session.submit( - new QueryRequest( - queryId, dispatchQueryRequest.getLangType(), dispatchQueryRequest.getQuery())); - return new DispatchQueryResponse( - queryId, - session.getSessionModel().getJobId(), - false, - dataSourceMetadata.getResultIndex(), - session.getSessionId().getSessionId()); - } else { - tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); - StartJobRequest startJobRequest = - new StartJobRequest( - dispatchQueryRequest.getQuery(), - jobName, - dispatchQueryRequest.getApplicationId(), - dispatchQueryRequest.getExecutionRoleARN(), - SparkSubmitParameters.Builder.builder() - .dataSource( - dataSourceService.getRawDataSourceMetadata( - dispatchQueryRequest.getDatasource())) - .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) - .build() - .toString(), - tags, - false, - dataSourceMetadata.getResultIndex()); - String jobId = emrServerlessClient.startJobRun(startJobRequest); - return new DispatchQueryResponse( - queryId, jobId, false, dataSourceMetadata.getResultIndex(), null); - } - } - private DispatchQueryResponse handleDropIndexQuery( DispatchQueryRequest dispatchQueryRequest, IndexQueryDetails indexQueryDetails) { DataSourceMetadata dataSourceMetadata = this.dataSourceService.getRawDataSourceMetadata(dispatchQueryRequest.getDatasource()); - dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata); FlintIndexMetadata indexMetadata = flintIndexMetadataReader.getFlintIndexMetadata(indexQueryDetails); // if index is created without auto refresh. there is no job to cancel. diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java new file mode 100644 index 0000000000..81c3438532 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/StreamingQueryHandler.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher; + +import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.INDEX_TAG_KEY; +import static org.opensearch.sql.spark.dispatcher.SparkQueryDispatcher.JOB_TYPE_TAG_KEY; + +import java.util.Map; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; +import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters; +import org.opensearch.sql.spark.client.EMRServerlessClient; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryContext; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryRequest; +import org.opensearch.sql.spark.dispatcher.model.DispatchQueryResponse; +import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; +import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.response.JobExecutionResponseReader; + +/** Handle Streaming Query. */ +public class StreamingQueryHandler extends BatchQueryHandler { + private final EMRServerlessClient emrServerlessClient; + + public StreamingQueryHandler( + EMRServerlessClient emrServerlessClient, + JobExecutionResponseReader jobExecutionResponseReader) { + super(emrServerlessClient, jobExecutionResponseReader); + this.emrServerlessClient = emrServerlessClient; + } + + @Override + public DispatchQueryResponse submit( + DispatchQueryRequest dispatchQueryRequest, DispatchQueryContext context) { + String jobName = dispatchQueryRequest.getClusterName() + ":" + "index-query"; + IndexQueryDetails indexQueryDetails = context.getIndexQueryDetails(); + Map tags = context.getTags(); + tags.put(INDEX_TAG_KEY, indexQueryDetails.openSearchIndexName()); + DataSourceMetadata dataSourceMetadata = context.getDataSourceMetadata(); + tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText()); + StartJobRequest startJobRequest = + new StartJobRequest( + dispatchQueryRequest.getQuery(), + jobName, + dispatchQueryRequest.getApplicationId(), + dispatchQueryRequest.getExecutionRoleARN(), + SparkSubmitParameters.Builder.builder() + .dataSource(dataSourceMetadata) + .structuredStreaming(true) + .extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()) + .build() + .toString(), + tags, + indexQueryDetails.isAutoRefresh(), + dataSourceMetadata.getResultIndex()); + String jobId = emrServerlessClient.startJobRun(startJobRequest); + return new DispatchQueryResponse( + AsyncQueryId.newAsyncQueryId(dataSourceMetadata.getName()), + jobId, + false, + dataSourceMetadata.getResultIndex(), + null); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java new file mode 100644 index 0000000000..d3400d86bf --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/dispatcher/model/DispatchQueryContext.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.dispatcher.model; + +import java.util.Map; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; + +@Getter +@Builder +public class DispatchQueryContext { + private final AsyncQueryId queryId; + private final DataSourceMetadata dataSourceMetadata; + private final Map tags; + private final IndexQueryDetails indexQueryDetails; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 81b9fdaee0..c0f7bbcde8 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -6,11 +6,8 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_ENABLED; -import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_LIMIT; import static org.opensearch.sql.spark.execution.session.SessionId.newSessionId; -import static org.opensearch.sql.spark.execution.statestore.StateStore.activeSessionsCount; -import java.util.Locale; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.common.setting.Settings; @@ -29,15 +26,6 @@ public class SessionManager { private final Settings settings; public Session createSession(CreateSessionRequest request) { - int sessionMaxLimit = sessionMaxLimit(); - if (activeSessionsCount(stateStore, request.getDatasourceName()).get() >= sessionMaxLimit) { - String errorMsg = - String.format( - Locale.ROOT, - "The maximum number of active sessions can be " + "supported is %d", - sessionMaxLimit); - throw new IllegalArgumentException(errorMsg); - } InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId(request.getDatasourceName())) @@ -67,8 +55,4 @@ public Optional getSession(SessionId sid) { public boolean isEnabled() { return settings.getSettingValue(SPARK_EXECUTION_SESSION_ENABLED); } - - public int sessionMaxLimit() { - return settings.getSettingValue(SPARK_EXECUTION_SESSION_LIMIT); - } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java index f471d79c22..86d15a7036 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -63,6 +63,7 @@ public class StateStore { datasourceName -> String.format( "%s_%s", SPARK_REQUEST_BUFFER_INDEX_NAME, datasourceName.toLowerCase(Locale.ROOT)); + public static String ALL_DATASOURCE = "*"; private static final Logger LOG = LogManager.getLogger(); @@ -192,9 +193,6 @@ private void createIndex(String indexName) { } private long count(String indexName, QueryBuilder query) { - if (!this.clusterService.state().routingTable().hasIndex(indexName)) { - return 0; - } SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.query(query); searchSourceBuilder.size(0); @@ -301,7 +299,6 @@ public static Supplier activeSessionsCount(StateStore stateStore, String d .must( QueryBuilders.termQuery( SessionModel.SESSION_TYPE, SessionType.INTERACTIVE.getSessionType())) - .must(QueryBuilders.termQuery(SessionModel.DATASOURCE_NAME, datasourceName)) .must( QueryBuilders.termQuery( SessionModel.SESSION_STATE, SessionState.RUNNING.getSessionState()))); diff --git a/spark/src/main/java/org/opensearch/sql/spark/leasemanager/ConcurrencyLimitExceededException.java b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/ConcurrencyLimitExceededException.java new file mode 100644 index 0000000000..ab6305c835 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/ConcurrencyLimitExceededException.java @@ -0,0 +1,13 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.leasemanager; + +/** Concurrency limit exceeded. */ +public class ConcurrencyLimitExceededException extends RuntimeException { + public ConcurrencyLimitExceededException(String message) { + super(message); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java new file mode 100644 index 0000000000..1635a1801b --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManager.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.leasemanager; + +import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_SESSION_LIMIT; +import static org.opensearch.sql.spark.execution.statestore.StateStore.ALL_DATASOURCE; +import static org.opensearch.sql.spark.execution.statestore.StateStore.activeSessionsCount; + +import java.util.Arrays; +import java.util.List; +import java.util.function.Predicate; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; + +/** + * Default Lease Manager + *
  • QueryHandler borrow lease before execute the query. + *
  • LeaseManagerService check request against domain level concurrent limit. + *
  • LeaseManagerService running on data node and check limit based on cluster settings. + */ +public class DefaultLeaseManager implements LeaseManager { + + private final List> concurrentLimitRules; + private final Settings settings; + private final StateStore stateStore; + + public DefaultLeaseManager(Settings settings, StateStore stateStore) { + this.settings = settings; + this.stateStore = stateStore; + this.concurrentLimitRules = Arrays.asList(new ConcurrentSessionRule()); + } + + @Override + public void borrow(LeaseRequest request) { + for (Rule rule : concurrentLimitRules) { + if (!rule.test(request)) { + throw new ConcurrencyLimitExceededException(rule.description()); + } + } + } + + interface Rule extends Predicate { + String description(); + } + + public class ConcurrentSessionRule implements Rule { + @Override + public String description() { + return String.format("domain concurrent active session can not exceed %d", sessionMaxLimit()); + } + + @Override + public boolean test(LeaseRequest leaseRequest) { + if (leaseRequest.getJobType() != JobType.INTERACTIVE) { + return true; + } + return activeSessionsCount(stateStore, ALL_DATASOURCE).get() < sessionMaxLimit(); + } + + public int sessionMaxLimit() { + return settings.getSettingValue(SPARK_EXECUTION_SESSION_LIMIT); + } + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/leasemanager/LeaseManager.java b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/LeaseManager.java new file mode 100644 index 0000000000..6cc74ecdc5 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/LeaseManager.java @@ -0,0 +1,19 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.leasemanager; + +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; + +/** Lease manager */ +public interface LeaseManager { + + /** + * Borrow from LeaseManager. If no exception, lease successfully. + * + * @throws ConcurrencyLimitExceededException + */ + void borrow(LeaseRequest request); +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/leasemanager/model/LeaseRequest.java b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/model/LeaseRequest.java new file mode 100644 index 0000000000..190c033198 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/leasemanager/model/LeaseRequest.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.leasemanager.model; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.spark.dispatcher.model.JobType; + +/** Lease Request. */ +@Getter +@RequiredArgsConstructor +public class LeaseRequest { + private final JobType jobType; + private final String datasourceName; +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java index 5b04c8f7ea..862da697d1 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceImplSpecTest.java @@ -67,6 +67,8 @@ import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.sql.spark.flint.FlintIndexMetadataReaderImpl; +import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; +import org.opensearch.sql.spark.leasemanager.DefaultLeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; @@ -76,6 +78,7 @@ public class AsyncQueryExecutorServiceImplSpecTest extends OpenSearchIntegTestCase { public static final String DATASOURCE = "mys3"; + public static final String DSOTHER = "mytest"; private ClusterService clusterService; private org.opensearch.sql.common.setting.Settings pluginSettings; @@ -112,7 +115,7 @@ public void setup() { clusterSettings = clusterService.getClusterSettings(); pluginSettings = new OpenSearchSettings(clusterSettings); dataSourceService = createDataSourceService(); - DataSourceMetadata dataSourceMetadata = + DataSourceMetadata dm = new DataSourceMetadata( DATASOURCE, StringUtils.EMPTY, @@ -128,9 +131,27 @@ public void setup() { "glue.indexstore.opensearch.auth", "noauth"), null); - dataSourceService.createDataSource(dataSourceMetadata); + dataSourceService.createDataSource(dm); + DataSourceMetadata otherDm = + new DataSourceMetadata( + DSOTHER, + StringUtils.EMPTY, + DataSourceType.S3GLUE, + ImmutableList.of(), + ImmutableMap.of( + "glue.auth.type", + "iam_role", + "glue.auth.role_arn", + "arn:aws:iam::924196221507:role/FlintOpensearchServiceRole", + "glue.indexstore.opensearch.uri", + "http://localhost:9200", + "glue.indexstore.opensearch.auth", + "noauth"), + null); + dataSourceService.createDataSource(otherDm); stateStore = new StateStore(client, clusterService); - createIndex(dataSourceMetadata.fromNameToCustomResultIndex()); + createIndex(dm.fromNameToCustomResultIndex()); + createIndex(otherDm.fromNameToCustomResultIndex()); } @After @@ -186,6 +207,28 @@ public void withoutSessionCreateAsyncQueryThenGetResultThenCancel() { emrsClient.cancelJobRunCalled(1); } + @Test + public void sessionLimitNotImpactBatchQuery() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // disable session + enableSession(false); + setSessionLimit(0); + + // 1. create async query. + CreateAsyncQueryResponse response = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + emrsClient.startJobRunCalled(1); + + CreateAsyncQueryResponse resp2 = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + emrsClient.startJobRunCalled(2); + } + @Test public void createAsyncQueryCreateJobWithCorrectParameters() { LocalEMRSClient emrsClient = new LocalEMRSClient(); @@ -431,14 +474,13 @@ public void createSessionMoreThanLimitFailed() { setSessionState(first.getSessionId(), SessionState.RUNNING); // 2. create async query without session. - IllegalArgumentException exception = + ConcurrencyLimitExceededException exception = assertThrows( - IllegalArgumentException.class, + ConcurrencyLimitExceededException.class, () -> asyncQueryExecutorService.createAsyncQuery( new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null))); - assertEquals( - "The maximum number of active sessions can be supported is 1", exception.getMessage()); + assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); } // https://github.com/opensearch-project/sql/issues/2360 @@ -535,6 +577,32 @@ public void datasourceNameIncludeUppercase() { "--conf spark.sql.catalog.TESTS3=org.opensearch.sql.FlintDelegatingSessionCatalog")); } + @Test + public void concurrentSessionLimitIsDomainLevel() { + LocalEMRSClient emrsClient = new LocalEMRSClient(); + AsyncQueryExecutorService asyncQueryExecutorService = + createAsyncQueryExecutorService(emrsClient); + + // only allow one session in domain. + setSessionLimit(1); + + // 1. create async query. + CreateAsyncQueryResponse first = + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null)); + assertNotNull(first.getSessionId()); + setSessionState(first.getSessionId(), SessionState.RUNNING); + + // 2. create async query without session. + ConcurrencyLimitExceededException exception = + assertThrows( + ConcurrencyLimitExceededException.class, + () -> + asyncQueryExecutorService.createAsyncQuery( + new CreateAsyncQueryRequest("select 1", DSOTHER, LangType.SQL, null))); + assertEquals("domain concurrent active session can not exceed 1", exception.getMessage()); + } + private DataSourceServiceImpl createDataSourceService() { String masterKey = "a57d991d9b573f75b9bba1df"; DataSourceMetadataStorage dataSourceMetadataStorage = @@ -562,7 +630,8 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( jobExecutionResponseReader, new FlintIndexMetadataReaderImpl(client), client, - new SessionManager(stateStore, emrServerlessClient, pluginSettings)); + new SessionManager(stateStore, emrServerlessClient, pluginSettings), + new DefaultLeaseManager(pluginSettings, stateStore)); return new AsyncQueryExecutorServiceImpl( asyncQueryJobMetadataStorageService, sparkQueryDispatcher, diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index aaef4db6b8..7663ece350 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -74,6 +74,7 @@ import org.opensearch.sql.spark.flint.FlintIndexMetadata; import org.opensearch.sql.spark.flint.FlintIndexMetadataReader; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.leasemanager.LeaseManager; import org.opensearch.sql.spark.response.JobExecutionResponseReader; import org.opensearch.sql.spark.rest.model.LangType; @@ -93,6 +94,8 @@ public class SparkQueryDispatcherTest { @Mock private SessionManager sessionManager; + @Mock private LeaseManager leaseManager; + @Mock(answer = RETURNS_DEEP_STUBS) private Session session; @@ -115,7 +118,8 @@ void setUp() { jobExecutionResponseReader, flintIndexMetadataReader, openSearchClient, - sessionManager); + sessionManager, + leaseManager); } @Test @@ -636,6 +640,7 @@ void testDispatchShowMVQuery() { HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "SHOW MATERIALIZED VIEW IN mys3.default"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -672,7 +677,7 @@ void testDispatchShowMVQuery() { StartJobRequest expected = new StartJobRequest( query, - "TEST_CLUSTER:index-query", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -690,6 +695,7 @@ void testRefreshIndexQuery() { HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "REFRESH SKIPPING INDEX ON my_glue.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -702,7 +708,7 @@ void testRefreshIndexQuery() { when(emrServerlessClient.startJobRun( new StartJobRequest( query, - "TEST_CLUSTER:index-query", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -726,7 +732,7 @@ void testRefreshIndexQuery() { StartJobRequest expected = new StartJobRequest( query, - "TEST_CLUSTER:index-query", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -744,6 +750,7 @@ void testDispatchDescribeIndexQuery() { HashMap tags = new HashMap<>(); tags.put(DATASOURCE_TAG_KEY, "my_glue"); tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME); + tags.put(JOB_TYPE_TAG_KEY, JobType.BATCH.getText()); String query = "DESCRIBE SKIPPING INDEX ON mys3.default.http_logs"; String sparkSubmitParameters = constructExpectedSparkSubmitParameterString( @@ -780,7 +787,7 @@ void testDispatchDescribeIndexQuery() { StartJobRequest expected = new StartJobRequest( query, - "TEST_CLUSTER:index-query", + "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, sparkSubmitParameters, @@ -977,15 +984,6 @@ void testGetQueryResponseWithStatementNotExist() { @Test void testGetQueryResponseWithSuccess() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader, - openSearchClient, - sessionManager); JSONObject queryResult = new JSONObject(); Map resultMap = new HashMap<>(); resultMap.put(STATUS_FIELD, "SUCCESS"); @@ -1013,16 +1011,6 @@ void testGetQueryResponseWithSuccess() { // todo. refactor query process logic in plugin. @Test void testGetQueryResponseOfDropIndex() { - SparkQueryDispatcher sparkQueryDispatcher = - new SparkQueryDispatcher( - emrServerlessClient, - dataSourceService, - dataSourceUserAuthorizationHelper, - jobExecutionResponseReader, - flintIndexMetadataReader, - openSearchClient, - sessionManager); - String jobId = new SparkQueryDispatcher.DropIndexResult(JobRunState.SUCCESS.toString()).toJobId(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java new file mode 100644 index 0000000000..47111c3a38 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/leasemanager/DefaultLeaseManagerTest.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.leasemanager; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.spark.dispatcher.model.JobType; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.leasemanager.model.LeaseRequest; + +@ExtendWith(MockitoExtension.class) +class DefaultLeaseManagerTest { + @Mock private Settings settings; + + @Mock private StateStore stateStore; + + @Test + public void concurrentSessionRuleOnlyApplyToInteractiveQuery() { + new DefaultLeaseManager(settings, stateStore).borrow(new LeaseRequest(JobType.BATCH, "mys3")); + } +}