From 42d998aabc644c2a423328c8f2fdff1c115e60a5 Mon Sep 17 00:00:00 2001 From: aloknath Date: Fri, 8 Apr 2022 18:13:46 +0530 Subject: [PATCH] CTE support for snapshot --- .../execution/SqlQueryExecution.java | 5 +- .../operator/CommonTableExecutionContext.java | 114 ++++++++++++++++-- .../CommonTableExpressionOperator.java | 70 ++++++++++- .../TestCommonTableExpressionOperator.java | 40 ++++++ 4 files changed, 212 insertions(+), 17 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java index 203a758b2..1814b33ac 100644 --- a/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java +++ b/presto-main/src/main/java/io/prestosql/execution/SqlQueryExecution.java @@ -749,9 +749,8 @@ else if (statement instanceof InsertCube) { } // Doesn't work with the following features - if (SystemSessionProperties.isReuseTableScanEnabled(session) - || SystemSessionProperties.isCTEReuseEnabled(session)) { - reasons.add("No support along with reuse_table_scan or cte_reuse_enabled features"); + if (SystemSessionProperties.isReuseTableScanEnabled(session)) { + reasons.add("No support along with reuse_table_scan feature"); } // All input tables must support snapshotting diff --git a/presto-main/src/main/java/io/prestosql/operator/CommonTableExecutionContext.java b/presto-main/src/main/java/io/prestosql/operator/CommonTableExecutionContext.java index 8a7394569..c722cf103 100644 --- a/presto-main/src/main/java/io/prestosql/operator/CommonTableExecutionContext.java +++ b/presto-main/src/main/java/io/prestosql/operator/CommonTableExecutionContext.java @@ -18,17 +18,27 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.log.Logger; +import io.airlift.slice.DynamicSliceOutput; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import io.hetu.core.transport.execution.buffer.PagesSerde; +import io.hetu.core.transport.execution.buffer.PagesSerdeFactory; +import io.hetu.core.transport.execution.buffer.SerializedPage; +import io.prestosql.block.BlockJsonSerde; import io.prestosql.spi.Page; +import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.BlockEncodingSerde; import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider; +import io.prestosql.spi.snapshot.Restorable; +import io.prestosql.spi.snapshot.RestorableConfig; +import it.unimi.dsi.fastutil.longs.LongArrayList; import javax.annotation.concurrent.GuardedBy; -import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.io.Serializable; +import java.util.*; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; @@ -37,11 +47,13 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.prestosql.operator.Operator.NOT_BLOCKED; +@RestorableConfig() public class CommonTableExecutionContext + implements Restorable { private static final Logger LOG = Logger.get(CommonTableExecutionContext.class); private final String name; - private final int queueCnt; + private int queueCnt; private final PlanNodeId feederId; private boolean isFeederInitialized; private List feeders = Collections.synchronizedList(new ArrayList<>()); @@ -52,9 +64,9 @@ public class CommonTableExecutionContext private final Executor notificationExecutor; @GuardedBy("this") private SettableFuture blockedFuture; - private final int taskCount; - private final int maxMainQueueSize; - private final int maxPrefetchQueueSize; + private int taskCount; + private int maxMainQueueSize; + private int maxPrefetchQueueSize; public CommonTableExecutionContext(String name, Set consumers, PlanNodeId feederId, Executor notificationExecutor, int taskCount, int maxMainQueueSize, int maxPrefetchQueueSize) @@ -264,4 +276,86 @@ public CTEDoneException() super(); } } + + @Override + public Object capture(BlockEncodingSerdeProvider serdeProvider) { + BlockEncodingSerde blockSerde = serdeProvider.getBlockEncodingSerde(); + CommonTableExecutionContextState myState = new CommonTableExecutionContextState(); + PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockSerde,false); + PagesSerde pagesSerde = pagesSerdeFactory.createPagesSerde(); + + myState.queueCnt = queueCnt; + myState.maxPrefetchQueueSize = maxPrefetchQueueSize; + myState.maxMainQueueSize = maxMainQueueSize; + myState.taskCount = taskCount; + myState.feeders = new Object[feeders.size()]; + for (int i = 0; i < feeders.size(); i++) { + myState.feeders[i] = feeders.get(i); + } + myState.prefetchedQueue = new Object[prefetchedQueue.size()]; + Iterator iterator = prefetchedQueue.iterator(); + while (iterator.hasNext()) { + Page page = (Page) iterator.next(); + SerializedPage serializedPage = pagesSerde.serialize(page); + int p = 0; + myState.prefetchedQueue[p] = serializedPage; + p++; + } + myState.consumerQueues = new Object[consumerQueues.size()][2]; + if (consumerQueues != null) { + int count = 0; + for (Map.Entry> entry : consumerQueues.entrySet()) { + myState.consumerQueues[count][0] = entry.getKey(); + LinkedList pages = entry.getValue(); + LinkedList serializedPages = pages.stream().map(page -> pagesSerde.serialize(page)).collect(Collectors.toCollection(LinkedList::new)); + myState.consumerQueues[count][1] = serializedPages; + count ++; + } + } + return myState; + } + + @Override + public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) + { + BlockEncodingSerde blockSerde = serdeProvider.getBlockEncodingSerde(); + CommonTableExecutionContextState myState = (CommonTableExecutionContextState) state; + + PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockSerde,false); + PagesSerde pagesSerde = pagesSerdeFactory.createPagesSerde(); + + this.queueCnt = myState.queueCnt; + this.maxPrefetchQueueSize = myState.maxPrefetchQueueSize; + this.maxMainQueueSize = myState.maxMainQueueSize; + this.taskCount = myState.taskCount; + for (int i = 0; i < myState.feeders.length; i++) { + feeders.add((Integer) myState.feeders[i]); + } + for (int i = 0; i < myState.prefetchedQueue.length; i++) { + SerializedPage serializedPage = (SerializedPage) myState.prefetchedQueue[i]; + Page page = pagesSerde.deserialize(serializedPage); + prefetchedQueue.add(page); + } + for (int i = 0; i < myState.consumerQueues.length; i++) { + LinkedList pages = (LinkedList) myState.consumerQueues[i][1]; + LinkedList depages = pages.stream().map(page -> pagesSerde.deserialize(page)).collect(Collectors.toCollection(LinkedList::new)); + consumerQueues.put((PlanNodeId) myState.consumerQueues[i][0], depages); + } + + } + + private static class CommonTableExecutionContextState + implements Serializable + { + + private int queueCnt; + private int maxPrefetchQueueSize; + private int maxMainQueueSize; + private int taskCount; + private Object[] feeders; + private Object[] prefetchedQueue; + private Object[][] consumerQueues; + } + + } diff --git a/presto-main/src/main/java/io/prestosql/operator/CommonTableExpressionOperator.java b/presto-main/src/main/java/io/prestosql/operator/CommonTableExpressionOperator.java index de27e0071..4810a45ec 100644 --- a/presto-main/src/main/java/io/prestosql/operator/CommonTableExpressionOperator.java +++ b/presto-main/src/main/java/io/prestosql/operator/CommonTableExpressionOperator.java @@ -18,13 +18,16 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.airlift.units.DataSize; +import io.prestosql.snapshot.SingleInputSnapshotState; import io.prestosql.spi.Page; import io.prestosql.spi.plan.PlanNodeId; +import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider; import io.prestosql.spi.snapshot.RestorableConfig; import io.prestosql.spi.type.Type; import java.io.Closeable; import java.io.IOException; +import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -36,7 +39,7 @@ import static java.util.Objects.requireNonNull; // TODO-cp-I2TJ3G: will add snapshot support later -@RestorableConfig(unsupported = true) +@RestorableConfig() public class CommonTableExpressionOperator implements Operator, Closeable { @@ -47,9 +50,10 @@ public class CommonTableExpressionOperator private final PlanNodeId consumer; private final CommonTableExecutionContext cteContext; private final Function pagePreprocessor; - private final int operatorInstaceId; + private int operatorInstaceId; private boolean finish; private boolean isFeeder; + private final SingleInputSnapshotState snapshotState; public CommonTableExpressionOperator( PlanNodeId self, @@ -65,6 +69,7 @@ public CommonTableExpressionOperator( this.cteContext = requireNonNull(cteContext, "CTE context is null"); this.operatorInstaceId = operatorInstaceId; this.pagePreprocessor = pagePreprocessor; + this.snapshotState = operatorContext.isSnapshotEnabled() ? SingleInputSnapshotState.forOperator(this, operatorContext) : null; synchronized (cteContext) { if (cteContext.isFeeder(consumer)) { @@ -176,6 +181,15 @@ public boolean needsInput() @Override public void addInput(Page page) { + + checkState(needsInput(), "Operator is already finishing"); + requireNonNull(page, "page is null"); + + if (snapshotState != null) { + if (snapshotState.processPage(page)) { + return; + } + } /* Got a new page... Place it in the Queue! */ Page addPage = pagePreprocessor.apply(page); cteContext.addPage(addPage); @@ -189,6 +203,13 @@ public void addInput(Page page) @Override public Page getOutput() { + if (snapshotState != null) { + Page marker = snapshotState.nextMarker(); + if (marker != null) { + return marker; + } + } + try { Page page = cteContext.getPage(consumer); if (page != null) { @@ -210,8 +231,7 @@ public Page getOutput() @Override public Page pollMarker() { - //TODO-cp-I2TJ3G: Operator currently not supported for Snapshot - return null; + return snapshotState.nextMarker(); } /** @@ -264,6 +284,9 @@ public void finish() @Override public boolean isFinished() { + if (snapshotState != null && snapshotState.hasMarker()) { + return false; + } return finish; } @@ -273,6 +296,45 @@ public boolean isFinished() @Override public void close() throws IOException { + if (snapshotState != null) { + snapshotState.close(); + } LOG.debug("CTE(" + cteContext.getName() + ")[" + consumer + "-" + operatorInstaceId + "] Operator Closed"); } + + @Override + public Object capture(BlockEncodingSerdeProvider serdeProvider) + { + + CommonTableOperatorState myState = new CommonTableOperatorState(); + myState.operatorContext = operatorContext.capture(serdeProvider); + if (isFeeder) { + myState.cteContext = cteContext.capture(serdeProvider); + } + myState.finish = finish; + myState.isFeeder = isFeeder; + return myState; + + } + + @Override + public void restore(Object state, BlockEncodingSerdeProvider serdeProvider) + { + CommonTableOperatorState myState = (CommonTableOperatorState) state; + this.operatorContext.restore(myState.operatorContext, serdeProvider); + isFeeder = myState.isFeeder; + if (isFeeder) { + this.cteContext.restore(myState.cteContext, serdeProvider); + } + finish = myState.finish; + } + + private static class CommonTableOperatorState + implements Serializable + { + private Object operatorContext; + private Object cteContext; + private boolean finish; + private boolean isFeeder; + } } diff --git a/presto-main/src/test/java/io/prestosql/operator/TestCommonTableExpressionOperator.java b/presto-main/src/test/java/io/prestosql/operator/TestCommonTableExpressionOperator.java index ef11595ce..47f47975e 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestCommonTableExpressionOperator.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestCommonTableExpressionOperator.java @@ -27,7 +27,10 @@ import io.prestosql.testing.MaterializedResult; import org.testng.annotations.Test; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -35,6 +38,7 @@ import static io.prestosql.SessionTestUtils.TEST_SESSION; import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals; +import static io.prestosql.operator.OperatorAssertion.assertOperatorEqualsWithSimpleSelfStateComparison; import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.testing.TestingTaskContext.createTaskContext; import static java.util.concurrent.Executors.newCachedThreadPool; @@ -94,6 +98,42 @@ public void testOperatorSource() assertOperatorEquals(parent2, driverContext, ImmutableList.of(input), result); } + + @Test + public void testOperatorSourceSnapshot() { + final Page input = SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0); + DriverContext driverContext = newDriverContext(); + + CommonTableExecutionContext cteContext = new CommonTableExecutionContext("test_cte_prod_1", + ImmutableSet.of(new PlanNodeId("consumer_1"), new PlanNodeId("consumer_2")), new PlanNodeId("consumer_1"), + driverContext.getNotificationExecutor(), 0, 1024, 512); + + CommonTableExpressionOperator.CommonTableExpressionOperatorFactory parent1 = new CommonTableExpressionOperator.CommonTableExpressionOperatorFactory( + 0, + new PlanNodeId("test"), + cteContext, + ImmutableList.of(VARCHAR), + new DataSize(0, DataSize.Unit.BYTE), + 0, + symbol -> symbol); + parent1.addConsumer(new PlanNodeId("consumer_1")); + + MaterializedResult result = MaterializedResult.resultBuilder(driverContext.getSession(), VARCHAR) + .page(input) + .build(); + assertOperatorEqualsWithSimpleSelfStateComparison(parent1, driverContext, ImmutableList.of(input), result, createExpectedMapping()); + } + + private Map createExpectedMapping() + { + Map expectedMapping = new HashMap<>(); + expectedMapping.put("operatorContext", 0); + expectedMapping.put("finish", false); + expectedMapping.put("isFeeder", true); + return expectedMapping; + } + + private static List toPages(Operator operator) { ImmutableList.Builder outputPages = ImmutableList.builder();