Skip to content

Commit

Permalink
CTE support for snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
codewithaloknath committed Apr 8, 2022
1 parent 7ad7194 commit 42d998a
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -749,9 +749,8 @@ else if (statement instanceof InsertCube) {
}

// Doesn't work with the following features
if (SystemSessionProperties.isReuseTableScanEnabled(session)
|| SystemSessionProperties.isCTEReuseEnabled(session)) {
reasons.add("No support along with reuse_table_scan or cte_reuse_enabled features");
if (SystemSessionProperties.isReuseTableScanEnabled(session)) {
reasons.add("No support along with reuse_table_scan feature");
}

// All input tables must support snapshotting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,27 @@
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.airlift.log.Logger;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import io.hetu.core.transport.execution.buffer.PagesSerde;
import io.hetu.core.transport.execution.buffer.PagesSerdeFactory;
import io.hetu.core.transport.execution.buffer.SerializedPage;
import io.prestosql.block.BlockJsonSerde;
import io.prestosql.spi.Page;
import io.prestosql.spi.block.Block;
import io.prestosql.spi.block.BlockEncodingSerde;
import io.prestosql.spi.plan.PlanNodeId;
import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider;
import io.prestosql.spi.snapshot.Restorable;
import io.prestosql.spi.snapshot.RestorableConfig;
import it.unimi.dsi.fastutil.longs.LongArrayList;

import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
Expand All @@ -37,11 +47,13 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.prestosql.operator.Operator.NOT_BLOCKED;

@RestorableConfig()
public class CommonTableExecutionContext
implements Restorable
{
private static final Logger LOG = Logger.get(CommonTableExecutionContext.class);
private final String name;
private final int queueCnt;
private int queueCnt;
private final PlanNodeId feederId;
private boolean isFeederInitialized;
private List<Integer> feeders = Collections.synchronizedList(new ArrayList<>());
Expand All @@ -52,9 +64,9 @@ public class CommonTableExecutionContext
private final Executor notificationExecutor;
@GuardedBy("this")
private SettableFuture<?> blockedFuture;
private final int taskCount;
private final int maxMainQueueSize;
private final int maxPrefetchQueueSize;
private int taskCount;
private int maxMainQueueSize;
private int maxPrefetchQueueSize;

public CommonTableExecutionContext(String name, Set<PlanNodeId> consumers, PlanNodeId feederId, Executor notificationExecutor,
int taskCount, int maxMainQueueSize, int maxPrefetchQueueSize)
Expand Down Expand Up @@ -264,4 +276,86 @@ public CTEDoneException()
super();
}
}

@Override
public Object capture(BlockEncodingSerdeProvider serdeProvider) {
BlockEncodingSerde blockSerde = serdeProvider.getBlockEncodingSerde();
CommonTableExecutionContextState myState = new CommonTableExecutionContextState();
PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockSerde,false);
PagesSerde pagesSerde = pagesSerdeFactory.createPagesSerde();

myState.queueCnt = queueCnt;
myState.maxPrefetchQueueSize = maxPrefetchQueueSize;
myState.maxMainQueueSize = maxMainQueueSize;
myState.taskCount = taskCount;
myState.feeders = new Object[feeders.size()];
for (int i = 0; i < feeders.size(); i++) {
myState.feeders[i] = feeders.get(i);
}
myState.prefetchedQueue = new Object[prefetchedQueue.size()];
Iterator iterator = prefetchedQueue.iterator();
while (iterator.hasNext()) {
Page page = (Page) iterator.next();
SerializedPage serializedPage = pagesSerde.serialize(page);
int p = 0;
myState.prefetchedQueue[p] = serializedPage;
p++;
}
myState.consumerQueues = new Object[consumerQueues.size()][2];
if (consumerQueues != null) {
int count = 0;
for (Map.Entry<PlanNodeId, LinkedList<Page>> entry : consumerQueues.entrySet()) {
myState.consumerQueues[count][0] = entry.getKey();
LinkedList<Page> pages = entry.getValue();
LinkedList<SerializedPage> serializedPages = pages.stream().map(page -> pagesSerde.serialize(page)).collect(Collectors.toCollection(LinkedList::new));
myState.consumerQueues[count][1] = serializedPages;
count ++;
}
}
return myState;
}

@Override
public void restore(Object state, BlockEncodingSerdeProvider serdeProvider)
{
BlockEncodingSerde blockSerde = serdeProvider.getBlockEncodingSerde();
CommonTableExecutionContextState myState = (CommonTableExecutionContextState) state;

PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockSerde,false);
PagesSerde pagesSerde = pagesSerdeFactory.createPagesSerde();

this.queueCnt = myState.queueCnt;
this.maxPrefetchQueueSize = myState.maxPrefetchQueueSize;
this.maxMainQueueSize = myState.maxMainQueueSize;
this.taskCount = myState.taskCount;
for (int i = 0; i < myState.feeders.length; i++) {
feeders.add((Integer) myState.feeders[i]);
}
for (int i = 0; i < myState.prefetchedQueue.length; i++) {
SerializedPage serializedPage = (SerializedPage) myState.prefetchedQueue[i];
Page page = pagesSerde.deserialize(serializedPage);
prefetchedQueue.add(page);
}
for (int i = 0; i < myState.consumerQueues.length; i++) {
LinkedList<SerializedPage> pages = (LinkedList<SerializedPage>) myState.consumerQueues[i][1];
LinkedList<Page> depages = pages.stream().map(page -> pagesSerde.deserialize(page)).collect(Collectors.toCollection(LinkedList::new));
consumerQueues.put((PlanNodeId) myState.consumerQueues[i][0], depages);
}

}

private static class CommonTableExecutionContextState
implements Serializable
{

private int queueCnt;
private int maxPrefetchQueueSize;
private int maxMainQueueSize;
private int taskCount;
private Object[] feeders;
private Object[] prefetchedQueue;
private Object[][] consumerQueues;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.prestosql.snapshot.SingleInputSnapshotState;
import io.prestosql.spi.Page;
import io.prestosql.spi.plan.PlanNodeId;
import io.prestosql.spi.snapshot.BlockEncodingSerdeProvider;
import io.prestosql.spi.snapshot.RestorableConfig;
import io.prestosql.spi.type.Type;

import java.io.Closeable;
import java.io.IOException;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
Expand All @@ -36,7 +39,7 @@
import static java.util.Objects.requireNonNull;

// TODO-cp-I2TJ3G: will add snapshot support later
@RestorableConfig(unsupported = true)
@RestorableConfig()
public class CommonTableExpressionOperator
implements Operator, Closeable
{
Expand All @@ -47,9 +50,10 @@ public class CommonTableExpressionOperator
private final PlanNodeId consumer;
private final CommonTableExecutionContext cteContext;
private final Function<Page, Page> pagePreprocessor;
private final int operatorInstaceId;
private int operatorInstaceId;
private boolean finish;
private boolean isFeeder;
private final SingleInputSnapshotState snapshotState;

public CommonTableExpressionOperator(
PlanNodeId self,
Expand All @@ -65,6 +69,7 @@ public CommonTableExpressionOperator(
this.cteContext = requireNonNull(cteContext, "CTE context is null");
this.operatorInstaceId = operatorInstaceId;
this.pagePreprocessor = pagePreprocessor;
this.snapshotState = operatorContext.isSnapshotEnabled() ? SingleInputSnapshotState.forOperator(this, operatorContext) : null;

synchronized (cteContext) {
if (cteContext.isFeeder(consumer)) {
Expand Down Expand Up @@ -176,6 +181,15 @@ public boolean needsInput()
@Override
public void addInput(Page page)
{

checkState(needsInput(), "Operator is already finishing");
requireNonNull(page, "page is null");

if (snapshotState != null) {
if (snapshotState.processPage(page)) {
return;
}
}
/* Got a new page... Place it in the Queue! */
Page addPage = pagePreprocessor.apply(page);
cteContext.addPage(addPage);
Expand All @@ -189,6 +203,13 @@ public void addInput(Page page)
@Override
public Page getOutput()
{
if (snapshotState != null) {
Page marker = snapshotState.nextMarker();
if (marker != null) {
return marker;
}
}

try {
Page page = cteContext.getPage(consumer);
if (page != null) {
Expand All @@ -210,8 +231,7 @@ public Page getOutput()
@Override
public Page pollMarker()
{
//TODO-cp-I2TJ3G: Operator currently not supported for Snapshot
return null;
return snapshotState.nextMarker();
}

/**
Expand Down Expand Up @@ -264,6 +284,9 @@ public void finish()
@Override
public boolean isFinished()
{
if (snapshotState != null && snapshotState.hasMarker()) {
return false;
}
return finish;
}

Expand All @@ -273,6 +296,45 @@ public boolean isFinished()
@Override
public void close() throws IOException
{
if (snapshotState != null) {
snapshotState.close();
}
LOG.debug("CTE(" + cteContext.getName() + ")[" + consumer + "-" + operatorInstaceId + "] Operator Closed");
}

@Override
public Object capture(BlockEncodingSerdeProvider serdeProvider)
{

CommonTableOperatorState myState = new CommonTableOperatorState();
myState.operatorContext = operatorContext.capture(serdeProvider);
if (isFeeder) {
myState.cteContext = cteContext.capture(serdeProvider);
}
myState.finish = finish;
myState.isFeeder = isFeeder;
return myState;

}

@Override
public void restore(Object state, BlockEncodingSerdeProvider serdeProvider)
{
CommonTableOperatorState myState = (CommonTableOperatorState) state;
this.operatorContext.restore(myState.operatorContext, serdeProvider);
isFeeder = myState.isFeeder;
if (isFeeder) {
this.cteContext.restore(myState.cteContext, serdeProvider);
}
finish = myState.finish;
}

private static class CommonTableOperatorState
implements Serializable
{
private Object operatorContext;
private Object cteContext;
private boolean finish;
private boolean isFeeder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@
import io.prestosql.testing.MaterializedResult;
import org.testng.annotations.Test;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;

import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.prestosql.SessionTestUtils.TEST_SESSION;
import static io.prestosql.metadata.MetadataManager.createTestMetadataManager;
import static io.prestosql.operator.OperatorAssertion.assertOperatorEquals;
import static io.prestosql.operator.OperatorAssertion.assertOperatorEqualsWithSimpleSelfStateComparison;
import static io.prestosql.spi.type.VarcharType.VARCHAR;
import static io.prestosql.testing.TestingTaskContext.createTaskContext;
import static java.util.concurrent.Executors.newCachedThreadPool;
Expand Down Expand Up @@ -94,6 +98,42 @@ public void testOperatorSource()
assertOperatorEquals(parent2, driverContext, ImmutableList.of(input), result);
}


@Test
public void testOperatorSourceSnapshot() {
final Page input = SequencePageBuilder.createSequencePage(ImmutableList.of(VARCHAR), 10_000, 0);
DriverContext driverContext = newDriverContext();

CommonTableExecutionContext cteContext = new CommonTableExecutionContext("test_cte_prod_1",
ImmutableSet.of(new PlanNodeId("consumer_1"), new PlanNodeId("consumer_2")), new PlanNodeId("consumer_1"),
driverContext.getNotificationExecutor(), 0, 1024, 512);

CommonTableExpressionOperator.CommonTableExpressionOperatorFactory parent1 = new CommonTableExpressionOperator.CommonTableExpressionOperatorFactory(
0,
new PlanNodeId("test"),
cteContext,
ImmutableList.of(VARCHAR),
new DataSize(0, DataSize.Unit.BYTE),
0,
symbol -> symbol);
parent1.addConsumer(new PlanNodeId("consumer_1"));

MaterializedResult result = MaterializedResult.resultBuilder(driverContext.getSession(), VARCHAR)
.page(input)
.build();
assertOperatorEqualsWithSimpleSelfStateComparison(parent1, driverContext, ImmutableList.of(input), result, createExpectedMapping());
}

private Map<String, Object> createExpectedMapping()
{
Map<String, Object> expectedMapping = new HashMap<>();
expectedMapping.put("operatorContext", 0);
expectedMapping.put("finish", false);
expectedMapping.put("isFeeder", true);
return expectedMapping;
}


private static List<Page> toPages(Operator operator)
{
ImmutableList.Builder<Page> outputPages = ImmutableList.builder();
Expand Down

0 comments on commit 42d998a

Please sign in to comment.