Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested aggregation #2814

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.util.List;
import java.util.stream.Collectors;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.optimizer.rule.EliminateNested;
import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter;
import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort;
import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder;
Expand Down Expand Up @@ -58,7 +59,11 @@ public static LogicalPlanOptimizer create() {
TableScanPushDown.PUSH_DOWN_HIGHLIGHT,
TableScanPushDown.PUSH_DOWN_NESTED,
TableScanPushDown.PUSH_DOWN_PROJECT,
new CreateTableWriteBuilder()));
new CreateTableWriteBuilder(),
/*
* Phase 3: Transformations for others
*/
new EliminateNested()));
}

/** Optimize {@link LogicalPlan}. */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.planner.optimizer.rule;

import static com.facebook.presto.matching.Pattern.typeOf;
import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source;

import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import lombok.Getter;
import lombok.experimental.Accessors;
import org.opensearch.sql.planner.logical.LogicalAggregation;
import org.opensearch.sql.planner.logical.LogicalNested;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.optimizer.Rule;

/**
* Eliminate LogicalNested if its child is LogicalAggregation.<br>
* LogicalNested - LogicalAggregation - Child --> LogicalAggregation - Child<br>
* E.g. count(nested(foo.bar, foo))
*/
public class EliminateNested implements Rule<LogicalNested> {

private final Capture<LogicalAggregation> capture;

@Accessors(fluent = true)
@Getter
private final Pattern<LogicalNested> pattern;

public EliminateNested() {
this.capture = Capture.newCapture();
this.pattern =
typeOf(LogicalNested.class)
.with(source().matching(typeOf(LogicalAggregation.class).capturedAs(capture)));
}

@Override
public LogicalPlan apply(LogicalNested plan, Captures captures) {
return captures.get(capture);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.sql.planner.optimizer;

import static java.util.Collections.emptyList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
Expand Down Expand Up @@ -122,6 +123,43 @@ void multiple_filter_should_eventually_be_merged() {
DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))))));
}

@Test
void eliminate_nested_in_aggregation() {
List<Map<String, ReferenceExpression>> nestedArgs =
ImmutableList.of(
Map.of(
"field", new ReferenceExpression("message.info", STRING),
"path", new ReferenceExpression("message", STRING)));
List<NamedExpression> projectList =
ImmutableList.of(
DSL.named(
"count(nested(message.info, message))",
DSL.ref("count(nested(message.info, message))", INTEGER)));

assertEquals(
aggregation(
tableScanBuilder,
ImmutableList.of(
DSL.named(
"count(nested(message.info, message))",
DSL.count(
DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))),
emptyList()),
optimize(
nested(
aggregation(
relation("schema", table),
ImmutableList.of(
DSL.named(
"count(nested(message.info, message))",
DSL.count(
DSL.nested(
DSL.ref("message.info", STRING), DSL.ref("message", ARRAY))))),
emptyList()),
nestedArgs,
projectList)));
}

@Test
void default_table_scan_builder_should_not_push_down_anything() {
LogicalPlan[] plans = {
Expand Down
26 changes: 25 additions & 1 deletion docs/user/dql/aggregations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ The aggregation could has expression as arguments::
| M | 202 |
+----------+--------+

COUNT Aggregations
COUNT Aggregation
------------------

Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments such as ``*`` or literals like ``1``. The meaning of these different forms are as follows:
Expand All @@ -135,6 +135,30 @@ Besides regular identifiers, ``COUNT`` aggregate function also accepts arguments
2. ``COUNT(*)`` will count the number of all its input rows.
3. ``COUNT(1)`` is same as ``COUNT(*)`` because any non-null literal will count.

NESTED Aggregation
------------------
The nested aggregation lets you aggregate on fields inside a nested object. You can use ``nested`` function to return a nested field, ref :ref:`nested function <nested_function_label>`.

To understand why we need nested aggregations, read `Nested Aggregations DSL doc <https://opensearch.org/docs/latest/aggregations/bucket/nested/>`_ to get more details.

The nested aggregation could be expression::

os> SELECT count(nested(message.info, message)) FROM nested;
fetched rows / total rows = 1/1
+----------------------------------------+
| count(nested(message.info, message)) |
|----------------------------------------|
| 2 |
+----------------------------------------+

os> SELECT count(nested(message.info)) FROM nested;
fetched rows / total rows = 1/1
+-------------------------------+
| count(nested(message.info)) |
|-------------------------------|
| 2 |
+-------------------------------+

Aggregation Functions
=====================

Expand Down
1 change: 1 addition & 0 deletions docs/user/dql/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4419,6 +4419,7 @@ Another example to show how to set custom values for the optional parameters::
+-------------------------------------------+


.. _nested_function_label:
NESTED
------

Expand Down
13 changes: 5 additions & 8 deletions integ-test/src/test/java/org/opensearch/sql/sql/NestedIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.json.JSONArray;
import org.json.JSONObject;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;
import org.opensearch.sql.legacy.SQLIntegTestCase;

public class NestedIT extends SQLIntegTestCase {
Expand Down Expand Up @@ -75,20 +74,18 @@ public void nested_function_in_select_test() {
rows("zz", "bb", 6));
}

// Has to be tested with JSON format when https://github.com/opensearch-project/sql/issues/1317
// gets resolved
@Disabled // TODO fix me when aggregation is supported
@Test
public void nested_function_in_an_aggregate_function_in_select_test() {
String query =
"SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS;
"SELECT sum(nested(message.dayOfWeek, message)) FROM "
+ TEST_INDEX_NESTED_TYPE_WITHOUT_ARRAYS;
JSONObject result = executeJdbcRequest(query);
verifyDataRows(result, rows(14));
}

// TODO Enable me when nested aggregation is supported
@Disabled
@Test
public void nested_function_with_arrays_in_an_aggregate_function_in_select_test() {
String query = "SELECT sum(nested(message.dayOfWeek)) FROM " + TEST_INDEX_NESTED_TYPE;
String query = "SELECT sum(nested(message.dayOfWeek, message)) FROM " + TEST_INDEX_NESTED_TYPE;
JSONObject result = executeJdbcRequest(query);
verifyDataRows(result, rows(19));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import lombok.RequiredArgsConstructor;
import org.opensearch.search.aggregations.Aggregation;
import org.opensearch.search.aggregations.Aggregations;
import org.opensearch.search.aggregations.bucket.nested.InternalNested;
import org.opensearch.sql.common.utils.StringUtils;

/** Parse multiple metrics in one bucket. */
Expand All @@ -44,6 +45,9 @@ public MetricParserHelper(List<MetricParser> metricParserList) {
public Map<String, Object> parse(Aggregations aggregations) {
Map<String, Object> resultMap = new HashMap<>();
for (Aggregation aggregation : aggregations) {
if (aggregation instanceof InternalNested) {
aggregation = ((InternalNested) aggregation).getAggregations().asList().getFirst();
}
if (metricParserMap.containsKey(aggregation.getName())) {
resultMap.putAll(metricParserMap.get(aggregation.getName()).parse(aggregation));
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ public Integer getMaxResultWindow() {
@Override
public PhysicalPlan implement(LogicalPlan plan) {
// TODO: Leave it here to avoid impact Prometheus and AD operators. Need to move to Planner.
return plan.accept(new OpenSearchDefaultImplementor(client), null);
PhysicalPlan pp = plan.accept(new OpenSearchDefaultImplementor(client), null);
return pp;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
import static org.opensearch.script.Script.DEFAULT_SCRIPT_TYPE;
import static org.opensearch.sql.opensearch.storage.script.ExpressionScriptEngine.EXPRESSION_LANG_NAME;

import java.util.List;
import java.util.function.Function;
import lombok.RequiredArgsConstructor;
import org.opensearch.script.Script;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.opensearch.data.type.OpenSearchTextType;
import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer;

Expand All @@ -25,18 +29,59 @@ public class AggregationBuilderHelper {

private final ExpressionSerializer serializer;

/** Build Composite Builder from Expression. */
public <T> T buildComposite(
Expression expression, Function<String, T> fieldBuilder, Function<Script, T> scriptBuilder) {
if (expression instanceof ReferenceExpression) {
String fieldName = ((ReferenceExpression) expression).getAttr();
return fieldBuilder.apply(
OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
} else if (expression instanceof FunctionExpression
|| expression instanceof LiteralExpression) {
return scriptBuilder.apply(
new Script(
DEFAULT_SCRIPT_TYPE,
EXPRESSION_LANG_NAME,
serializer.serialize(expression),
emptyMap()));
} else {
throw new IllegalStateException(
String.format("bucket aggregation doesn't support " + "expression %s", expression));
}
}

/**
* Build AggregationBuilder from Expression.
*
* @param expression Expression
* @return AggregationBuilder
*/
public <T> T build(
Expression expression, Function<String, T> fieldBuilder, Function<Script, T> scriptBuilder) {
public AggregationBuilder build(
Expression expression,
Function<String, AggregationBuilder> fieldBuilder,
Function<Script, AggregationBuilder> scriptBuilder) {
if (expression instanceof ReferenceExpression) {
String fieldName = ((ReferenceExpression) expression).getAttr();
return fieldBuilder.apply(
OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
} else if (expression instanceof FunctionExpression
&& ((FunctionExpression) expression)
.getFunctionName()
.equals(BuiltinFunctionName.NESTED.getName())) {
List<Expression> args = ((FunctionExpression) expression).getArguments();
// NestedAnalyzer has validated the number of arguments.
// Here we can safety invoke args.getFirst().
String fieldName = ((ReferenceExpression) args.getFirst()).getAttr();
if (fieldName.contains("*")) {
throw new IllegalArgumentException("Nested aggregation doesn't support multiple fields");
}
String path =
args.size() == 2
? ((ReferenceExpression) args.get(1)).getAttr()
: fieldName.substring(0, fieldName.lastIndexOf("."));
AggregationBuilder subAgg =
fieldBuilder.apply(OpenSearchTextType.convertTextToKeyword(fieldName, expression.type()));
return AggregationBuilders.nested(path + "_nested", path).subAggregation(subAgg);
} else if (expression instanceof FunctionExpression
|| expression instanceof LiteralExpression) {
return scriptBuilder.apply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ private CompositeValuesSourceBuilder<?> buildCompositeValuesSourceBuilder(
if (List.of(TIMESTAMP, TIME, DATE).contains(expr.getDelegated().type())) {
sourceBuilder.userValuetypeHint(ValueType.LONG);
}
return helper.build(expr.getDelegated(), sourceBuilder::field, sourceBuilder::script);
return helper.buildComposite(
expr.getDelegated(), sourceBuilder::field, sourceBuilder::script);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.AggregatorFactories;
import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder;
import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder;
import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder;
import org.opensearch.search.aggregations.metrics.ExtendedStats;
import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder;
Expand Down Expand Up @@ -179,7 +180,7 @@ private Pair<AggregationBuilder, MetricParser> make(
Expression condition,
String name,
MetricParser parser) {
ValuesSourceAggregationBuilder aggregationBuilder =
AggregationBuilder aggregationBuilder =
helper.build(expression, builder::field, builder::script);
if (condition != null) {
return Pair.of(
Expand All @@ -196,7 +197,7 @@ private Pair<AggregationBuilder, MetricParser> make(
Expression condition,
String name,
MetricParser parser) {
CardinalityAggregationBuilder aggregationBuilder =
AggregationBuilder aggregationBuilder =
helper.build(expression, builder::field, builder::script);
if (condition != null) {
return Pair.of(
Expand Down Expand Up @@ -234,12 +235,23 @@ private Pair<AggregationBuilder, MetricParser> make(
Expression condition,
String name,
MetricParser parser) {
PercentilesAggregationBuilder aggregationBuilder =
AggregationBuilder aggregationBuilder =
helper.build(expression, builder::field, builder::script);
PercentilesAggregationBuilder percentilesBuilder;
if (aggregationBuilder instanceof NestedAggregationBuilder) {
percentilesBuilder =
aggregationBuilder.getSubAggregations().stream()
.filter(PercentilesAggregationBuilder.class::isInstance)
.map(a -> (PercentilesAggregationBuilder) a)
.findFirst()
.orElseThrow();
} else {
percentilesBuilder = (PercentilesAggregationBuilder) aggregationBuilder;
}
if (compression != null) {
aggregationBuilder.compression(compression.valueOf().doubleValue());
percentilesBuilder.compression(compression.valueOf().doubleValue());
}
aggregationBuilder.percentiles(percent.valueOf().doubleValue());
percentilesBuilder.percentiles(percent.valueOf().doubleValue());
if (condition != null) {
return Pair.of(
makeFilterAggregation(aggregationBuilder, condition, name),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.opensearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder;
import org.opensearch.search.aggregations.bucket.histogram.ParsedDateHistogram;
import org.opensearch.search.aggregations.bucket.histogram.ParsedHistogram;
import org.opensearch.search.aggregations.bucket.nested.NestedAggregationBuilder;
import org.opensearch.search.aggregations.bucket.nested.ParsedNested;
import org.opensearch.search.aggregations.bucket.terms.DoubleTerms;
import org.opensearch.search.aggregations.bucket.terms.LongTerms;
import org.opensearch.search.aggregations.bucket.terms.ParsedDoubleTerms;
Expand Down Expand Up @@ -87,6 +89,8 @@ public class AggregationResponseUtils {
.put(
TopHitsAggregationBuilder.NAME,
(p, c) -> ParsedTopHits.fromXContent(p, (String) c))
.put(
NestedAggregationBuilder.NAME, (p, c) -> ParsedNested.fromXContent(p, (String) c))
.build()
.entrySet()
.stream()
Expand Down
Loading
Loading