Skip to content

Commit

Permalink
feat: support config AddressResolverGroup in r2dbc-mysql (#279)
Browse files Browse the repository at this point in the history
Motivation:
Currently,`AddressResolverGroup` can't be configured. The DnsResolver
default start address listen to "0.0.0.0", which may have some security
risks.
also see netty/netty#11061

Modification:
Add `AddressResolverGroup` in Client's connect method

---------

Signed-off-by: ZhangJian He <[email protected]>
  • Loading branch information
ZhangJian He authored Jul 25, 2024
1 parent 508d6c3 commit e37cbdd
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.asyncer.r2dbc.mysql.extension.Extension;
import io.asyncer.r2dbc.mysql.internal.util.InternalArrays;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import org.jetbrains.annotations.Nullable;
import org.reactivestreams.Publisher;
import reactor.netty.resources.LoopResources;
Expand Down Expand Up @@ -127,6 +128,9 @@ public final class MySqlConnectionConfiguration {
@Nullable
private final Publisher<String> passwordPublisher;

@Nullable
private final AddressResolverGroup<?> resolver;

private MySqlConnectionConfiguration(
boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
Expand All @@ -141,7 +145,8 @@ private MySqlConnectionConfiguration(
int queryCacheSize, int prepareCacheSize,
Set<CompressionAlgorithm> compressionAlgorithms, int zstdCompressionLevel,
@Nullable LoopResources loopResources,
Extensions extensions, @Nullable Publisher<String> passwordPublisher
Extensions extensions, @Nullable Publisher<String> passwordPublisher,
@Nullable AddressResolverGroup<?> resolver
) {
this.isHost = isHost;
this.domain = domain;
Expand Down Expand Up @@ -171,6 +176,7 @@ private MySqlConnectionConfiguration(
this.loopResources = loopResources == null ? TcpResources.get() : loopResources;
this.extensions = extensions;
this.passwordPublisher = passwordPublisher;
this.resolver = resolver;
}

/**
Expand Down Expand Up @@ -301,6 +307,11 @@ Publisher<String> getPasswordPublisher() {
return passwordPublisher;
}

@Nullable
AddressResolverGroup<?> getResolver() {
return resolver;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand Down Expand Up @@ -337,7 +348,8 @@ public boolean equals(Object o) {
zstdCompressionLevel == that.zstdCompressionLevel &&
Objects.equals(loopResources, that.loopResources) &&
extensions.equals(that.extensions) &&
Objects.equals(passwordPublisher, that.passwordPublisher);
Objects.equals(passwordPublisher, that.passwordPublisher) &&
Objects.equals(resolver, that.resolver);
}

@Override
Expand All @@ -352,52 +364,41 @@ public int hashCode() {
loadLocalInfilePath, localInfileBufferSize,
queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel,
loopResources, extensions, passwordPublisher);
loopResources, extensions, passwordPublisher, resolver);
}

@Override
public String toString() {
if (isHost) {
return "MySqlConnectionConfiguration{host='" + domain + "', port=" + port + ", ssl=" + ssl +
", tcpNoDelay=" + tcpNoDelay + ", tcpKeepAlive=" + tcpKeepAlive +
", connectTimeout=" + connectTimeout +
return "MySqlConnectionConfiguration{" +
(isHost ? "host='" + domain + "', port=" + port + ", ssl=" + ssl +
", tcpNoDelay=" + tcpNoDelay + ", tcpKeepAlive=" + tcpKeepAlive :
"unixSocket='" + domain + "'") +
buildCommonToStringPart() +
'}';
}

private String buildCommonToStringPart() {
return ", connectTimeout=" + connectTimeout +
", preserveInstants=" + preserveInstants +
", connectionTimeZone=" + connectionTimeZone +
", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession +
", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password +
", zeroDateOption=" + zeroDateOption +
", user='" + user + "', password=" + password +
", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist +
", preferPrepareStatement=" + preferPrepareStatement +
", sessionVariables=" + sessionVariables +
", lockWaitTimeout=" + lockWaitTimeout +
", statementTimeout=" + statementTimeout +
", loadLocalInfilePath=" + loadLocalInfilePath +
", localInfileBufferSize=" + localInfileBufferSize +
", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize +
", queryCacheSize=" + queryCacheSize +
", prepareCacheSize=" + prepareCacheSize +
", compressionAlgorithms=" + compressionAlgorithms +
", zstdCompressionLevel=" + zstdCompressionLevel +
", loopResources=" + loopResources +
", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}';
}

return "MySqlConnectionConfiguration{unixSocket='" + domain +
"', connectTimeout=" + connectTimeout +
", preserveInstants=" + preserveInstants +
", connectionTimeZone=" + connectionTimeZone +
", forceConnectionTimeZoneToSession=" + forceConnectionTimeZoneToSession +
", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password +
", database='" + database + "', createDatabaseIfNotExist=" + createDatabaseIfNotExist +
", preferPrepareStatement=" + preferPrepareStatement +
", sessionVariables=" + sessionVariables +
", lockWaitTimeout=" + lockWaitTimeout +
", statementTimeout=" + statementTimeout +
", loadLocalInfilePath=" + loadLocalInfilePath +
", localInfileBufferSize=" + localInfileBufferSize +
", queryCacheSize=" + queryCacheSize +
", prepareCacheSize=" + prepareCacheSize +
", compressionAlgorithms=" + compressionAlgorithms +
", zstdCompressionLevel=" + zstdCompressionLevel +
", loopResources=" + loopResources +
", extensions=" + extensions + ", passwordPublisher=" + passwordPublisher + '}';
", extensions=" + extensions +
", passwordPublisher=" + passwordPublisher +
", resolver=" + resolver;
}

/**
Expand Down Expand Up @@ -494,6 +495,9 @@ public static final class Builder {
@Nullable
private Publisher<String> passwordPublisher;

@Nullable
private AddressResolverGroup<?> resolver;

/**
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
*
Expand Down Expand Up @@ -528,7 +532,7 @@ public MySqlConnectionConfiguration build() {
loadLocalInfilePath,
localInfileBufferSize, queryCacheSize, prepareCacheSize,
compressionAlgorithms, zstdCompressionLevel, loopResources,
Extensions.from(extensions, autodetectExtensions), passwordPublisher);
Extensions.from(extensions, autodetectExtensions), passwordPublisher, resolver);
}

/**
Expand Down Expand Up @@ -1156,6 +1160,21 @@ public Builder passwordPublisher(Publisher<String> passwordPublisher) {
return this;
}

/**
* Sets the {@link AddressResolverGroup} for resolving host addresses.
* <p>
* This can be used to customize the DNS resolution mechanism, which is particularly useful in environments
* with specific DNS configuration needs or where a custom DNS resolver is required.
*
* @param resolver the resolver group to use for host address resolution.
* @return this {@link Builder}.
* @since 1.2.0
*/
public Builder resolver(AddressResolverGroup<?> resolver) {
this.resolver = resolver;
return this;
}

private SslMode requireSslMode() {
SslMode sslMode = this.sslMode;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ private static Mono<MySqlConnection> getMySqlConnection(
configuration.isTcpNoDelay(),
context,
configuration.getConnectTimeout(),
configuration.getLoopResources()
configuration.getLoopResources(),
configuration.getResolver()
)).flatMap(client -> {
// Lazy init database after handshake/login
boolean deferDatabase = configuration.isCreateDatabaseIfNotExist();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryProvider;
Expand Down Expand Up @@ -308,6 +309,17 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<Publisher<String>> PASSWORD_PUBLISHER = Option.valueOf("passwordPublisher");

/**
* Option to set the {@link AddressResolverGroup} for resolving host addresses.
* <p>
* This can be used to customize the DNS resolution mechanism, which is particularly useful in environments
* with specific DNS configuration needs or where a custom DNS resolver is required.
* <p>
*
* @since 1.2.0
*/
public static final Option<AddressResolverGroup<?>> RESOLVER = Option.valueOf("resolver");

@Override
public ConnectionFactory create(ConnectionFactoryOptions options) {
requireNonNull(options, "connectionFactoryOptions must not be null");
Expand Down Expand Up @@ -389,6 +401,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) {
.to(builder::loopResources);
mapper.optional(PASSWORD_PUBLISHER).as(Publisher.class)
.to(builder::passwordPublisher);
mapper.optional(RESOLVER).as(AddressResolverGroup.class)
.to(builder::resolver);
mapper.optional(SESSION_VARIABLES).asArray(
String[].class,
Function.identity(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.asyncer.r2dbc.mysql.message.server.ServerMessage;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelOption;
import io.netty.resolver.AddressResolverGroup;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -132,7 +133,7 @@ public interface Client {
*/
static Mono<Client> connect(MySqlSslConfiguration ssl, SocketAddress address, boolean tcpKeepAlive,
boolean tcpNoDelay, ConnectionContext context, @Nullable Duration connectTimeout,
LoopResources loopResources) {
LoopResources loopResources, @Nullable AddressResolverGroup<?> resolver) {
requireNonNull(ssl, "ssl must not be null");
requireNonNull(address, "address must not be null");
requireNonNull(context, "context must not be null");
Expand All @@ -150,6 +151,10 @@ static Mono<Client> connect(MySqlSslConfiguration ssl, SocketAddress address, bo
tcpClient = tcpClient.option(ChannelOption.TCP_NODELAY, tcpNoDelay);
}

if (resolver != null) {
tcpClient = tcpClient.resolver(resolver);
}

return tcpClient.remoteAddress(() -> address).connect()
.map(conn -> new ReactorNettyClient(conn, ssl, context));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.asyncer.r2dbc.mysql.extension.Extension;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.DefaultAddressResolverGroup;
import org.assertj.core.api.ObjectAssert;
import org.assertj.core.api.ThrowableTypeAssert;
import org.jetbrains.annotations.Nullable;
Expand Down Expand Up @@ -207,6 +209,19 @@ void validPasswordSupplier() {
.verifyComplete();
}

@Test
void validResolver() {
final AddressResolverGroup<?> resolver = DefaultAddressResolverGroup.INSTANCE;
AddressResolverGroup<?> resolverGroup = MySqlConnectionConfiguration.builder()
.host(HOST)
.user(USER)
.resolver(resolver)
.autodetectExtensions(false)
.build()
.getResolver();
assertThat(resolverGroup).isSameAs(resolver);
}

private static MySqlConnectionConfiguration unixSocketSslMode(SslMode sslMode) {
return MySqlConnectionConfiguration.builder()
.unixSocket(UNIX_SOCKET)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.asyncer.r2dbc.mysql.constant.SslMode;
import io.asyncer.r2dbc.mysql.constant.ZeroDateOption;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.DefaultAddressResolverGroup;
import io.r2dbc.spi.ConnectionFactories;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.Option;
Expand Down Expand Up @@ -50,6 +52,7 @@
import java.util.stream.Stream;

import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_PUBLISHER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.RESOLVER;
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT;
import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE;
Expand Down Expand Up @@ -453,6 +456,19 @@ void validPasswordSupplier() {
assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class);
}

@Test
void validResolver() {
final AddressResolverGroup<?> resolver = DefaultAddressResolverGroup.INSTANCE;
ConnectionFactoryOptions options = ConnectionFactoryOptions.builder()
.option(DRIVER, "mysql")
.option(HOST, "127.0.0.1")
.option(USER, "root")
.option(RESOLVER, resolver)
.build();

assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class);
}

@Test
void allConfigurationOptions() {
List<String> exceptConfigs = Arrays.asList(
Expand Down

0 comments on commit e37cbdd

Please sign in to comment.