Skip to content

Commit

Permalink
Add support for SQL mode NO_BACKSLASH_ESCAPES
Browse files Browse the repository at this point in the history
  • Loading branch information
mirromutth committed Apr 2, 2024
1 parent e97dc7d commit faa30a7
Show file tree
Hide file tree
Showing 11 changed files with 140 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ public boolean isMariaDb() {
return (capability != null && capability.isMariaDb()) || serverVersion.isMariaDb();
}

public boolean isNoBackslashEscapes() {
return (serverStatuses & ServerStatuses.NO_BACKSLASH_ESCAPES) != 0;
}

@Override
public ZeroDateOption getZeroDateOption() {
return zeroDateOption;
Expand Down
42 changes: 21 additions & 21 deletions r2dbc-mysql/src/main/java/io/asyncer/r2dbc/mysql/InitFlow.java
Original file line number Diff line number Diff line change
Expand Up @@ -209,31 +209,31 @@ private static Mono<SessionState> loadAndInitInnoDbEngineStatus(
Codecs codecs,
@Nullable Duration lockWaitTimeout
) {
return new TextSimpleStatement(client, codecs, "SHOW VARIABLES LIKE 'innodb\\\\_lock\\\\_wait\\\\_timeout'")
.execute()
.flatMap(r -> r.map(readable -> {
String value = readable.get(1, String.class);
return new TextSimpleStatement(
client,
codecs,
"SHOW VARIABLES LIKE 'innodb_lock_wait_timeout'"
).execute().flatMap(r -> r.map(readable -> {
String value = readable.get(1, String.class);

if (value == null || value.isEmpty()) {
return data;
} else {
return data.lockWaitTimeout(Duration.ofSeconds(Long.parseLong(value)));
if (value == null || value.isEmpty()) {
return data;
} else {
return data.lockWaitTimeout(Duration.ofSeconds(Long.parseLong(value)));
}
})).single(data).flatMap(d -> {
if (lockWaitTimeout != null) {
// Do not use context.isLockWaitTimeoutSupported() here, because its session variable is not set
if (d.lockWaitTimeoutSupported) {
return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(lockWaitTimeout))
.then(Mono.fromSupplier(() -> d.lockWaitTimeout(lockWaitTimeout)));
}
}))
.single(data)
.flatMap(d -> {
if (lockWaitTimeout != null) {
// Do not use context.isLockWaitTimeoutSupported() here, because its session variable is not set
if (d.lockWaitTimeoutSupported) {
return QueryFlow.executeVoid(client, StringUtils.lockWaitTimeoutStatement(lockWaitTimeout))
.then(Mono.fromSupplier(() -> d.lockWaitTimeout(lockWaitTimeout)));
}

logger.warn("Lock wait timeout is not supported by server, ignore initial setting");
return Mono.just(d);
}
logger.warn("Lock wait timeout is not supported by server, ignore initial setting");
return Mono.just(d);
});
}
return Mono.just(d);
});
}

private static Mono<SessionState> loadSessionVariables(Client client, Codecs codecs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,20 @@ public final class ServerStatuses {
public static final short LAST_ROW_SENT = 128;

// public static final short DB_DROPPED = 256;
// public static final short NO_BACKSLASH_ESCAPES = 512;

/**
* Server does not permit backslash escapes.
*
* @since 1.1.3
*/
public static final short NO_BACKSLASH_ESCAPES = 512;

// public static final short METADATA_CHANGED = 1024;
// public static final short QUERY_WAS_SLOW = 2048;
// public static final short PS_OUT_PARAMS = 4096;
// public static final short IN_TRANS_READONLY = 8192;
// public static final short SESSION_STATE_CHANGED = 16384;

private ServerStatuses() { }
private ServerStatuses() {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,17 @@ final class ParamWriter extends ParameterWriter {

private final StringBuilder builder;

private final boolean noBackslashEscapes;

private final Query query;

private int index;

private Mode mode;

private ParamWriter(Query query) {
private ParamWriter(boolean noBackslashEscapes, Query query) {
this.builder = newBuilder(query);
this.noBackslashEscapes = noBackslashEscapes;
this.query = query;
this.index = 1;
this.mode = 1 < query.getPartSize() ? Mode.AVAILABLE : Mode.FULL;
Expand Down Expand Up @@ -318,15 +321,19 @@ private void write0(char[] s, int off, int len) {
}

private void escape(char c) {
if (c == '\'') {
// MySQL will auto-combine consecutive strings, whatever backslash is used or not, e.g. '1''2' -> '1\'2'
builder.append('\'').append('\'');
return;
} else if (noBackslashEscapes) {
builder.append(c);
return;
}

switch (c) {
case '\\':
builder.append('\\').append('\\');
break;
case '\'':
// MySQL will auto-combine consecutive strings, like '1''2' -> '12'.
// Sure, there can use '1\'2', but this will be better. (For some logging systems)
builder.append('\'').append('\'');
break;
// Maybe useful in the future, keep '"' here.
// case '"': buf.append('\\').append('"'); break;
// SHIFT-JIS, WINDOWS-932, EUC-JP and eucJP-OPEN will encode '\u00a5' (the sign of Japanese Yen
Expand All @@ -335,20 +342,19 @@ private void escape(char c) {
// case '\u00a5': do something; break;
// case '\u20a9': do something; break;
case 0:
// MySQL is based on C/C++, must escape '\0' which is an end flag in C style string.
// Should escape '\0' which is an end flag in C style string.
builder.append('\\').append('0');
break;
case '\032':
// It seems like a problem on Windows 32, maybe check current OS here?
// It gives some problems on Win32.
builder.append('\\').append('Z');
break;
case '\n':
// Should escape it for some logging such as Relational Database Service (RDS) Logging
// System, etc. Sure, it is not necessary, but this will be better.
// Should be escaped for better logging.
builder.append('\\').append('n');
break;
case '\r':
// Should escape it for some logging such as RDS Logging System, etc.
// Should be escaped for better logging.
builder.append('\\').append('r');
break;
default:
Expand All @@ -357,9 +363,9 @@ private void escape(char c) {
}
}

static Mono<String> publish(Query query, Flux<MySqlParameter> values) {
static Mono<String> publish(boolean noBackslashEscapes, Query query, Flux<MySqlParameter> values) {
return Mono.defer(() -> {
ParamWriter writer = new ParamWriter(query);
ParamWriter writer = new ParamWriter(noBackslashEscapes, query);

return OperatorUtils.discardOnCancel(values)
.doOnDiscard(MySqlParameter.class, DISPOSE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public Mono<ByteBuf> encode(ByteBufAllocator allocator, ConnectionContext contex
return Flux.fromArray(values);
});

return ParamWriter.publish(query, parameters).handle((it, sink) -> {
return ParamWriter.publish(context.isNoBackslashEscapes(), query, parameters).handle((it, sink) -> {
ByteBuf buf = allocator.buffer();

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.Collections;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;

import static io.r2dbc.spi.IsolationLevel.READ_COMMITTED;
import static io.r2dbc.spi.IsolationLevel.READ_UNCOMMITTED;
Expand Down Expand Up @@ -80,6 +81,53 @@ void isInTransaction() {
.doOnSuccess(ignored -> assertThat(connection.context().isInTransaction()).isFalse()));
}

@ParameterizedTest
@ValueSource(strings = {
"test",
"test`data",
"test\ndata",
"I'm feeling good",
})
void sqlModeNoBackslashEscapes(String value) {
String tdl = "CREATE TEMPORARY TABLE `test` (`id` INT NOT NULL PRIMARY KEY, `value` VARCHAR(50) NOT NULL)";

// Add NO_BACKSLASH_ESCAPES instead of replace
castedComplete(connection -> Mono.fromRunnable(() -> assertThat(connection.context().isNoBackslashEscapes())
.isFalse())
.thenMany(connection.createStatement(tdl).execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("INSERT INTO test VALUES (1, ?)")
.bind(0, value)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("SELECT COUNT(0) FROM `test` WHERE `value` = ?")
.bind(0, value)
.execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, Integer.class)))
.collectList()
.doOnNext(counts -> assertThat(counts).isEqualTo(Collections.singletonList(1)))
.thenMany(connection.createStatement("SELECT @@sql_mode").execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, String.class)))
.map(modes -> Stream.concat(Stream.of(modes.split(",")), Stream.of("NO_BACKSLASH_ESCAPES"))
.toArray(String[]::new))
.last()
.flatMapMany(modes -> connection.createStatement("SET sql_mode = ?")
.bind(0, modes)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.doOnComplete(() -> assertThat(connection.context().isNoBackslashEscapes()).isTrue())
.thenMany(connection.createStatement("INSERT INTO test VALUES (2, ?)")
.bind(0, value)
.execute())
.flatMap(MySqlResult::getRowsUpdated)
.thenMany(connection.createStatement("SELECT COUNT(0) FROM `test` WHERE `value` = ?")
.bind(0, value)
.execute())
.flatMap(result -> result.map((row, metadata) -> row.get(0, Integer.class)))
.collectList()
.doOnNext(counts -> assertThat(counts).isEqualTo(Collections.singletonList(2))));
}

@DisabledIf("envIsLessThanMySql56")
@Test
void startTransaction() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ void insertOnDuplicate() {
.bind(2, 20)
.execute())
.flatMap(IntegrationTestSupport::extractRowsUpdated)
.doOnNext(it -> assertThat(it).isOne()) // TODO: check capability flag
.doOnNext(it -> assertThat(it).isOne())
.thenMany(connection.createStatement("SELECT value FROM test WHERE id=?")
.bind(0, 1)
.execute())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ default void encodeStringify() {
Query query = Query.parse("?");

for (int i = 0; i < origin.length; ++i) {
ParameterWriter writer = ParameterWriterHelper.get(query);
ParameterWriter writer = ParameterWriterHelper.get(false, query);
codec.encode(origin[i], context())
.publishText(writer)
.as(StepVerifier::create)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void stringifySet() {
Query query = Query.parse("?");

for (int i = 0; i < sets.length; ++i) {
ParameterWriter writer = ParameterWriterHelper.get(query);
ParameterWriter writer = ParameterWriterHelper.get(false, query);
codec.encode(sets[i], context())
.publishText(writer)
.as(StepVerifier::create)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import io.asyncer.r2dbc.mysql.Query;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import reactor.core.publisher.Flux;
import reactor.test.StepVerifier;

Expand Down Expand Up @@ -84,42 +86,42 @@ void badFollowNull() {

@Test
void appendPart() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.append("define", 2, 5);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'fin'");
}

@Test
void writePart() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write("define", 2, 3);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'fin'");
}

@Test
void appendNull() {
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(parameterOnly(1)).append(null)))
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(false, parameterOnly(1)).append(null)))
.isEqualTo("'null'");
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(parameterOnly(1))
assertThat(ParameterWriterHelper.toSql(ParameterWriterHelper.get(false, parameterOnly(1))
.append(null, 1, 3)))
.isEqualTo("'ul'");
}

@Test
void writeNull() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((String) null);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'null'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((String) null, 1, 2);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'ul'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((char[]) null);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'null'");

writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write((char[]) null, 1, 2);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'ul'");
}
Expand All @@ -132,7 +134,7 @@ void publishSuccess() {
values[i] = new MockMySqlParameter(true);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.expectNext(new String(new char[SIZE]).replace("\0", "''"))
.verifyComplete();
Expand All @@ -154,7 +156,7 @@ void publishPartially() {
values[i] = new MockMySqlParameter(false);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.verifyError(MockException.class);

Expand All @@ -169,13 +171,30 @@ void publishNothing() {
values[i] = new MockMySqlParameter(false);
}

Flux.from(ParamWriter.publish(parameterOnly(SIZE), Flux.fromArray(values)))
Flux.from(ParamWriter.publish(false, parameterOnly(SIZE), Flux.fromArray(values)))
.as(StepVerifier::create)
.verifyError(MockException.class);

assertThat(values).extracting(MockMySqlParameter::refCnt).containsOnly(0);
}

@ParameterizedTest
@ValueSource(strings = {
"abc",
"a'b'c",
"a\nb\rc",
"a\"b\"c",
"a\\b\\c",
"a\0b\0c",
"a\u00a5b\u20a9c",
"a\032b\032c",
})
void noBackslashEscapes(String value) {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(true, parameterOnly(1));
writer.write(value);
assertThat(ParameterWriterHelper.toSql(writer)).isEqualTo("'" + value.replaceAll("'", "''") + "'");
}

private static Query parameterOnly(int parameters) {
char[] chars = new char[parameters];
Arrays.fill(chars, '?');
Expand All @@ -184,13 +203,13 @@ private static Query parameterOnly(int parameters) {
}

private static ParamWriter stringWriter() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.write('0');
return writer;
}

private static ParamWriter nullWriter() {
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(parameterOnly(1));
ParamWriter writer = (ParamWriter) ParameterWriterHelper.get(false, parameterOnly(1));
writer.writeNull();
return writer;
}
Expand Down
Loading

0 comments on commit faa30a7

Please sign in to comment.