Skip to content

Commit

Permalink
RATIS-1519. When DataStreamManagement#read an exception occurs, remov…
Browse files Browse the repository at this point in the history
…e DataStream (apache#596)
  • Loading branch information
guohao-rosicky authored Apr 8, 2024
1 parent f404244 commit bc6221b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ static class StreamInfo {
private final boolean primary;
private final LocalStream local;
private final Set<RemoteStream> remotes;
private final RaftServer server;
private final Division division;
private final AtomicReference<CompletableFuture<Void>> previous
= new AtomicReference<>(CompletableFuture.completedFuture(null));

StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<DataStream> stream, RaftServer server,
StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<DataStream> stream, Division division,
CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputImpl>, IOException> getStreams,
Function<RequestType, RequestMetrics> metricsConstructor)
throws IOException {
this.request = request;
this.primary = primary;
this.local = new LocalStream(stream, metricsConstructor.apply(RequestType.LOCAL_WRITE));
this.server = server;
final Set<RaftPeer> successors = getSuccessors(server.getId());
this.division = division;
final Set<RaftPeer> successors = getSuccessors(division.getId());
final Set<DataStreamOutputImpl> outs = getStreams.apply(request, successors);
this.remotes = outs.stream()
.map(o -> new RemoteStream(o, metricsConstructor.apply(RequestType.REMOTE_WRITE)))
Expand All @@ -167,16 +167,12 @@ RaftClientRequest getRequest() {
return request;
}

Division getDivision() throws IOException {
return server.getDivision(request.getRaftGroupId());
Division getDivision() {
return division;
}

Collection<CommitInfoProto> getCommitInfos() {
try {
return getDivision().getCommitInfos();
} catch (IOException e) {
throw new IllegalStateException(e);
}
return getDivision().getCommitInfos();
}

boolean isPrimary() {
Expand All @@ -196,7 +192,7 @@ public String toString() {
return JavaUtils.getClassSimpleName(getClass()) + ":" + request;
}

private Set<RaftPeer> getSuccessors(RaftPeerId peerId) throws IOException {
private Set<RaftPeer> getSuccessors(RaftPeerId peerId) {
final RaftConfiguration conf = getDivision().getRaftConf();
final RoutingTable routingTable = request.getRoutingTable();

Expand All @@ -208,7 +204,7 @@ private Set<RaftPeer> getSuccessors(RaftPeerId peerId) throws IOException {
// Default start topology
// get the other peers from the current configuration
return conf.getCurrentPeers().stream()
.filter(p -> !p.getId().equals(server.getId()))
.filter(p -> !p.getId().equals(division.getId()))
.collect(Collectors.toSet());
}

Expand Down Expand Up @@ -276,7 +272,8 @@ private StreamInfo newStreamInfo(ByteBuf buf,
final RaftClientRequest request = ClientProtoUtils.toRaftClientRequest(
RaftClientRequestProto.parseFrom(buf.nioBuffer()));
final boolean isPrimary = server.getId().equals(request.getServerId());
return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), server, getStreams,
final Division division = server.getDivision(request.getRaftGroupId());
return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), division, getStreams,
getMetrics()::newRequestMetrics);
} catch (Throwable e) {
throw new CompletionException(e);
Expand Down Expand Up @@ -411,6 +408,18 @@ void read(DataStreamRequestByteBuf request, ChannelHandlerContext ctx,
readImpl(request, ctx, getStreams);
} catch (Throwable t) {
replyDataStreamException(t, request, ctx);
removeDataStream(ClientInvocationId.valueOf(request.getClientId(), request.getStreamId()), null);
}
}

private void removeDataStream(ClientInvocationId invocationId, StreamInfo info) {
final StreamInfo removed = streams.remove(invocationId);
if (info == null) {
info = removed;
}
if (info != null) {
info.getDivision().getDataStreamMap().remove(invocationId);
info.getLocal().cleanUp();
}
}

Expand All @@ -429,8 +438,6 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct
() -> newStreamInfo(request.slice(), getStreams));
info = streams.computeIfAbsent(key, id -> supplier.get());
if (!supplier.isInitialized()) {
final StreamInfo removed = streams.remove(key);
removed.getLocal().cleanUp();
throw new IllegalStateException("Failed to create a new stream for " + request
+ " since a stream already exists Key: " + key + " StreamInfo:" + info);
}
Expand Down Expand Up @@ -468,9 +475,8 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct
}, requestExecutor)).whenComplete((v, exception) -> {
try {
if (exception != null) {
final StreamInfo removed = streams.remove(key);
replyDataStreamException(server, exception, info.getRequest(), request, ctx);
removed.getLocal().cleanUp();
removeDataStream(key, info);
}
} finally {
request.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class MultiDataStreamStateMachine extends BaseStateMachine {
@Override
public CompletableFuture<DataStream> stream(RaftClientRequest request) {
final SingleDataStream s = new SingleDataStream(request);
LOG.info("XXX {} put {}, {}", this, ClientInvocationId.valueOf(request), s);
streams.put(ClientInvocationId.valueOf(request), s);
return CompletableFuture.completedFuture(s);
}
Expand Down Expand Up @@ -179,7 +180,9 @@ SingleDataStream getSingleDataStream(RaftClientRequest request) {
}

SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
return streams.get(invocationId);
final SingleDataStream s = streams.get(invocationId);
LOG.info("XXX {}: get {} return {}", this, invocationId, s);
return s;
}

Collection<SingleDataStream> getStreams() {
Expand Down Expand Up @@ -329,6 +332,8 @@ static CompletableFuture<RaftClientReply> writeAndCloseAndAssertReplies(

static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize, boolean stepDownLeader)
throws Exception {
LOG.info("XXX {}: dataSize={}, stepDownLeader={}, header={}",
server.getId(), dataSize, stepDownLeader, header);
// check header
Assertions.assertEquals(RaftClientRequest.dataStreamRequestType(), header.getType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ public void setup() {
RaftConfigKeys.DataStream.setType(properties, SupportedDataStreamType.NETTY);
}

RaftServer.Division mockDivision(RaftServer server) {

RaftServer.Division mockDivision(RaftServer server, RaftGroupId groupId) {
final RaftServer.Division division = mock(RaftServer.Division.class);
when(division.getRaftServer()).thenReturn(server);
when(division.getRaftConf()).thenAnswer(i -> getRaftConf());

final MultiDataStreamStateMachine stateMachine = new MultiDataStreamStateMachine();
try {
stateMachine.initialize(server, groupId, null);
} catch (IOException e) {
throw new IllegalStateException(e);
}
when(division.getStateMachine()).thenReturn(stateMachine);

final DataStreamMap streamMap = RaftServerTestUtil.newDataStreamMap(server.getId());
Expand Down Expand Up @@ -95,7 +101,7 @@ private void testMockCluster(int numServers, RaftException leaderException,
when(raftServer.getId()).thenReturn(peerId);
when(raftServer.getPeer()).thenReturn(RaftPeer.newBuilder().setId(peerId).build());
if (getStateMachineException == null) {
final RaftServer.Division myDivision = mockDivision(raftServer);
final RaftServer.Division myDivision = mockDivision(raftServer, groupId);
when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenReturn(myDivision);
} else {
when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenThrow(getStateMachineException);
Expand Down

0 comments on commit bc6221b

Please sign in to comment.