From 93e313444b16f0e16d4632dccee15c7670d2c00f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 17 Oct 2023 15:56:24 +0000 Subject: [PATCH] Minor Refactoring (#2308) Signed-off-by: Vamsi Manohar (cherry picked from commit 69572c8cca278a500db7710fb415cc58a2589c78) Signed-off-by: github-actions[bot] --- spark/src/main/antlr/SqlBaseParser.g4 | 2 +- .../sql/spark/client/StartJobRequest.java | 2 + .../dispatcher/SparkQueryDispatcherTest.java | 329 +++++++++--------- 3 files changed, 158 insertions(+), 175 deletions(-) diff --git a/spark/src/main/antlr/SqlBaseParser.g4 b/spark/src/main/antlr/SqlBaseParser.g4 index 6a6d39e96c..77a9108e06 100644 --- a/spark/src/main/antlr/SqlBaseParser.g4 +++ b/spark/src/main/antlr/SqlBaseParser.g4 @@ -967,7 +967,6 @@ primaryExpression | qualifiedName DOT ASTERISK #star | LEFT_PAREN namedExpression (COMMA namedExpression)+ RIGHT_PAREN #rowConstructor | LEFT_PAREN query RIGHT_PAREN #subqueryExpression - | IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN #identifierClause | functionName LEFT_PAREN (setQuantifier? argument+=functionArgument (COMMA argument+=functionArgument)*)? RIGHT_PAREN (FILTER LEFT_PAREN WHERE where=booleanExpression RIGHT_PAREN)? @@ -1196,6 +1195,7 @@ qualifiedNameList functionName : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | identFunc=IDENTIFIER_KW // IDENTIFIER itself is also a valid function name. | qualifiedName | FILTER | LEFT diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java index c4382239a1..f57c8facee 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/client/StartJobRequest.java @@ -7,12 +7,14 @@ import java.util.Map; import lombok.Data; +import lombok.EqualsAndHashCode; /** * This POJO carries all the fields required for emr serverless job submission. Used as model in * {@link EMRServerlessClient} interface. */ @Data +@EqualsAndHashCode public class StartJobRequest { public static final Long DEFAULT_JOB_TIMEOUT = 120L; diff --git a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java index ab9761da36..8c0ecb2ea2 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/dispatcher/SparkQueryDispatcherTest.java @@ -41,6 +41,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -78,6 +80,8 @@ public class SparkQueryDispatcherTest { private SparkQueryDispatcher sparkQueryDispatcher; + @Captor ArgumentCaptor startJobRequestArgumentCaptor; + @BeforeEach void setUp() { sparkQueryDispatcher = @@ -96,19 +100,21 @@ void testDispatchSelectQuery() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -125,23 +131,18 @@ void testDispatchSelectQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -153,20 +154,22 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "basicauth", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); + put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -183,24 +186,18 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "basicauth", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AUTH_USERNAME, "username"); - put(FLINT_INDEX_STORE_AUTH_PASSWORD, "password"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -212,18 +209,20 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { tags.put("datasource", "my_glue"); tags.put("cluster", TEST_CLUSTER_NAME); String query = "select * from my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "noauth", + new HashMap<>() { + { + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -240,22 +239,18 @@ void testDispatchSelectQueryWithNoAuthIndexStoreDatasource() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "noauth", - new HashMap<>() { - { - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -272,20 +267,22 @@ void testDispatchIndexQuery() { String query = "CREATE INDEX elb_and_requestUri ON my_glue.default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, true, any()))) @@ -302,24 +299,18 @@ void testDispatchIndexQuery() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -332,19 +323,21 @@ void testDispatchWithPPLQuery() { tags.put("cluster", TEST_CLUSTER_NAME); String query = "source = my_glue.default.http_logs"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -361,23 +354,18 @@ void testDispatchWithPPLQuery() { LangType.PPL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -390,19 +378,21 @@ void testDispatchQueryWithoutATableAndDataSourceName() { tags.put("cluster", TEST_CLUSTER_NAME); String query = "show tables"; + String sparkSubmitParameters = + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + }); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:non-index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), + sparkSubmitParameters, tags, false, any()))) @@ -419,23 +409,18 @@ void testDispatchQueryWithoutATableAndDataSourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:non-index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - }), - tags, - false, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:non-index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + false, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -453,20 +438,22 @@ void testDispatchIndexQueryWithoutADatasourceName() { String query = "CREATE INDEX elb_and_requestUri ON default.http_logs(l_orderkey, l_quantity) WITH" + " (auto_refresh = true)"; + String sparkSubmitParameters = + withStructuredStreaming( + constructExpectedSparkSubmitParameterString( + "sigv4", + new HashMap<>() { + { + put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); + } + })); when(emrServerlessClient.startJobRun( new StartJobRequest( query, "TEST_CLUSTER:index-query", EMRS_APPLICATION_ID, EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), + sparkSubmitParameters, tags, true, any()))) @@ -483,24 +470,18 @@ void testDispatchIndexQueryWithoutADatasourceName() { LangType.SQL, EMRS_EXECUTION_ROLE, TEST_CLUSTER_NAME)); - verify(emrServerlessClient, times(1)) - .startJobRun( - new StartJobRequest( - query, - "TEST_CLUSTER:index-query", - EMRS_APPLICATION_ID, - EMRS_EXECUTION_ROLE, - withStructuredStreaming( - constructExpectedSparkSubmitParameterString( - "sigv4", - new HashMap<>() { - { - put(FLINT_INDEX_STORE_AWSREGION_KEY, "eu-west-1"); - } - })), - tags, - true, - any())); + verify(emrServerlessClient, times(1)).startJobRun(startJobRequestArgumentCaptor.capture()); + StartJobRequest expected = + new StartJobRequest( + query, + "TEST_CLUSTER:index-query", + EMRS_APPLICATION_ID, + EMRS_EXECUTION_ROLE, + sparkSubmitParameters, + tags, + true, + null); + Assertions.assertEquals(expected, startJobRequestArgumentCaptor.getValue()); Assertions.assertEquals(EMR_JOB_ID, dispatchQueryResponse.getJobId()); Assertions.assertFalse(dispatchQueryResponse.isDropIndexQuery()); verifyNoInteractions(flintIndexMetadataReader); @@ -905,8 +886,8 @@ private String constructExpectedSparkSubmitParameterString( + " --conf" + " spark.hive.metastore.glue.role.arn=arn:aws:iam::924196221507:role/FlintOpensearchServiceRole" + " --conf spark.sql.catalog.my_glue=org.opensearch.sql.FlintDelegatingSessionCatalog " - + authParamConfigBuilder - + " --conf spark.flint.datasource.name=my_glue "; + + " --conf spark.flint.datasource.name=my_glue " + + authParamConfigBuilder; } private String withStructuredStreaming(String parameters) {