Skip to content

Commit

Permalink
Support tinyInt1isBit
Browse files Browse the repository at this point in the history
Motivation:
Currently, the R2DBC MySQL connector does not  support the `tinyInt1isBit` MySQL parameter.

Modifications:
Updated the mapping of `TINYINT(1)` or `BIT(1)` values to boolean based on the `tinyInt1isBit` parameter.

Result:
Properly supports the `tinyInt1isBit` parameter.
Resolves #277
  • Loading branch information
jchrys committed Jul 28, 2024
1 parent e37cbdd commit e061056
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ public final class ConnectionContext implements CodecContext {

private final boolean preserveInstants;

private final boolean tinyInt1isBit;

private int connectionId = -1;

private ServerVersion serverVersion = NONE_VERSION;
Expand Down Expand Up @@ -104,16 +106,18 @@ public final class ConnectionContext implements CodecContext {
private volatile short serverStatuses = ServerStatuses.AUTO_COMMIT;

ConnectionContext(
ZeroDateOption zeroDateOption,
@Nullable Path localInfilePath,
int localInfileBufferSize,
boolean preserveInstants,
@Nullable ZoneId timeZone
ZeroDateOption zeroDateOption,
@Nullable Path localInfilePath,
int localInfileBufferSize,
boolean preserveInstants,
boolean tinyInt1isBit,
@Nullable ZoneId timeZone
) {
this.zeroDateOption = requireNonNull(zeroDateOption, "zeroDateOption must not be null");
this.localInfilePath = localInfilePath;
this.localInfileBufferSize = localInfileBufferSize;
this.preserveInstants = preserveInstants;
this.tinyInt1isBit = tinyInt1isBit;
this.timeZone = timeZone;
}

Expand Down Expand Up @@ -333,4 +337,9 @@ boolean isAutoCommit() {
return (serverStatuses & ServerStatuses.IN_TRANSACTION) == 0 &&
(serverStatuses & ServerStatuses.AUTO_COMMIT) != 0;
}

@Override
public boolean isTinyInt1isBit() {
return tinyInt1isBit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,24 @@ public final class MySqlConnectionConfiguration {
@Nullable
private final AddressResolverGroup<?> resolver;

private final boolean tinyInt1isBit;

private MySqlConnectionConfiguration(
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
ZeroDateOption zeroDateOption,
boolean preserveInstants,
String connectionTimeZone,
boolean forceConnectionTimeZoneToSession,
String user, @Nullable CharSequence password, @Nullable String database,
boolean createDatabaseIfNotExist, @Nullable Predicate<String> preferPrepareStatement,
List<String> sessionVariables, @Nullable Duration lockWaitTimeout, @Nullable Duration statementTimeout,
@Nullable Path loadLocalInfilePath, int localInfileBufferSize,
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
ZeroDateOption zeroDateOption,
boolean preserveInstants,
String connectionTimeZone,
boolean forceConnectionTimeZoneToSession,
String user, @Nullable CharSequence password, @Nullable String database,
boolean createDatabaseIfNotExist, @Nullable Predicate<String> preferPrepareStatement,
List<String> sessionVariables, @Nullable Duration lockWaitTimeout, @Nullable Duration statementTimeout,
@Nullable Path loadLocalInfilePath, int localInfileBufferSize,
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver, boolean tinyInt1isBit
) {
this.isHost = isHost;
this.domain = domain;
Expand Down Expand Up @@ -177,6 +179,7 @@ private MySqlConnectionConfiguration(
this.extensions = extensions;
this.passwordPublisher = passwordPublisher;
this.resolver = resolver;
this.tinyInt1isBit = tinyInt1isBit;
}

/**
Expand Down Expand Up @@ -312,6 +315,10 @@ AddressResolverGroup<?> getResolver() {
return resolver;
}

boolean getTinyInt1isBit() {
return tinyInt1isBit;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -498,6 +505,8 @@ public static final class Builder {
@Nullable
private AddressResolverGroup<?> resolver;

private boolean tinyInt1isBit = true;

/**
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
*
Expand Down Expand Up @@ -532,7 +541,7 @@ public MySqlConnectionConfiguration build() {
loadLocalInfilePath,
localInfileBufferSize, queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel, loopResources,
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver);
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver, tinyInt1isBit);
}

/**
Expand Down Expand Up @@ -1175,6 +1184,16 @@ public Builder resolver(AddressResolverGroup<?> resolver) {
return this;
}

/**
* Sets
* @param tinyInt1isBit
* @return
*/
public Builder tinyInt1isBit(boolean tinyInt1isBit) {
this.tinyInt1isBit = tinyInt1isBit;
return this;
}

private SslMode requireSslMode() {
SslMode sslMode = this.sslMode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ private static Mono<MySqlConnection> getMySqlConnection(
configuration.getLoadLocalInfilePath(),
configuration.getLocalInfileBufferSize(),
configuration.isPreserveInstants(),
configuration.getTinyInt1isBit(),
connectionTimeZone
);
}).flatMap(context -> Client.connect(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,13 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<AddressResolverGroup<?>> RESOLVER = Option.valueOf("resolver");

/**
*
* @param options
* @return
*/
public static final Option<Boolean> TINY_INT_1_IS_BIT = Option.valueOf("tinyInt1isBit");

@Override
public ConnectionFactory create(ConnectionFactoryOptions options) {
requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down Expand Up @@ -413,6 +420,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) {
.to(builder::lockWaitTimeout);
mapper.optional(STATEMENT_TIMEOUT).as(Duration.class, Duration::parse)
.to(builder::statementTimeout);
mapper.optional(TINY_INT_1_IS_BIT).asBoolean()
.to(builder::tinyInt1isBit);

return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public MySqlParameter encode(Object value, CodecContext context) {
@Override
public boolean doCanDecode(MySqlReadableMetadata metadata) {
MySqlType type = metadata.getType();
return (type == MySqlType.BIT || type == MySqlType.TINYINT) &&
return (type == MySqlType.BIT || type == MySqlType.TINYINT || type == MySqlType.TINYINT_UNSIGNED) &&
Integer.valueOf(1).equals(metadata.getPrecision());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public interface CodecContext {
*/
boolean isPreserveInstants();


boolean isTinyInt1isBit();

/**
* Gets the {@link ZoneId} of connection.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import io.asyncer.r2dbc.mysql.MySqlParameter;
import io.asyncer.r2dbc.mysql.api.MySqlReadableMetadata;
import io.asyncer.r2dbc.mysql.constant.MySqlType;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.asyncer.r2dbc.mysql.message.FieldValue;
import io.asyncer.r2dbc.mysql.message.LargeFieldValue;
Expand Down Expand Up @@ -150,11 +151,12 @@ public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, Class<?> t
return null;
}

Class<?> target = chooseClass(metadata, type);
Class<?> target = chooseClass(metadata, type, context);

if (value instanceof NormalFieldValue) {
return decodeNormal((NormalFieldValue) value, metadata, target, binary, context);
} else if (value instanceof LargeFieldValue) {
}
if (value instanceof LargeFieldValue) {
return decodeMassive((LargeFieldValue) value, metadata, target, binary, context);
}

Expand All @@ -171,9 +173,11 @@ public <T> T decode(FieldValue value, MySqlReadableMetadata metadata, Parameteri

if (value.isNull()) {
return null;
} else if (value instanceof NormalFieldValue) {
}
if (value instanceof NormalFieldValue) {
return decodeNormal((NormalFieldValue) value, metadata, type, binary, context);
} else if (value instanceof LargeFieldValue) {
}
if (value instanceof LargeFieldValue) {
return decodeMassive((LargeFieldValue) value, metadata, type, binary, context);
}

Expand Down Expand Up @@ -358,11 +362,21 @@ private <T> T decodeMassive(LargeFieldValue value, MySqlReadableMetadata metadat
* @param type the {@link Class} specified by the user.
* @return the {@link Class} to use for decoding.
*/
private static Class<?> chooseClass(MySqlReadableMetadata metadata, Class<?> type) {
Class<?> javaType = metadata.getType().getJavaType();
private static Class<?> chooseClass(MySqlReadableMetadata metadata, Class<?> type, CodecContext context) {
Class<?> javaType = resolveJavaType(metadata, context);
return type.isAssignableFrom(javaType) ? javaType : type;
}

private static Class<?> resolveJavaType(final MySqlReadableMetadata metadata, final CodecContext context) {
final MySqlType mySqlType = metadata.getType();
final Integer precision = metadata.getPrecision();
if (precision != null && precision == 1 && context.isTinyInt1isBit() &&
(mySqlType == MySqlType.TINYINT || mySqlType == MySqlType.TINYINT_UNSIGNED || mySqlType == MySqlType.BIT)) {
return Boolean.class;
}
return metadata.getType().getJavaType();
}

static final class Builder implements CodecsBuilder {

@GuardedBy("lock")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void getTimeZone() {
String id = i < 0 ? "UTC" + i : "UTC+" + i;
ConnectionContext context = new ConnectionContext(
ZeroDateOption.USE_NULL, null,
8192, true, ZoneId.of(id));
8192, true, true, ZoneId.of(id));

assertThat(context.getTimeZone()).isEqualTo(ZoneId.of(id));
}
Expand All @@ -48,7 +48,7 @@ void getTimeZone() {
@Test
void setTwiceTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, null);
8192, true, true, null);

context.initSession(
Caches.createPrepareCache(0),
Expand All @@ -70,7 +70,7 @@ void setTwiceTimeZone() {
@Test
void badSetTimeZone() {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, ZoneId.systemDefault());
8192, true, true, ZoneId.systemDefault());
assertThatIllegalStateException().isThrownBy(() -> context.initSession(
Caches.createPrepareCache(0),
IsolationLevel.REPEATABLE_READ,
Expand All @@ -91,7 +91,7 @@ public static ConnectionContext mock(boolean isMariaDB) {

public static ConnectionContext mock(boolean isMariaDB, ZoneId zoneId) {
ConnectionContext context = new ConnectionContext(ZeroDateOption.USE_NULL, null,
8192, true, zoneId);
8192, true, true, zoneId);

context.initHandshake(1, ServerVersion.parse(isMariaDB ? "11.2.22.MOCKED" : "8.0.11.MOCKED"),
Capability.of(~(isMariaDB ? 1 : 0)));
Expand Down

0 comments on commit e061056

Please sign in to comment.