diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 5c115f0db8..359742f218 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -12,6 +12,7 @@ import java.util.List; import java.util.stream.Collectors; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.rule.EliminateNested; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; @@ -58,7 +59,11 @@ public static LogicalPlanOptimizer create() { TableScanPushDown.PUSH_DOWN_HIGHLIGHT, TableScanPushDown.PUSH_DOWN_NESTED, TableScanPushDown.PUSH_DOWN_PROJECT, - new CreateTableWriteBuilder())); + new CreateTableWriteBuilder(), + /* + * Phase 3: Transformations for others + */ + new EliminateNested())); } /** Optimize {@link LogicalPlan}. */ diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EliminateNested.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EliminateNested.java new file mode 100644 index 0000000000..f4f63717f6 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/EliminateNested.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule; + +import static com.facebook.presto.matching.Pattern.typeOf; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalNested; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.Rule; + +/** + * Eliminate LogicalNested if its child is LogicalAggregation.
+ * LogicalNested - LogicalAggregation - Child --> LogicalAggregation - Child
+ * E.g. count(nested(foo.bar, foo)) + */ +public class EliminateNested implements Rule { + + private final Capture capture; + + @Accessors(fluent = true) + @Getter + private final Pattern pattern; + + public EliminateNested() { + this.capture = Capture.newCapture(); + this.pattern = + typeOf(LogicalNested.class) + .with(source().matching(typeOf(LogicalAggregation.class).capturedAs(capture))); + } + + @Override + public LogicalPlan apply(LogicalNested plan, Captures captures) { + return captures.get(capture); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index c25e415cfa..5a12b429ad 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -5,6 +5,7 @@ package org.opensearch.sql.planner.optimizer; +import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; @@ -122,6 +123,43 @@ void multiple_filter_should_eventually_be_merged() { DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))))); } + @Test + void eliminate_nested_in_aggregation() { + List> nestedArgs = + ImmutableList.of( + Map.of( + "field", new ReferenceExpression("message.info", STRING), + "path", new ReferenceExpression("message", STRING))); + List projectList = + ImmutableList.of( + DSL.named( + "count(nested(message.info, message))", + DSL.ref("count(nested(message.info, message))", INTEGER))); + + assertEquals( + aggregation( + tableScanBuilder, + ImmutableList.of( + DSL.named( + "count(nested(message.info, message))", + DSL.count( + DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))), + emptyList()), + optimize( + nested( + aggregation( + relation("schema", table), + ImmutableList.of( + DSL.named( + "count(nested(message.info, message))", + DSL.count( + DSL.nested( + DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))), + emptyList()), + nestedArgs, + projectList))); + } + @Test void default_table_scan_builder_should_not_push_down_anything() { LogicalPlan[] plans = { diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index 42db4cdb4f..f81800f558 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -126,7 +126,7 @@ The aggregation could has expression as arguments:: | M | 202 | +----------+--------+ -COUNT Aggregations +COUNT Aggregation ------------------ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments such as ``*`` or literals like ``1``. The meaning of these different forms are as follows: @@ -135,6 +135,30 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments 2. ``COUNT(*)`` will count the number of all its input rows. 3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count. +NESTED Aggregation +------------------ +The nested aggregation lets you aggregate on fields inside a nested object. You can use ``nested`` function to return a nested field, ref :ref:`nested function `. + +To understand why we need nested aggregations, read `Nested Aggregations DSL doc `_ to get more details. + +The nested aggregation could be expression:: + + os> SELECT count(nested(message.info, message)) FROM nested; + fetched rows / total rows = 1/1 + +----------------------------------------+ + | count(nested(message.info, message)) | + |----------------------------------------| + | 2 | + +----------------------------------------+ + + os> SELECT count(nested(message.info)) FROM nested; + fetched rows / total rows = 1/1 + +-------------------------------+ + | count(nested(message.info)) | + |-------------------------------| + | 2 | + +-------------------------------+ + Aggregation Functions ===================== diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index b445fffa63..c636823a47 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -4419,6 +4419,7 @@ Another example to show how to set custom values for the optional parameters:: +-------------------------------------------+ +.. _nested_function_label: NESTED ------ diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java index 96bbae94e5..ccd8b72c4c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java @@ -21,7 +21,6 @@ import org.json.JSONArray; import org.json.JSONObject; import org.junit.Test; -import org.junit.jupiter.api.Disabled; import org.opensearch.sql.legacy.SQLIntegTestCase; public class NestedIT extends SQLIntegTestCase { @@ -75,20 +74,18 @@ public void nested_function_in_select_test() { rows("zz", "bb", 6)); } - // Has to be tested with JSON format when https://github.com/opensearch-project/sql/issues/1317 - // gets resolved - @Disabled // TODO fix me when aggregation is supported + @Test public void nested_function_in_an_aggregate_function_in_select_test() { String query = - "SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; + "SELECT sum(nested(message.dayOfWeek, message)) FROM " + + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS; JSONObject result = executeJdbcRequest(query); verifyDataRows(result, rows(14)); } - // TODO Enable me when nested aggregation is supported - @Disabled + @Test public void nested_function_with_arrays_in_an_aggregate_function_in_select_test() { - String query = "SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE; + String query = "SELECT sum(nested(message.dayOfWeek, message)) FROM " + TEST_INDEX_NESTED_TYPE; JSONObject result = executeJdbcRequest(query); verifyDataRows(result, rows(19)); } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java index 4df9537973..6da7028760 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -21,6 +21,7 @@ import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.Aggregations; +import org.opensearch.search.aggregations.bucket.nested.InternalNested; import org.opensearch.sql.common.utils.StringUtils; /** Parse multiple metrics in one bucket. */ @@ -44,6 +45,9 @@ public MetricParserHelper(List metricParserList) { public Map parse(Aggregations aggregations) { Map resultMap = new HashMap<>(); for (Aggregation aggregation : aggregations) { + if (aggregation instanceof InternalNested) { + aggregation = ((InternalNested) aggregation).getAggregations().asList().getFirst(); + } if (metricParserMap.containsKey(aggregation.getName())) { resultMap.putAll(metricParserMap.get(aggregation.getName()).parse(aggregation)); } else { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index c6afdb8511..a3d96fdb39 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -155,7 +155,8 @@ public Integer getMaxResultWindow() { @Override public PhysicalPlan implement(LogicalPlan plan) { // TODO: Leave it here to avoid impact Prometheus and AD operators. Need to move to Planner. - return plan.accept(new OpenSearchDefaultImplementor(client), null); + PhysicalPlan pp = plan.accept(new OpenSearchDefaultImplementor(client), null); + return pp; } @Override diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java index 7dd02d82d0..a93222f79d 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/AggregationBuilderHelper.java @@ -9,13 +9,17 @@ import static org.opensearch.script.Script.DEFAULT_SCRIPT_TYPE; import static org.opensearch.sql.opensearch.storage.script.ExpressionScriptEngine.EXPRESSION_LANG_NAME; +import java.util.List; import java.util.function.Function; import lombok.RequiredArgsConstructor; import org.opensearch.script.Script; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -25,18 +29,59 @@ public class AggregationBuilderHelper { private final ExpressionSerializer serializer; + /** Build Composite Builder from Expression. */ + public T buildComposite( + Expression expression, Function fieldBuilder, Function scriptBuilder) { + if (expression instanceof ReferenceExpression) { + String fieldName = ((ReferenceExpression) expression).getAttr(); + return fieldBuilder.apply( + OpenSearchTextType.convertTextToKeyword(fieldName, expression.type())); + } else if (expression instanceof FunctionExpression + || expression instanceof LiteralExpression) { + return scriptBuilder.apply( + new Script( + DEFAULT_SCRIPT_TYPE, + EXPRESSION_LANG_NAME, + serializer.serialize(expression), + emptyMap())); + } else { + throw new IllegalStateException( + String.format("bucket aggregation doesn't support " + "expression %s", expression)); + } + } + /** * Build AggregationBuilder from Expression. * * @param expression Expression * @return AggregationBuilder */ - public T build( - Expression expression, Function fieldBuilder, Function scriptBuilder) { + public AggregationBuilder build( + Expression expression, + Function fieldBuilder, + Function scriptBuilder) { if (expression instanceof ReferenceExpression) { String fieldName = ((ReferenceExpression) expression).getAttr(); return fieldBuilder.apply( OpenSearchTextType.convertTextToKeyword(fieldName, expression.type())); + } else if (expression instanceof FunctionExpression + && ((FunctionExpression) expression) + .getFunctionName() + .equals(BuiltinFunctionName.NESTED.getName())) { + List args = ((FunctionExpression) expression).getArguments(); + // NestedAnalyzer has validated the number of arguments. + // Here we can safety invoke args.getFirst(). + String fieldName = ((ReferenceExpression) args.getFirst()).getAttr(); + if (fieldName.contains("*")) { + throw new IllegalArgumentException("Nested aggregation doesn't support multiple fields"); + } + String path = + args.size() == 2 + ? ((ReferenceExpression) args.get(1)).getAttr() + : fieldName.substring(0, fieldName.lastIndexOf(".")); + AggregationBuilder subAgg = + fieldBuilder.apply(OpenSearchTextType.convertTextToKeyword(fieldName, expression.type())); + return AggregationBuilders.nested(path + "_nested", path).subAggregation(subAgg); } else if (expression instanceof FunctionExpression || expression instanceof LiteralExpression) { return scriptBuilder.apply( diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java index ff66ec425a..2b3c6f0dd2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilder.java @@ -68,7 +68,8 @@ private CompositeValuesSourceBuilder buildCompositeValuesSourceBuilder( if (List.of(TIMESTAMP, TIME, DATE).contains(expr.getDelegated().type())) { sourceBuilder.userValuetypeHint(ValueType.LONG); } - return helper.build(expr.getDelegated(), sourceBuilder::field, sourceBuilder::script); + return helper.buildComposite( + expr.getDelegated(), sourceBuilder::field, sourceBuilder::script); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index 779fe2f1c9..f93f0ee5d1 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -15,6 +15,7 @@ import org.opensearch.search.aggregations.AggregationBuilders; import org.opensearch.search.aggregations.AggregatorFactories; import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; +import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.metrics.ExtendedStats; import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder; @@ -179,7 +180,7 @@ private Pair make( Expression condition, String name, MetricParser parser) { - ValuesSourceAggregationBuilder aggregationBuilder = + AggregationBuilder aggregationBuilder = helper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( @@ -196,7 +197,7 @@ private Pair make( Expression condition, String name, MetricParser parser) { - CardinalityAggregationBuilder aggregationBuilder = + AggregationBuilder aggregationBuilder = helper.build(expression, builder::field, builder::script); if (condition != null) { return Pair.of( @@ -234,12 +235,23 @@ private Pair make( Expression condition, String name, MetricParser parser) { - PercentilesAggregationBuilder aggregationBuilder = + AggregationBuilder aggregationBuilder = helper.build(expression, builder::field, builder::script); + PercentilesAggregationBuilder percentilesBuilder; + if (aggregationBuilder instanceof NestedAggregationBuilder) { + percentilesBuilder = + aggregationBuilder.getSubAggregations().stream() + .filter(PercentilesAggregationBuilder.class::isInstance) + .map(a -> (PercentilesAggregationBuilder) a) + .findFirst() + .orElseThrow(); + } else { + percentilesBuilder = (PercentilesAggregationBuilder) aggregationBuilder; + } if (compression != null) { - aggregationBuilder.compression(compression.valueOf().doubleValue()); + percentilesBuilder.compression(compression.valueOf().doubleValue()); } - aggregationBuilder.percentiles(percent.valueOf().doubleValue()); + percentilesBuilder.percentiles(percent.valueOf().doubleValue()); if (condition != null) { return Pair.of( makeFilterAggregation(aggregationBuilder, condition, name), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java index ccdfdce7a4..f230bae5a8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java @@ -26,6 +26,8 @@ import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder; import org.opensearch.search.aggregations.bucket.histogram.ParsedDateHistogram; import org.opensearch.search.aggregations.bucket.histogram.ParsedHistogram; +import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder; +import org.opensearch.search.aggregations.bucket.nested.ParsedNested; import org.opensearch.search.aggregations.bucket.terms.DoubleTerms; import org.opensearch.search.aggregations.bucket.terms.LongTerms; import org.opensearch.search.aggregations.bucket.terms.ParsedDoubleTerms; @@ -87,6 +89,8 @@ public class AggregationResponseUtils { .put( TopHitsAggregationBuilder.NAME, (p, c) -> ParsedTopHits.fromXContent(p, (String) c)) + .put( + NestedAggregationBuilder.NAME, (p, c) -> ParsedNested.fromXContent(p, (String) c)) .build() .entrySet() .stream() diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java index 4250b3297f..c871bf3b4f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/BucketAggregationBuilderTest.java @@ -6,10 +6,12 @@ package org.opensearch.sql.opensearch.storage.script.aggregation.dsl; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.Mockito.when; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.UNKNOWN; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; @@ -34,10 +36,15 @@ import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder; import org.opensearch.search.aggregations.bucket.missing.MissingOrder; import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.type.OpenSearchTextType; @@ -152,6 +159,38 @@ void terms_bucket_for_datetime_types_uses_long(ExprType dataType) { buildQuery(Arrays.asList(asc(named("date", ref("date", dataType)))))); } + /** Add this unit test in case build bucket aggregation with unsupported expression in future */ + @Test + void should_throw_exception_for_unsupported_expression() { + Expression unsupportedExpression = + new Expression() { + @Override + public ExprValue valueOf(Environment valueEnv) { + return ExprValueUtils.nullValue(); + } + + @Override + public ExprType type() { + return UNKNOWN; + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return null; + } + + @Override + public String toString() { + return "unknown"; + } + }; + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> buildQuery(Arrays.asList(asc(named("UNKNOWN", unsupportedExpression))))); + assertEquals("bucket aggregation doesn't support expression unknown", exception.getMessage()); + } + @SneakyThrows private String buildQuery( List> groupByExpressions) { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 6d792dec25..8f82981f6c 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -15,6 +15,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.named; +import static org.opensearch.sql.expression.DSL.nested; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevPopulation; import static org.opensearch.sql.expression.aggregation.StdDevAggregator.stddevSample; @@ -145,6 +146,26 @@ void should_build_count_other_literal_aggregation() { named("count(1)", new CountAggregator(Arrays.asList(literal(1)), INTEGER))))); } + @Test + void should_build_max_with_literal_aggregation() { + var literal = literal(1); + when(serializer.serialize(literal)).thenReturn("mock-serialize"); + assertEquals( + format( + "{%n" + + " \"max(1)\" : {%n" + + " \"max\" : {%n" + + " \"script\" : {%n" + + " \"source\" : \"mock-serialize\",%n" + + " \"lang\" : \"opensearch_query_expression\"%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList(named("max(1)", new MaxAggregator(Arrays.asList(literal), INTEGER))))); + } + @Test void should_build_min_aggregation() { assertEquals( @@ -507,6 +528,242 @@ void should_throw_exception_for_unsupported_exception() { assertEquals("metric aggregation doesn't support expression age", exception.getMessage()); } + @Test + void should_build_nested_aggregation() { + assertEquals( + format( + "{%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(nested(message.info, message))\" : {%n" + + " \"value_count\" : {%n" + + " \"field\" : \"message.info\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "count(nested(message.info, message))", + new CountAggregator( + Arrays.asList(nested(ref("message.info", STRING), ref("message", ARRAY))), + INTEGER))))); + } + + @Test + void should_build_nested_aggregation_without_path() { + assertEquals( + format( + "{%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(nested(message.info, message))\" : {%n" + + " \"value_count\" : {%n" + + " \"field\" : \"message.info\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "count(nested(message.info, message))", + new CountAggregator( + Arrays.asList(nested(ref("message.info", STRING))), INTEGER))))); + } + + @Test + void should_build_nested_aggregation_cardinality() { + assertEquals( + format( + "{%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(distinct nested(message.info, message))\" : {%n" + + " \"cardinality\" : {%n" + + " \"field\" : \"message.info\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "count(distinct nested(message.info, message))", + new CountAggregator( + Arrays.asList( + nested(ref("message.info", STRING), ref("message", ARRAY))), + INTEGER) + .distinct(true))))); + } + + @Test + void should_build_nested_aggregation_filtered_cardinality() { + assertEquals( + format( + "{%n \"count(distinct nested(message.info, message)) filter(where age > 30)\" : {%n" + + " \"filter\" : {%n" + + " \"range\" : {%n" + + " \"age\" : {%n" + + " \"from\" : 30,%n" + + " \"to\" : null,%n" + + " \"include_lower\" : false,%n" + + " \"include_upper\" : true,%n" + + " \"boost\" : 1.0%n" + + " }%n" + + " }%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(distinct nested(message.info, message)) filter(where age >" + + " 30)\" : {%n" + + " \"cardinality\" : {%n" + + " \"field\" : \"message.info\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "count(distinct nested(message.info, message)) filter(where age > 30)", + new CountAggregator( + Arrays.asList( + nested(ref("message.info", STRING), ref("message", ARRAY))), + INTEGER) + .condition(DSL.greater(ref("age", INTEGER), literal(30))) + .distinct(true))))); + } + + @Test + void should_build_nested_aggregation_nested_filtered_cardinality() { + assertEquals( + format( + "{%n \"count(distinct nested(message.info, message)) filter(where nested(message.age," + + " message) > 30)\" : {%n" + + " \"filter\" : {%n" + + " \"nested\" : {%n" + + " \"query\" : {%n" + + " \"range\" : {%n" + + " \"message.age\" : {%n" + + " \"from\" : 30,%n" + + " \"to\" : null,%n" + + " \"include_lower\" : false,%n" + + " \"include_upper\" : true,%n" + + " \"boost\" : 1.0%n" + + " }%n" + + " }%n" + + " },%n" + + " \"path\" : \"message\",%n" + + " \"ignore_unmapped\" : false,%n" + + " \"score_mode\" : \"none\",%n" + + " \"boost\" : 1.0%n" + + " }%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"count(distinct nested(message.info, message)) filter(where" + + " nested(message.age, message) > 30)\" : {%n" + + " \"cardinality\" : {%n" + + " \"field\" : \"message.info\"%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "count(distinct nested(message.info, message)) filter(where nested(message.age," + + " message) > 30)", + new CountAggregator( + Arrays.asList( + nested(ref("message.info", STRING), ref("message", ARRAY))), + INTEGER) + .condition( + DSL.greater( + nested(ref("message.age", INTEGER), ref("message", ARRAY)), + literal(30))) + .distinct(true))))); + } + + @Test + void should_build_nested_aggregation_percentile() { + assertEquals( + format( + "{%n" + + " \"message_nested\" : {%n" + + " \"nested\" : {%n" + + " \"path\" : \"message\"%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"percentile(nested(message.info, message), 50)\" : {%n" + + " \"percentiles\" : {%n" + + " \"field\" : \"message.info\",%n" + + " \"percents\" : [ 50.0 ],%n" + + " \"keyed\" : true,%n" + + " \"tdigest\" : {%n" + + " \"compression\" : 100.0%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "percentile(nested(message.info, message), 50)", + new PercentileApproximateAggregator( + Arrays.asList( + nested(ref("message.info", STRING), ref("message", ARRAY)), + literal(50)), + DOUBLE))))); + } + + @Test + void should_throw_exception_for_nested_aggregation_on_star() { + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> + buildQuery( + Arrays.asList( + named( + "count(nested(message.*))", + new CountAggregator( + Arrays.asList( + nested(ref("message.*", STRING), ref("message", ARRAY))), + INTEGER))))); + assertEquals("Nested aggregation doesn't support multiple fields", exception.getMessage()); + } + @SneakyThrows private String buildQuery(List namedAggregatorList) { ObjectMapper objectMapper = new ObjectMapper();