Skip to content

Commit

Permalink
Add LOCAL INFILE support
Browse files Browse the repository at this point in the history
- Improve envelope codec for `SubsequenceClientMessage`
- Add `LocalInfileRequest`/`LocalInfileResponse` for text protocol
  • Loading branch information
mirromutth committed Jan 23, 2024
1 parent 8ec8b97 commit 9f99c07
Show file tree
Hide file tree
Showing 28 changed files with 831 additions and 76 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,7 @@ jobs:
mysql version: ${{ matrix.mysql-version }}
mysql database: r2dbc
mysql root password: r2dbc-password!@
- name: Enable LOCAL INFILE
run: mysql --protocol=tcp --password='r2dbc-password!@' "SET GLOBAL local_infile=on;"
- name: Integration test with MySQL ${{ matrix.mysql-version }}
run: ./mvnw -B verify -Dmaven.javadoc.skip=true -Dmaven.surefire.skip=true -Dtest.mysql.password=r2dbc-password!@ -Dtest.mysql.version=${{ matrix.mysql-version }} -Dorg.slf4j.simpleLogger.log.org.apache.maven.cli.transfer.Slf4jMavenTransferListener=WARN
4 changes: 4 additions & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/ConnectionContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ public ZeroDateOption getZeroDateOption() {
return zeroDateOption;
}

public int getLocalInfileBufferSize() {
return 64 * 1024;
}

/**
* Get the bitmap of server statuses.
*
Expand Down
126 changes: 84 additions & 42 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
import io.asyncer.r2dbc.mysql.client.FluxExchangeable;
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.asyncer.r2dbc.mysql.internal.util.StringUtils;
import io.asyncer.r2dbc.mysql.message.client.AuthResponse;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.HandshakeResponse;
import io.asyncer.r2dbc.mysql.message.client.LoginClientMessage;
import io.asyncer.r2dbc.mysql.message.client.LocalInfileResponse;
import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage;
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedCloseMessage;
Expand All @@ -44,6 +44,7 @@
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
import io.asyncer.r2dbc.mysql.message.server.HandshakeHeader;
import io.asyncer.r2dbc.mysql.message.server.HandshakeRequest;
import io.asyncer.r2dbc.mysql.message.server.LocalInfileRequest;
import io.asyncer.r2dbc.mysql.message.server.OkMessage;
import io.asyncer.r2dbc.mysql.message.server.PreparedOkMessage;
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
Expand Down Expand Up @@ -74,6 +75,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Predicate;

Expand Down Expand Up @@ -209,36 +211,26 @@ static Mono<Client> login(Client client, SslMode sslMode, String database, Strin
* terminates with the last {@link CompleteMessage} or a {@link ErrorMessage}. The {@link ErrorMessage}
* will emit an exception. The exchange will be completed by {@link CompleteMessage} after receive the
* last result for the last binding.
* <p>
* Note: this method does not support {@code LOCAL INFILE} due to it should be used for excepted queries.
*
* @param client the {@link Client} to exchange messages with.
* @param sql the query to execute, can be contains multi-statements.
* @return receives complete signal.
*/
static Mono<Void> executeVoid(Client client, String sql) {
return Mono.defer(() -> execute0(client, sql).doOnNext(EXECUTE_VOID).then());
}
return Mono.defer(() -> client.<ServerMessage>exchange(new TextQueryMessage(sql), (message, sink) -> {
if (message instanceof ErrorMessage) {
sink.next(((ErrorMessage) message).offendedBy(sql));
sink.complete();
} else {
sink.next(message);

/**
* Execute multiple simple queries with one-by-one and return a {@link Mono} for the complete signal or
* error. Query execution terminates with the last {@link CompleteMessage} or a {@link ErrorMessage}. The
* {@link ErrorMessage} will emit an exception and cancel subsequent statements execution. The exchange
* will be completed by {@link CompleteMessage} after receive the last result for the last binding.
*
* @param client the {@link Client} to exchange messages with.
* @param statements the queries to execute, each element can be contains multi-statements.
* @return receives complete signal.
*/
static Mono<Void> executeVoid(Client client, String... statements) {
switch (statements.length) {
case 0:
return Mono.empty();
case 1:
return executeVoid(client, statements[0]);
default:
return client.exchange(new MultiQueryExchangeable(InternalArrays.asIterator(statements)))
.doOnNext(EXECUTE_VOID)
.then();
}
if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.complete();
}
}
}).doOnSubscribe(ignored -> QueryLogger.log(sql)).doOnNext(EXECUTE_VOID).then());
}

/**
Expand Down Expand Up @@ -303,18 +295,7 @@ static Mono<Void> createSavepoint(Client client, ConnectionState state, String n
* @return the messages received in response to this exchange.
*/
private static Flux<ServerMessage> execute0(Client client, String sql) {
return client.<ServerMessage>exchange(new TextQueryMessage(sql), (message, sink) -> {
if (message instanceof ErrorMessage) {
sink.next(((ErrorMessage) message).offendedBy(sql));
sink.complete();
} else {
sink.next(message);

if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
sink.complete();
}
}
}).doOnSubscribe(ignored -> QueryLogger.log(sql));
return client.exchange(new SimpleQueryExchangeable(sql));
}

private QueryFlow() { }
Expand All @@ -339,6 +320,16 @@ public final void accept(ServerMessage message, SynchronousSink<ServerMessage> s
if (message instanceof ErrorMessage) {
sink.next(((ErrorMessage) message).offendedBy(offendingSql()));
sink.complete();
} else if (message instanceof LocalInfileRequest) {
LocalInfileRequest request = (LocalInfileRequest) message;
String path = request.getPath();

QueryLogger.logLocalInfile(path);

requests.emitNext(
new LocalInfileResponse(request.getEnvelopeId() + 1, path, sink),
Sinks.EmitFailureHandler.FAIL_FAST
);
} else {
sink.next(message);

Expand All @@ -353,6 +344,59 @@ public final void accept(ServerMessage message, SynchronousSink<ServerMessage> s
abstract protected String offendingSql();
}

final class SimpleQueryExchangeable extends BaseFluxExchangeable {

private static final int INIT = 0;

private static final int EXECUTE = 1;

private static final int DISPOSE = 2;

private final AtomicInteger state = new AtomicInteger(INIT);

private final String sql;

SimpleQueryExchangeable(String sql) {
this.sql = sql;
}

@Override
public void dispose() {
if (state.getAndSet(DISPOSE) != DISPOSE) {
requests.tryEmitComplete();
}
}

@Override
public boolean isDisposed() {
return state.get() == DISPOSE;
}

@Override
protected void tryNextOrComplete(@Nullable SynchronousSink<ServerMessage> sink) {
if (state.compareAndSet(INIT, EXECUTE)) {
QueryLogger.log(sql);

Sinks.EmitResult result = requests.tryEmitNext(new TextQueryMessage(sql));

if (result == Sinks.EmitResult.OK) {
return;
}

QueryFlow.logger.error("Emit request failed due to {}", result);
}

if (sink != null) {
sink.complete();
}
}

@Override
protected String offendingSql() {
return sql;
}
}

/**
* An implementation of {@link FluxExchangeable} that considers client-preparing requests.
*/
Expand Down Expand Up @@ -770,8 +814,8 @@ final class LoginExchangeable extends FluxExchangeable<Void> {

private static final int HANDSHAKE_VERSION = 10;

private final Sinks.Many<LoginClientMessage> requests = Sinks.many().unicast()
.onBackpressureBuffer(Queues.<LoginClientMessage>one().get());
private final Sinks.Many<SubsequenceClientMessage> requests = Sinks.many().unicast()
.onBackpressureBuffer(Queues.<SubsequenceClientMessage>one().get());

private final Client client;

Expand Down Expand Up @@ -879,7 +923,7 @@ public void dispose() {
this.requests.tryEmitComplete();
}

private void emitNext(LoginClientMessage message, SynchronousSink<Void> sink) {
private void emitNext(SubsequenceClientMessage message, SynchronousSink<Void> sink) {
Sinks.EmitResult result = requests.tryEmitNext(message);

if (result != Sinks.EmitResult.OK) {
Expand All @@ -903,8 +947,6 @@ private Capability clientCapability(Capability serverCapability) {

builder.disableDatabasePinned();
builder.disableCompression();
// TODO: support LOAD DATA LOCAL INFILE
builder.disableLoadDataInfile();
builder.disableIgnoreAmbiguitySpace();
builder.disableInteractiveTimeout();

Expand Down
4 changes: 4 additions & 0 deletions src/main/java/io/asyncer/r2dbc/mysql/QueryLogger.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ static void log(int statementId, MySqlParameter[] values) {
logger.debug("Executing prepared statement {} with {}", statementId, values);
}

static void logLocalInfile(String path) {
logger.debug("Loading data from: {}", path);
}

private QueryLogger() { }
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ final class TextParametrizedStatement extends ParametrizedStatementSupport {

@Override
protected Flux<MySqlResult> execute(List<Binding> bindings) {
return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(), bindings))
return Flux.defer(() -> QueryFlow.execute(client, query, returningIdentifiers(),
bindings))
.map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages));
}
}
4 changes: 2 additions & 2 deletions src/main/java/io/asyncer/r2dbc/mysql/TextSimpleStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ final class TextSimpleStatement extends SimpleStatementSupport {
public Flux<MySqlResult> execute() {
return Flux.defer(() -> QueryFlow.execute(
client,
StringUtils.extendReturning(sql, returningIdentifiers()))
).map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages));
StringUtils.extendReturning(sql, returningIdentifiers())
).map(messages -> MySqlResult.toResult(false, codecs, context, syntheticKeyName(), messages)));
}
}
20 changes: 10 additions & 10 deletions src/main/java/io/asyncer/r2dbc/mysql/client/MessageDuplexCodec.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import io.asyncer.r2dbc.mysql.ConnectionContext;
import io.asyncer.r2dbc.mysql.internal.util.OperatorUtils;
import io.asyncer.r2dbc.mysql.message.client.ClientMessage;
import io.asyncer.r2dbc.mysql.message.client.LoginClientMessage;
import io.asyncer.r2dbc.mysql.message.client.SubsequenceClientMessage;
import io.asyncer.r2dbc.mysql.message.client.PrepareQueryMessage;
import io.asyncer.r2dbc.mysql.message.client.PreparedFetchMessage;
import io.asyncer.r2dbc.mysql.message.client.SslRequest;
Expand Down Expand Up @@ -86,22 +86,22 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
if (msg instanceof ClientMessage) {
ByteBufAllocator allocator = ctx.alloc();

Flux<ByteBuf> encoded;
int envelopeId;

if (msg instanceof LoginClientMessage) {
LoginClientMessage message = (LoginClientMessage) msg;
if (msg instanceof SubsequenceClientMessage) {
SubsequenceClientMessage message = (SubsequenceClientMessage) msg;

encoded = Flux.from(message.encode(allocator, this.context));
envelopeId = message.getEnvelopeId();
int envelopeId = message.getEnvelopeId();

OperatorUtils.envelope(encoded, allocator, envelopeId, false)
.subscribe(new WriteSubscriber(ctx, promise));
} else {
encoded = Flux.from(((ClientMessage) msg).encode(allocator, this.context));
envelopeId = 0;
}

OperatorUtils.cumulateEnvelope(encoded, allocator, envelopeId)
.subscribe(new WriteSubscriber(ctx, promise));
OperatorUtils.envelope(encoded, allocator, 0, true)
.subscribe(new WriteSubscriber(ctx, promise));
}

if (msg instanceof PrepareQueryMessage) {
setDecodeContext(DecodeContext.prepareQuery());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ public <T> Flux<T> exchange(FluxExchangeable<T> exchangeable) {
.asFlux()
.doOnSubscribe(ignored -> exchangeable.subscribe(
this::emitNextRequest,
e -> requests.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST))
e ->
requests.emitError(e, Sinks.EmitFailureHandler.FAIL_FAST))
)
.handle(exchangeable)
.doOnTerminate(() -> {
Expand Down
Loading

0 comments on commit 9f99c07

Please sign in to comment.