From eff243d3ee31253377abe46f3debd7cae4dbc619 Mon Sep 17 00:00:00 2001 From: ggivo Date: Fri, 20 Dec 2024 21:39:11 +0200 Subject: [PATCH] Token based authentication integration with core extension (#3063) * Support for StreamingCredentials This enables use cases like credential rotation and token based auth without client disconnect. Especially with Pub/Sub clients will reduce the chnance of missing events. * Tests & publish ReauthEvent * Clean up & Format & Add ReauthenticateEvent test * Conditionally enable connection reauthentication based on client setting DEFAULT_REAUTHENTICATE_BEHAVIOUR * Client setting for enabling reauthentication - Moved Authentication handler to DefaultEndpoint - updated since 6.6.0 * formating * resolve conflict with main * format * dispath using connection handler * Support multi with re-auth Defer the re-auth operation in case there is on-going multi Tx in lettuce need to be externally synchronised when used in multithreaded env. Since re-auth happens from different thread we need to make sure it does not happen while there is ongoing transaction. * Fix EndpointId missing in events * format * Add unit tests for setCredenatials * Skip preProcessing of auth command to avoid replacing the credential provider with static one provider Add unit tests for setCredentials * clean up - remove dead code * Moved almost all code inside the new handler * fix inTransaction lock with dispatch command batch * Remove StreamingCredentialsProvider interface. move credentials() method to RedisCredentialsProvider. Resolve issue with unsafe cast after extending RedisCredentialsProvider with supportsStreaming() method * Add authentication handler to ClusterPubSub connections * Token based auth integration with core extension Provide a way for lettuce clients to use token-based authentication. TOKENs come with a TTL. After a Redis client authenticates with a TOKEN, if they didn't renew their authentication we need to evict (close) them. The suggested approach is to leverage the existing CredentialsProvider and add support for streaming credentials to handle token refresh scenarios. Each time a new token is received connection is reauthenticated. * rebase to address "oid" core-autx lib change formating * Add EntraId integration tests Verify authentication using Azure AD with service principals * StreamingCredentialsProvider replaced with RedisCredentialsProvider.supportsStreaming() * pub/sub test basic functionality with entraid auth * Update src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java Co-authored-by: Tihomir Krasimirov Mateev * Addressing review comments from @tishun * Bump redis-authx-core & redis-authx-entraid from 0.1.0-SNAPSHOT to 0.1.1-beta1 * add java doc for TokenBasedRedisCredentialsProvider --------- Co-authored-by: Tihomir Mateev Co-authored-by: Tihomir Krasimirov Mateev --- pom.xml | 18 +- .../TokenBasedRedisCredentialsProvider.java | 137 ++++++ .../java/io/lettuce/core/ClientOptions.java | 83 +++- .../core/RedisAuthenticationHandler.java | 406 ++++++++++++++++++ .../java/io/lettuce/core/RedisClient.java | 12 +- .../core/RedisCredentialsProvider.java | 29 ++ .../core/StatefulRedisConnectionImpl.java | 52 ++- .../core/cluster/ClusterClientOptions.java | 19 +- .../core/cluster/RedisClusterClient.java | 12 + .../event/connection/AuthenticationEvent.java | 24 ++ .../core/event/connection/JfrReauthEvent.java | 43 ++ .../connection/JfrReauthFailedEvent.java | 43 ++ .../connection/ReauthenticationEvent.java | 34 ++ .../ReauthenticationFailedEvent.java | 56 +++ .../authx/EntraIdIntegrationTests.java | 195 +++++++++ .../io/lettuce/authx/EntraIdTestContext.java | 111 +++++ ...okenBasedRedisCredentialsProviderTest.java | 158 +++++++ .../core/AuthenticationIntegrationTests.java | 122 ++++++ .../MyStreamingRedisCredentialsProvider.java | 49 +++ .../RedisAuthenticationHandlerUnitTests.java | 214 +++++++++ .../RedisClientConnectIntegrationTests.java | 18 + .../io/lettuce/core/TestTokenManager.java | 50 +++ .../ClusterClientOptionsIntegrationTests.java | 23 + ...gCredentialsProviderlIntegrationTests.java | 172 ++++++++ ...ectionEventsTriggeredIntegrationTests.java | 38 ++ .../examples/TokenBasedAuthExample.java | 138 ++++++ src/test/resources/.env.entraid | 11 + 27 files changed, 2245 insertions(+), 22 deletions(-) create mode 100644 src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java create mode 100644 src/main/java/io/lettuce/core/RedisAuthenticationHandler.java create mode 100644 src/main/java/io/lettuce/core/event/connection/AuthenticationEvent.java create mode 100644 src/main/java/io/lettuce/core/event/connection/JfrReauthEvent.java create mode 100644 src/main/java/io/lettuce/core/event/connection/JfrReauthFailedEvent.java create mode 100644 src/main/java/io/lettuce/core/event/connection/ReauthenticationEvent.java create mode 100644 src/main/java/io/lettuce/core/event/connection/ReauthenticationFailedEvent.java create mode 100644 src/test/java/io/lettuce/authx/EntraIdIntegrationTests.java create mode 100644 src/test/java/io/lettuce/authx/EntraIdTestContext.java create mode 100644 src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java create mode 100644 src/test/java/io/lettuce/core/MyStreamingRedisCredentialsProvider.java create mode 100644 src/test/java/io/lettuce/core/RedisAuthenticationHandlerUnitTests.java create mode 100644 src/test/java/io/lettuce/core/TestTokenManager.java create mode 100644 src/test/java/io/lettuce/core/cluster/RedisClusterStreamingCredentialsProviderlIntegrationTests.java create mode 100644 src/test/java/io/lettuce/examples/TokenBasedAuthExample.java create mode 100644 src/test/resources/.env.entraid diff --git a/pom.xml b/pom.xml index e92ed704ea..a7334e25ad 100644 --- a/pom.xml +++ b/pom.xml @@ -178,7 +178,23 @@ - + + redis.clients.authentication + redis-authx-core + 0.1.1-beta1 + + + redis.clients.authentication + redis-authx-entraid + 0.1.1-beta1 + test + + + io.github.cdimascio + dotenv-java + 2.2.0 + test + diff --git a/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java b/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java new file mode 100644 index 0000000000..8eb12e7bc2 --- /dev/null +++ b/src/main/java/io/lettuce/authx/TokenBasedRedisCredentialsProvider.java @@ -0,0 +1,137 @@ +/* + * Copyright 2024, Redis Ltd. and Contributors + * All rights reserved. + * + * Licensed under the MIT License. + */ +package io.lettuce.authx; + +import io.lettuce.core.RedisCredentials; +import io.lettuce.core.RedisCredentialsProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import redis.clients.authentication.core.Token; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.core.TokenListener; +import redis.clients.authentication.core.TokenManager; + +/** + * A {@link RedisCredentialsProvider} implementation that supports token-based authentication for Redis. + *

+ * This provider uses a {@link TokenManager} to manage and renew tokens, ensuring that the Redis client can authenticate with + * Redis using a dynamically updated token. This is particularly useful in scenarios where Redis access is controlled via + * token-based authentication, such as when Redis is integrated with an identity provider like EntraID. + *

+ *

+ * The provider supports streaming of credentials and automatically emits new credentials whenever a token is renewed. It must + * be used with {@link io.lettuce.core.ClientOptions.ReauthenticateBehavior#ON_NEW_CREDENTIALS} to automatically re-authenticate + * connections whenever new tokens are emitted by the provider. + *

+ *

+ * The lifecycle of this provider is externally managed. It should be closed when there are no longer any connections using it, + * to stop the token management process and release resources. + *

+ * + * @since 6.6 + */ +public class TokenBasedRedisCredentialsProvider implements RedisCredentialsProvider, AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(TokenBasedRedisCredentialsProvider.class); + + private final TokenManager tokenManager; + + private final Sinks.Many credentialsSink = Sinks.many().replay().latest(); + + private TokenBasedRedisCredentialsProvider(TokenManager tokenManager) { + this.tokenManager = tokenManager; + } + + private void init() { + + TokenListener listener = new TokenListener() { + + @Override + public void onTokenRenewed(Token token) { + String username = token.getUser(); + char[] pass = token.getValue().toCharArray(); + RedisCredentials credentials = RedisCredentials.just(username, pass); + credentialsSink.tryEmitNext(credentials); + } + + @Override + public void onError(Exception exception) { + log.error("Token renew failed!", exception); + } + + }; + + try { + tokenManager.start(listener, false); + } catch (Exception e) { + credentialsSink.tryEmitError(e); + tokenManager.stop(); + throw new RuntimeException("Failed to start TokenManager", e); + } + } + + /** + * Resolve the latest available credentials as a Mono. + *

+ * This method returns a Mono that emits the most recent set of Redis credentials. The Mono will complete once the + * credentials are emitted. If no credentials are available at the time of subscription, the Mono will wait until + * credentials are available. + * + * @return a Mono that emits the latest Redis credentials + */ + @Override + public Mono resolveCredentials() { + + return credentialsSink.asFlux().next(); + } + + /** + * Expose the Flux for all credential updates. + *

+ * This method returns a Flux that emits all updates to the Redis credentials. Subscribers will receive the latest + * credentials whenever they are updated. The Flux will continue to emit updates until the provider is shut down. + * + * @return a Flux that emits all updates to the Redis credentials + */ + @Override + public Flux credentials() { + + return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials + } + + @Override + public boolean supportsStreaming() { + return true; + } + + /** + * Stop the credentials provider and clean up resources. + *

+ * This method stops the TokenManager and completes the credentials sink, ensuring that all resources are properly released. + * It should be called when the credentials provider is no longer needed. + */ + @Override + public void close() { + credentialsSink.tryEmitComplete(); + tokenManager.stop(); + } + + public static TokenBasedRedisCredentialsProvider create(TokenAuthConfig tokenAuthConfig) { + return create(new TokenManager(tokenAuthConfig.getIdentityProviderConfig().getProvider(), + tokenAuthConfig.getTokenManagerConfig())); + } + + public static TokenBasedRedisCredentialsProvider create(TokenManager tokenManager) { + TokenBasedRedisCredentialsProvider credentialManager = new TokenBasedRedisCredentialsProvider(tokenManager); + credentialManager.init(); + return credentialManager; + } + +} diff --git a/src/main/java/io/lettuce/core/ClientOptions.java b/src/main/java/io/lettuce/core/ClientOptions.java index 52a5a950da..3fd635e4ee 100644 --- a/src/main/java/io/lettuce/core/ClientOptions.java +++ b/src/main/java/io/lettuce/core/ClientOptions.java @@ -55,6 +55,8 @@ public class ClientOptions implements Serializable { public static final DisconnectedBehavior DEFAULT_DISCONNECTED_BEHAVIOR = DisconnectedBehavior.DEFAULT; + public static final ReauthenticateBehavior DEFAULT_REAUTHENTICATE_BEHAVIOUR = ReauthenticateBehavior.DEFAULT; + public static final boolean DEFAULT_PUBLISH_ON_SCHEDULER = false; public static final boolean DEFAULT_PING_BEFORE_ACTIVATE_CONNECTION = true; @@ -95,6 +97,8 @@ public class ClientOptions implements Serializable { private final DisconnectedBehavior disconnectedBehavior; + private final ReauthenticateBehavior reauthenticateBehavior; + private final boolean publishOnScheduler; private final boolean pingBeforeActivateConnection; @@ -124,6 +128,7 @@ protected ClientOptions(Builder builder) { this.cancelCommandsOnReconnectFailure = builder.cancelCommandsOnReconnectFailure; this.decodeBufferPolicy = builder.decodeBufferPolicy; this.disconnectedBehavior = builder.disconnectedBehavior; + this.reauthenticateBehavior = builder.reauthenticateBehavior; this.publishOnScheduler = builder.publishOnScheduler; this.pingBeforeActivateConnection = builder.pingBeforeActivateConnection; this.protocolVersion = builder.protocolVersion; @@ -143,6 +148,7 @@ protected ClientOptions(ClientOptions original) { this.cancelCommandsOnReconnectFailure = original.isCancelCommandsOnReconnectFailure(); this.decodeBufferPolicy = original.getDecodeBufferPolicy(); this.disconnectedBehavior = original.getDisconnectedBehavior(); + this.reauthenticateBehavior = original.getReauthenticateBehaviour(); this.publishOnScheduler = original.isPublishOnScheduler(); this.pingBeforeActivateConnection = original.isPingBeforeActivateConnection(); this.protocolVersion = original.getConfiguredProtocolVersion(); @@ -220,6 +226,8 @@ public static class Builder { private TimeoutOptions timeoutOptions = DEFAULT_TIMEOUT_OPTIONS; + private ReauthenticateBehavior reauthenticateBehavior = DEFAULT_REAUTHENTICATE_BEHAVIOUR; + private boolean useHashIndexedQueue = DEFAULT_USE_HASH_INDEX_QUEUE; protected Builder() { @@ -301,6 +309,20 @@ public Builder disconnectedBehavior(DisconnectedBehavior disconnectedBehavior) { return this; } + /** + * Configure the {@link ReauthenticateBehavior} of the Lettuce driver. Defaults to + * {@link ReauthenticateBehavior#DEFAULT}. + * + * @param reauthenticateBehavior the {@link ReauthenticateBehavior} to use. Must not be {@code null}. + * @return {@code this} + */ + public Builder reauthenticateBehavior(ReauthenticateBehavior reauthenticateBehavior) { + + LettuceAssert.notNull(reauthenticateBehavior, "ReuthenticatBehavior must not be null"); + this.reauthenticateBehavior = reauthenticateBehavior; + return this; + } + /** * Perform a lightweight {@literal PING} connection handshake when establishing a Redis connection. If {@code true} * (default is {@code true}, {@link #DEFAULT_PING_BEFORE_ACTIVATE_CONNECTION}), every connection and reconnect will @@ -505,11 +527,12 @@ public ClientOptions.Builder mutate() { builder.autoReconnect(isAutoReconnect()).cancelCommandsOnReconnectFailure(isCancelCommandsOnReconnectFailure()) .decodeBufferPolicy(getDecodeBufferPolicy()).disconnectedBehavior(getDisconnectedBehavior()) - .readOnlyCommands(getReadOnlyCommands()).publishOnScheduler(isPublishOnScheduler()) - .pingBeforeActivateConnection(isPingBeforeActivateConnection()).protocolVersion(getConfiguredProtocolVersion()) - .requestQueueSize(getRequestQueueSize()).scriptCharset(getScriptCharset()).jsonParser(getJsonParser()) - .socketOptions(getSocketOptions()).sslOptions(getSslOptions()) - .suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()).timeoutOptions(getTimeoutOptions()); + .reauthenticateBehavior(getReauthenticateBehaviour()).readOnlyCommands(getReadOnlyCommands()) + .publishOnScheduler(isPublishOnScheduler()).pingBeforeActivateConnection(isPingBeforeActivateConnection()) + .protocolVersion(getConfiguredProtocolVersion()).requestQueueSize(getRequestQueueSize()) + .scriptCharset(getScriptCharset()).jsonParser(getJsonParser()).socketOptions(getSocketOptions()) + .sslOptions(getSslOptions()).suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()) + .timeoutOptions(getTimeoutOptions()); return builder; } @@ -573,6 +596,16 @@ public DisconnectedBehavior getDisconnectedBehavior() { return disconnectedBehavior; } + /** + * Behavior for re-authentication when the {@link RedisCredentialsProvider} emits new credentials. Defaults to + * {@link ReauthenticateBehavior#DEFAULT}. + * + * @return the currently set {@link ReauthenticateBehavior}. + */ + public ReauthenticateBehavior getReauthenticateBehaviour() { + return reauthenticateBehavior; + } + /** * Predicate to identify commands as read-only. Defaults to {@link #DEFAULT_READ_ONLY_COMMANDS}. * @@ -704,6 +737,46 @@ public TimeoutOptions getTimeoutOptions() { return timeoutOptions; } + /** + * Defines the re-authentication behavior of the Redis client. + *

+ * Certain implementations of the {@link RedisCredentialsProvider} could emit new credentials at runtime. This setting + * controls how the driver reacts to these newly emitted credentials. + */ + public enum ReauthenticateBehavior { + + /** + * This is the default behavior. The client will fetch current credentials from the underlying + * {@link RedisCredentialsProvider} only when the driver needs to, e.g. when the connection is first established or when + * it is re-established after a disconnect. + *

+ *

+ * No re-authentication is performed when new credentials are emitted by a {@link RedisCredentialsProvider} that + * supports streaming. The client does not subscribe to or react to any updates in the credential stream provided by + * {@link RedisCredentialsProvider#credentials()}. + *

+ */ + DEFAULT, + + /** + * Automatically triggers re-authentication whenever new credentials are emitted by a {@link RedisCredentialsProvider} + * that supports streaming, as indicated by {@link RedisCredentialsProvider#supportsStreaming()}. + * + *

+ * When this behavior is enabled, the client subscribes to the credential stream provided by + * {@link RedisCredentialsProvider#credentials()} and issues an {@code AUTH} command to the Redis server each time new + * credentials are received. This behavior supports dynamic credential scenarios, such as token-based authentication, or + * credential rotation where credentials are refreshed periodically to maintain access. + *

+ * + *

+ * Note: {@code AUTH} commands issued as part of this behavior may interleave with user-submitted commands, as the + * client performs re-authentication independently of user command flow. + *

+ */ + ON_NEW_CREDENTIALS + } + /** * Whether we should use hash indexed queue, which provides O(1) remove(Object) * diff --git a/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java new file mode 100644 index 0000000000..81b0bc9dd5 --- /dev/null +++ b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java @@ -0,0 +1,406 @@ +/* + * Copyright 2024, Redis Ltd. and Contributors + * All rights reserved. + * + * Licensed under the MIT License. + */ +package io.lettuce.core; + +import io.lettuce.core.api.async.RedisAsyncCommands; +import io.lettuce.core.codec.RedisCodec; +import io.lettuce.core.event.connection.ReauthenticationEvent; +import io.lettuce.core.event.connection.ReauthenticationFailedEvent; +import io.lettuce.core.internal.LettuceAssert; +import io.lettuce.core.output.StatusOutput; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.Command; +import io.lettuce.core.protocol.CommandArgs; +import io.lettuce.core.protocol.CommandExpiryWriter; +import io.lettuce.core.protocol.CompleteableCommand; +import io.lettuce.core.protocol.Endpoint; +import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.core.protocol.RedisCommand; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; + +import java.util.Collection; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; + +import static io.lettuce.core.protocol.CommandType.AUTH; +import static io.lettuce.core.protocol.CommandType.DISCARD; +import static io.lettuce.core.protocol.CommandType.EXEC; +import static io.lettuce.core.protocol.CommandType.MULTI; + +/** + * Redis authentication handler. Internally used to authenticate a Redis connection. This class is part of the internal API. + * + * @author Ivo Gaydazhiev + * @since 6.6.0 + */ +public class RedisAuthenticationHandler { + + private static final InternalLogger log = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class); + + private final StatefulRedisConnectionImpl connection; + + private final RedisCredentialsProvider credentialsProvider; + + private final AtomicReference credentialsSubscription = new AtomicReference<>(); + + private final Boolean isPubSubConnection; + + private final AtomicReference credentialsRef = new AtomicReference<>(); + + private final ReentrantLock reAuthSafety = new ReentrantLock(); + + private final AtomicBoolean inTransaction = new AtomicBoolean(false); + + /** + * Creates a new {@link RedisAuthenticationHandler}. + * + * @param connection the connection to authenticate + * @param credentialsProvider the implementation of {@link RedisCredentialsProvider} to use + * @param isPubSubConnection {@code true} if the connection is a pub/sub connection + */ + public RedisAuthenticationHandler(StatefulRedisConnectionImpl connection, + RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection) { + this.connection = connection; + this.credentialsProvider = credentialsProvider; + this.isPubSubConnection = isPubSubConnection; + } + + /** + * Creates a new {@link RedisAuthenticationHandler} if the connection supports re-authentication. + * + * @param connection the connection to authenticate + * @param credentialsProvider the implementation of {@link RedisCredentialsProvider} to use + * @param isPubSubConnection {@code true} if the connection is a pub/sub connection + * @param options the {@link ClientOptions} to use + * @return a new {@link RedisAuthenticationHandler} if the connection supports re-authentication, otherwise an + * implementation of the {@link RedisAuthenticationHandler} that does nothing + * @since 6.6.0 + * @see RedisCredentialsProvider + */ + public static RedisAuthenticationHandler createHandler(StatefulRedisConnectionImpl connection, + RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection, ClientOptions options) { + + if (isSupported(options)) { + + if (isPubSubConnection && options.getConfiguredProtocolVersion() == ProtocolVersion.RESP2) { + throw new RedisConnectionException( + "Renewable credentials are not supported with RESP2 protocol on a pub/sub connection."); + } + + return new RedisAuthenticationHandler<>(connection, credentialsProvider, isPubSubConnection); + } + + return null; + } + + /** + * Creates a new default {@link RedisAuthenticationHandler}. + *

+ * The default {@link RedisAuthenticationHandler} is used when re-authentication is not supported. + * + * @return a new {@link RedisAuthenticationHandler} + * @since 6.6.0 + * @see RedisCredentialsProvider + */ + public static RedisAuthenticationHandler createDefaultAuthenticationHandler() { + return new DisabledAuthenticationHandler<>(); + } + + /** + * This method subscribes to a stream of credentials provided by the `StreamingCredentialsProvider`. + *

+ * Each time new credentials are received, the client is re-authenticated. The previous subscription, if any, is disposed of + * before setting the new subscription. + */ + public void subscribe() { + if (credentialsProvider == null || !credentialsProvider.supportsStreaming()) { + return; + } + + if (!isSupportedConnection()) { + return; + } + + Flux credentialsFlux = credentialsProvider.credentials(); + + Disposable subscription = credentialsFlux.subscribe(this::onNext, this::onError, this::complete); + + Disposable oldSubscription = credentialsSubscription.getAndSet(subscription); + if (oldSubscription != null && !oldSubscription.isDisposed()) { + oldSubscription.dispose(); + } + } + + /** + * Unsubscribes from the current credentials stream. + */ + public void unsubscribe() { + Disposable subscription = credentialsSubscription.getAndSet(null); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + } + + protected void complete() { + log.debug("Credentials stream completed"); + } + + protected void onNext(RedisCredentials credentials) { + reauthenticate(credentials); + } + + protected void onError(Throwable e) { + log.error("Credentials renew failed.", e); + publishReauthFailedEvent(e); + } + + /** + * Performs re-authentication with the provided credentials. + * + * @param credentials the new credentials + */ + protected void reauthenticate(RedisCredentials credentials) { + setCredentials(credentials); + } + + boolean isSupportedConnection() { + if (isPubSubConnection && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) { + log.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection."); + return false; + } + return true; + } + + private static boolean isSupported(ClientOptions clientOptions) { + LettuceAssert.notNull(clientOptions, "ClientOptions must not be null"); + switch (clientOptions.getReauthenticateBehaviour()) { + case ON_NEW_CREDENTIALS: + return true; + case DEFAULT: + default: + return false; + } + } + + /** + * Post-processes the command after it is sent to the server. + *

+ * If the command type is either {@link RedisCommand.Type#EXEC} or {@link RedisCommand.Type#DISCARD}, the transaction state + * is cleared and a check for deferred credentials is initiated. + *

+ * + * @param toSend the command to post-process + */ + protected void postProcess(RedisCommand toSend) { + if (toSend.getType() == EXEC || toSend.getType() == DISCARD) { + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); + } + } + + /** + * Post-processes a collection of dispatched commands after they are sent to the server. + *

+ * This method checks if any of the dispatched commands indicate the completion of a transaction (via + * {@link RedisCommand.Type#EXEC} or {@link RedisCommand.Type#DISCARD}). If the transaction is complete, it clears the + * transaction state and initiates a check for deferred credentials. + *

+ * + * @param dispatched the collection of dispatched commands to post-process + */ + protected void postProcess(Collection> dispatched) { + Boolean transactionComplete = null; + for (RedisCommand command : dispatched) { + if (command.getType() == EXEC || command.getType() == DISCARD) { + transactionComplete = true; + } + if (command.getType() == MULTI) { + transactionComplete = false; + } + } + + if (transactionComplete != null) { + if (transactionComplete) { + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); + } + } + } + + /** + * Marks that the current connection has started a transaction. + *

+ * During transactions, any re-authentication attempts are deferred until the transaction ends. + */ + public void startTransaction() { + reAuthSafety.lock(); + try { + inTransaction.set(true); + } finally { + reAuthSafety.unlock(); + } + } + + /** + * Marks that the current connection has ended the transaction. + *

+ * After a transaction is completed, any deferred re-authentication attempts are dispatched. + */ + public void endTransaction() { + inTransaction.set(false); + setCredentials(credentialsRef.getAndSet(null)); + } + + /** + * Authenticates the current connection using the provided credentials. + *

+ * Unlike using dispatch of {@link RedisAsyncCommands#auth}, this method defers the {@code AUTH} command if the connection + * is within an active transaction. The authentication command will only be dispatched after the enclosing {@code DISCARD} + * or {@code EXEC} command is executed, ensuring that authentication does not interfere with ongoing transactions. + *

+ * + * @param credentials the {@link RedisCredentials} to authenticate the connection. If {@code null}, no action is performed. + * + *

+ * Behavior: + *

    + *
  • If the provided credentials are {@code null}, the method exits immediately.
  • + *
  • If a transaction is active (as indicated by {@code inTransaction}), the {@code AUTH} command is not dispatched + * immediately but deferred until the transaction ends.
  • + *
  • If no transaction is active, the {@code AUTH} command is dispatched immediately using the provided + * credentials.
  • + *
+ *

+ * + * @see RedisAsyncCommands#auth + */ + public void setCredentials(RedisCredentials credentials) { + if (credentials == null) { + return; + } + reAuthSafety.lock(); + try { + credentialsRef.set(credentials); + if (!inTransaction.get()) { + dispatchAuth(credentialsRef.getAndSet(null)); + } + } finally { + reAuthSafety.unlock(); + } + } + + protected void dispatchAuth(RedisCredentials credentials) { + if (credentials == null) { + return; + } + + // dispatch directly to avoid AUTH preprocessing overrides credentials provider + RedisCommand auth = connection.getChannelWriter().write(authCommand(credentials)); + if (auth instanceof CompleteableCommand) { + ((CompleteableCommand) auth).onComplete((status, throwable) -> { + if (throwable != null) { + log.error("Re-authentication failed {}.", getEpid(), throwable); + publishReauthFailedEvent(throwable); + } else { + log.info("Re-authentication succeeded {}.", getEpid()); + publishReauthEvent(); + } + }); + } + } + + private AsyncCommand authCommand(RedisCredentials credentials) { + RedisCodec codec = connection.getCodec(); + CommandArgs args = new CommandArgs<>(codec); + if (credentials.getUsername() != null) { + args.add(credentials.getUsername()).add(credentials.getPassword()); + } else { + args.add(credentials.getPassword()); + } + return new AsyncCommand<>(new Command<>(AUTH, new StatusOutput<>(codec), args)); + } + + private void publishReauthEvent() { + connection.getResources().eventBus().publish(new ReauthenticationEvent(getEpid())); + } + + private void publishReauthFailedEvent(Throwable throwable) { + connection.getResources().eventBus().publish(new ReauthenticationFailedEvent(getEpid(), throwable)); + } + + private String getEpid() { + RedisChannelWriter writer = connection.getChannelWriter(); + while (!(writer instanceof Endpoint)) { + + if (writer instanceof CommandListenerWriter) { + writer = ((CommandListenerWriter) writer).getDelegate(); + continue; + } + + if (writer instanceof CommandExpiryWriter) { + writer = ((CommandExpiryWriter) writer).getDelegate(); + continue; + } + return null; + } + + return ((Endpoint) writer).getId(); + } + + private static final class DisabledAuthenticationHandler extends RedisAuthenticationHandler { + + public DisabledAuthenticationHandler(StatefulRedisConnectionImpl connection, + RedisCredentialsProvider credentialsProvider, Boolean isPubSubConnection) { + super(null, null, null); + } + + public DisabledAuthenticationHandler() { + super(null, null, null); + } + + @Override + protected void postProcess(RedisCommand toSend) { + // No-op + } + + @Override + protected void postProcess(Collection> dispatched) { + // No-op + } + + @Override + public void startTransaction() { + // No-op + } + + @Override + public void endTransaction() { + // No-op + } + + @Override + public void setCredentials(RedisCredentials credentials) { + // No-op + } + + @Override + public void unsubscribe() { + // No-op + } + + @Override + public void subscribe() { + // No-op + } + + } + +} diff --git a/src/main/java/io/lettuce/core/RedisClient.java b/src/main/java/io/lettuce/core/RedisClient.java index 4a2c3e7bd3..78b28d3a10 100644 --- a/src/main/java/io/lettuce/core/RedisClient.java +++ b/src/main/java/io/lettuce/core/RedisClient.java @@ -19,6 +19,7 @@ */ package io.lettuce.core; +import static io.lettuce.core.RedisAuthenticationHandler.createHandler; import static io.lettuce.core.internal.LettuceStrings.*; import java.net.InetSocketAddress; @@ -38,7 +39,6 @@ import io.lettuce.core.internal.ExceptionFactory; import io.lettuce.core.internal.Futures; import io.lettuce.core.internal.LettuceAssert; -import io.lettuce.core.json.JsonParser; import io.lettuce.core.masterreplica.MasterReplica; import io.lettuce.core.protocol.CommandExpiryWriter; import io.lettuce.core.protocol.CommandHandler; @@ -288,8 +288,9 @@ private ConnectionFuture> connectStandalone } StatefulRedisConnectionImpl connection = newStatefulRedisConnection(writer, endpoint, codec, timeout); + ConnectionFuture> future = connectStatefulAsync(connection, endpoint, redisURI, - () -> new CommandHandler(getOptions(), getResources(), endpoint)); + () -> new CommandHandler(getOptions(), getResources(), endpoint), false); future.whenComplete((channelHandler, throwable) -> { @@ -303,7 +304,7 @@ private ConnectionFuture> connectStandalone @SuppressWarnings("unchecked") private ConnectionFuture connectStatefulAsync(StatefulRedisConnectionImpl connection, Endpoint endpoint, - RedisURI redisURI, Supplier commandHandlerSupplier) { + RedisURI redisURI, Supplier commandHandlerSupplier, Boolean isPubSub) { ConnectionBuilder connectionBuilder; if (redisURI.isSsl()) { @@ -317,7 +318,8 @@ private ConnectionFuture connectStatefulAsync(StatefulRedisConnecti ConnectionState state = connection.getConnectionState(); state.apply(redisURI); state.setDb(redisURI.getDatabase()); - + connection + .setAuthenticationHandler(createHandler(connection, redisURI.getCredentialsProvider(), isPubSub, getOptions())); connectionBuilder.connection(connection); connectionBuilder.clientOptions(getOptions()); connectionBuilder.clientResources(getResources()); @@ -421,7 +423,7 @@ private ConnectionFuture> connectPubS StatefulRedisPubSubConnectionImpl connection = newStatefulRedisPubSubConnection(endpoint, writer, codec, timeout); ConnectionFuture> future = connectStatefulAsync(connection, endpoint, redisURI, - () -> new PubSubCommandHandler<>(getOptions(), getResources(), codec, endpoint)); + () -> new PubSubCommandHandler<>(getOptions(), getResources(), codec, endpoint), true); return future.whenComplete((conn, throwable) -> { diff --git a/src/main/java/io/lettuce/core/RedisCredentialsProvider.java b/src/main/java/io/lettuce/core/RedisCredentialsProvider.java index afaef0ae7c..9c57a280af 100644 --- a/src/main/java/io/lettuce/core/RedisCredentialsProvider.java +++ b/src/main/java/io/lettuce/core/RedisCredentialsProvider.java @@ -2,6 +2,7 @@ import java.util.function.Supplier; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import io.lettuce.core.internal.LettuceAssert; @@ -41,6 +42,34 @@ static RedisCredentialsProvider from(Supplier supplier) { return () -> Mono.fromSupplier(supplier); } + /** + * Some implementations of the {@link RedisCredentialsProvider} may support streaming new credentials, based on some event + * that originates outside the driver. In this case they should indicate that so the {@link RedisAuthenticationHandler} is + * able to process these new credentials. + * + * @return whether the {@link RedisCredentialsProvider} supports streaming credentials. + */ + default boolean supportsStreaming() { + return false; + } + + /** + * Returns a {@link Flux} emitting {@link RedisCredentials} that can be used to authorize a Redis connection. + * + * For implementations that support streaming credentials (as indicated by {@link #supportsStreaming()} returning + * {@code true}), this method can emit multiple credentials over time, typically based on external events like token renewal + * or rotation. + * + * For implementations that do not support streaming credentials (where {@link #supportsStreaming()} returns {@code false}), + * this method throws an {@link UnsupportedOperationException} by default. + * + * @return a {@link Flux} emitting {@link RedisCredentials}, or throws an exception if streaming is not supported. + * @throws UnsupportedOperationException if the provider does not support streaming credentials. + */ + default Flux credentials() { + throw new UnsupportedOperationException("Streaming credentials are not supported by this provider."); + } + /** * Extension to {@link RedisCredentialsProvider} that resolves credentials immediately without the need to defer the * credential resolution. diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java index 14ba7b5701..b51ee8ffae 100644 --- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java +++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java @@ -71,6 +71,8 @@ public class StatefulRedisConnectionImpl extends RedisChannelHandler protected MultiOutput multi; + private RedisAuthenticationHandler authHandler = RedisAuthenticationHandler.createDefaultAuthenticationHandler(); + /** * Initialize a new connection. * @@ -181,20 +183,41 @@ public boolean isMulti() { public RedisCommand dispatch(RedisCommand command) { RedisCommand toSend = preProcessCommand(command); - return super.dispatch(toSend); + RedisCommand result = super.dispatch(toSend); + RedisCommand finalCommand = postProcessCommand(result); + + return finalCommand; } @Override public Collection> dispatch(Collection> commands) { + Collection> sentCommands = preProcessCommands(commands); + + Collection> dispatchedCommands = super.dispatch(sentCommands); + + return this.postProcessCommands(dispatchedCommands); + } + + protected Collection> postProcessCommands(Collection> commands) { + authHandler.postProcess(commands); + return commands; + } + + protected RedisCommand postProcessCommand(RedisCommand command) { + authHandler.postProcess(command); + return command; + } + + protected Collection> preProcessCommands(Collection> commands) { List> sentCommands = new ArrayList<>(commands.size()); commands.forEach(o -> { - RedisCommand command = preProcessCommand(o); - sentCommands.add(command); + RedisCommand preprocessed = preProcessCommand(o); + sentCommands.add(preprocessed); }); - return super.dispatch(sentCommands); + return sentCommands; } // TODO [tihomir.mateev] Refactor to include as part of the Command interface @@ -272,13 +295,14 @@ protected RedisCommand preProcessCommand(RedisCommand comm } if (commandType.equals(MULTI.name())) { - + authHandler.startTransaction(); multi = (multi == null ? new MultiOutput<>(codec) : multi); if (command instanceof CompleteableCommand) { ((CompleteableCommand) command).onComplete((ignored, e) -> { if (e != null) { multi = null; + authHandler.endTransaction(); } }); } @@ -315,4 +339,22 @@ public ConnectionState getConnectionState() { return state; } + @Override + public void activated() { + super.activated(); + authHandler.subscribe(); + } + + @Override + public void deactivated() { + authHandler.unsubscribe(); + super.deactivated(); + } + + public void setAuthenticationHandler(RedisAuthenticationHandler handler) { + if (handler != null) { + authHandler = handler; + } + } + } diff --git a/src/main/java/io/lettuce/core/cluster/ClusterClientOptions.java b/src/main/java/io/lettuce/core/cluster/ClusterClientOptions.java index 11b90828fb..832c310f54 100644 --- a/src/main/java/io/lettuce/core/cluster/ClusterClientOptions.java +++ b/src/main/java/io/lettuce/core/cluster/ClusterClientOptions.java @@ -134,6 +134,7 @@ public static ClusterClientOptions.Builder builder(ClientOptions clientOptions) .cancelCommandsOnReconnectFailure(clientOptions.isCancelCommandsOnReconnectFailure()) .decodeBufferPolicy(clientOptions.getDecodeBufferPolicy()) .disconnectedBehavior(clientOptions.getDisconnectedBehavior()) + .reauthenticateBehavior(clientOptions.getReauthenticateBehaviour()) .pingBeforeActivateConnection(clientOptions.isPingBeforeActivateConnection()) .publishOnScheduler(clientOptions.isPublishOnScheduler()) .protocolVersion(clientOptions.getConfiguredProtocolVersion()) @@ -218,6 +219,12 @@ public Builder disconnectedBehavior(DisconnectedBehavior disconnectedBehavior) { return this; } + @Override + public Builder reauthenticateBehavior(ReauthenticateBehavior reauthenticateBehavior) { + super.reauthenticateBehavior(reauthenticateBehavior); + return this; + } + /** * Number of maximal cluster redirects ({@literal -MOVED} and {@literal -ASK}) to follow in case a key was moved from * one node to another node. Defaults to {@literal 5}. See {@link ClusterClientOptions#DEFAULT_MAX_REDIRECTS}. @@ -355,12 +362,12 @@ public ClusterClientOptions.Builder mutate() { builder.autoReconnect(isAutoReconnect()).cancelCommandsOnReconnectFailure(isCancelCommandsOnReconnectFailure()) .decodeBufferPolicy(getDecodeBufferPolicy()).disconnectedBehavior(getDisconnectedBehavior()) - .maxRedirects(getMaxRedirects()).publishOnScheduler(isPublishOnScheduler()) - .pingBeforeActivateConnection(isPingBeforeActivateConnection()).protocolVersion(getConfiguredProtocolVersion()) - .readOnlyCommands(getReadOnlyCommands()).requestQueueSize(getRequestQueueSize()) - .scriptCharset(getScriptCharset()).socketOptions(getSocketOptions()).sslOptions(getSslOptions()) - .suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()).timeoutOptions(getTimeoutOptions()) - .topologyRefreshOptions(getTopologyRefreshOptions()) + .reauthenticateBehavior(getReauthenticateBehaviour()).maxRedirects(getMaxRedirects()) + .publishOnScheduler(isPublishOnScheduler()).pingBeforeActivateConnection(isPingBeforeActivateConnection()) + .protocolVersion(getConfiguredProtocolVersion()).readOnlyCommands(getReadOnlyCommands()) + .requestQueueSize(getRequestQueueSize()).scriptCharset(getScriptCharset()).socketOptions(getSocketOptions()) + .sslOptions(getSslOptions()).suspendReconnectOnProtocolFailure(isSuspendReconnectOnProtocolFailure()) + .timeoutOptions(getTimeoutOptions()).topologyRefreshOptions(getTopologyRefreshOptions()) .validateClusterNodeMembership(isValidateClusterNodeMembership()).nodeFilter(getNodeFilter()); return builder; diff --git a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java index 577689cecc..e9e8a2ede0 100644 --- a/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java +++ b/src/main/java/io/lettuce/core/cluster/RedisClusterClient.java @@ -74,6 +74,8 @@ import io.netty.util.internal.logging.InternalLoggerFactory; import reactor.core.publisher.Mono; +import static io.lettuce.core.RedisAuthenticationHandler.createHandler; + /** * A scalable and thread-safe Redis cluster client supporting synchronous, asynchronous and * reactive execution models. Multiple threads may share one connection. The cluster client handles command routing based on the @@ -556,6 +558,9 @@ ConnectionFuture> connectToNodeAsync(RedisC StatefulRedisConnectionImpl connection = newStatefulRedisConnection(writer, endpoint, codec, getFirstUri().getTimeout(), getClusterClientOptions().getJsonParser()); + connection.setAuthenticationHandler( + createHandler(connection, getFirstUri().getCredentialsProvider(), false, getOptions())); + ConnectionFuture> connectionFuture = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, () -> new CommandHandler(getClusterClientOptions(), getResources(), endpoint)); @@ -620,10 +625,13 @@ ConnectionFuture> connectPubSubToNode StatefulRedisPubSubConnectionImpl connection = new StatefulRedisPubSubConnectionImpl<>(endpoint, writer, codec, getFirstUri().getTimeout()); + connection.setAuthenticationHandler( + createHandler(connection, getFirstUri().getCredentialsProvider(), true, getOptions())); ConnectionFuture> connectionFuture = connectStatefulAsync(connection, endpoint, getFirstUri(), socketAddressSupplier, () -> new PubSubCommandHandler<>(getClusterClientOptions(), getResources(), codec, endpoint)); + return connectionFuture.whenComplete((conn, throwable) -> { if (throwable != null) { connection.closeAsync(); @@ -772,6 +780,8 @@ private CompletableFuture> con clusterWriter.setClusterConnectionProvider(pooledClusterConnectionProvider); connection.setPartitions(partitions); + connection.setAuthenticationHandler( + createHandler(connection, getFirstUri().getCredentialsProvider(), true, getOptions())); Supplier commandHandlerSupplier = () -> new PubSubCommandHandler<>(getClusterClientOptions(), getResources(), codec, endpoint); @@ -843,6 +853,7 @@ private ConnectionBuilder createConnectionBuilder(RedisChannelHandler ConnectionBuilder createConnectionBuilder(RedisChannelHandler connection = client.connect()) { + assertThat(connection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID()); + assertThat(connection.async().aclWhoami().get()).isEqualTo(testCtx.getSpOID()); + assertThat(connection.reactive().aclWhoami().block()).isEqualTo(testCtx.getSpOID()); + } + } + + // T.1.1 + // Verify authentication using Azure AD with service principals using Redis Cluster Client + @Test + public void clusterWithSecret_azureServicePrincipalIntegrationTest() throws ExecutionException, InterruptedException { + + try (StatefulRedisClusterConnection connection = clusterClient.connect()) { + assertThat(connection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID()); + assertThat(connection.async().aclWhoami().get()).isEqualTo(testCtx.getSpOID()); + assertThat(connection.reactive().aclWhoami().block()).isEqualTo(testCtx.getSpOID()); + + connection.getPartitions().forEach((partition) -> { + try (StatefulRedisConnection nodeConnection = connection.getConnection(partition.getNodeId())) { + assertThat(nodeConnection.sync().aclWhoami()).isEqualTo(testCtx.getSpOID()); + } + }); + } + } + + // T.2.2 + // Test that the Redis client is not blocked/interrupted during token renewal. + @Test + public void renewalDuringOperationsTest() throws InterruptedException { + + // Counter to track the number of command cycles + AtomicInteger commandCycleCount = new AtomicInteger(0); + + // Start a thread to continuously send Redis commands + Thread commandThread = new Thread(() -> { + try (StatefulRedisConnection connection = client.connect()) { + RedisAsyncCommands async = connection.async(); + for (int i = 1; i <= 10; i++) { + // Start a transaction with SET and INCRBY commands + RedisFuture multi = async.multi(); + RedisFuture set = async.set("key", "1"); + RedisFuture incrby = async.incrby("key", 1); + RedisFuture exec = async.exec(); + TransactionResult results = exec.get(1, TimeUnit.SECONDS); + + // Increment the command cycle count after each execution + commandCycleCount.incrementAndGet(); + + // Verify the results from EXEC + assertThat(results).hasSize(2); // We expect 2 responses: SET and INCRBY + + // Check the response from each command in the transaction + assertThat((String) results.get(0)).isEqualTo("OK"); // SET "key" = "1" + assertThat((Long) results.get(1)).isEqualTo(2L); // INCRBY "key" by 1, expected result is 2 + } + } catch (Exception e) { + fail("Command execution failed during token refresh", e); + } + }); + + commandThread.start(); + + CountDownLatch latch = new CountDownLatch(10); // Wait for at least 10 token renewals + + credentialsProvider.credentials().subscribe(cred -> { + latch.countDown(); // Signal each renewal as it's received + }); + + assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); // Wait to reach 10 renewals + commandThread.join(); // Wait for the command thread to finish + + // Verify that at least 10 command cycles were executed during the test + assertThat(commandCycleCount.get()).isGreaterThanOrEqualTo(10); + } + + // T.2.2 + // Test basic Pub/Sub functionality is not blocked/interrupted during token renewal. + @Test + public void renewalDuringPubSubOperationsTest() throws InterruptedException { + try (StatefulRedisPubSubConnection connectionPubSub = client.connectPubSub(); + StatefulRedisPubSubConnection connectionPubSub1 = client.connectPubSub()) { + + PubSubTestListener listener = new PubSubTestListener(); + connectionPubSub.addListener(listener); + connectionPubSub.sync().subscribe("channel"); + + // Start a thread to continuously send Redis commands + Thread pubsubThread = new Thread(() -> { + for (int i = 1; i <= 100; i++) { + connectionPubSub1.sync().publish("channel", "message"); + } + }); + + pubsubThread.start(); + + CountDownLatch latch = new CountDownLatch(10); + credentialsProvider.credentials().subscribe(cred -> { + latch.countDown(); + }); + + assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); // Wait for at least 10 token renewals + pubsubThread.join(); // Wait for the pub/sub thread to finish + + // Verify that all messages were received + Wait.untilEquals(100, () -> listener.getMessages().size()).waitOrTimeout(); + assertThat(listener.getMessages()).allMatch(msg -> msg.equals("message")); + } + } + +} diff --git a/src/test/java/io/lettuce/authx/EntraIdTestContext.java b/src/test/java/io/lettuce/authx/EntraIdTestContext.java new file mode 100644 index 0000000000..7abfac0fe8 --- /dev/null +++ b/src/test/java/io/lettuce/authx/EntraIdTestContext.java @@ -0,0 +1,111 @@ +package io.lettuce.authx; + +import io.github.cdimascio.dotenv.Dotenv; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class EntraIdTestContext { + + private static final String AZURE_CLIENT_ID = "AZURE_CLIENT_ID"; + + private static final String AZURE_CLIENT_SECRET = "AZURE_CLIENT_SECRET"; + + private static final String AZURE_SP_OID = "AZURE_SP_OID"; + + private static final String AZURE_AUTHORITY = "AZURE_AUTHORITY"; + + private static final String AZURE_REDIS_SCOPES = "AZURE_REDIS_SCOPES"; + + private static final String REDIS_AZURE_HOST = "REDIS_AZURE_HOST"; + + private static final String REDIS_AZURE_PORT = "REDIS_AZURE_PORT"; + + private static final String REDIS_AZURE_CLUSTER_HOST = "REDIS_AZURE_CLUSTER_HOST"; + + private static final String REDIS_AZURE_CLUSTER_PORT = "REDIS_AZURE_CLUSTER_PORT"; + + private static final String REDIS_AZURE_DB = "REDIS_AZURE_DB"; + + private final String clientId; + + private final String authority; + + private final String clientSecret; + + private final String spOID; + + private final Set redisScopes; + + private final String redisHost; + + private final int redisPort; + + private final List redisClusterHost; + + private final int redisClusterPort; + + private static Dotenv dotenv; + static { + dotenv = Dotenv.configure().directory("src/test/resources").filename(".env.entraid").load(); + } + + public static final EntraIdTestContext DEFAULT = new EntraIdTestContext(); + + private EntraIdTestContext() { + // Using Dotenv directly here + clientId = dotenv.get(AZURE_CLIENT_ID, ""); + clientSecret = dotenv.get(AZURE_CLIENT_SECRET, ""); + spOID = dotenv.get(AZURE_SP_OID, ""); + authority = dotenv.get(AZURE_AUTHORITY, "https://login.microsoftonline.com/your-tenant-id"); + redisHost = dotenv.get(REDIS_AZURE_HOST); + redisPort = Integer.parseInt(dotenv.get(REDIS_AZURE_PORT, "6379")); + redisClusterHost = Arrays.asList(dotenv.get(REDIS_AZURE_CLUSTER_HOST, "").split(",")); + redisClusterPort = Integer.parseInt(dotenv.get(REDIS_AZURE_CLUSTER_PORT, "6379")); + String redisScopesEnv = dotenv.get(AZURE_REDIS_SCOPES, "https://redis.azure.com/.default"); + if (redisScopesEnv != null && !redisScopesEnv.isEmpty()) { + this.redisScopes = new HashSet<>(Arrays.asList(redisScopesEnv.split(";"))); + } else { + this.redisScopes = new HashSet<>(); + } + } + + public String host() { + return redisHost; + } + + public int port() { + return redisPort; + } + + public List clusterHost() { + return redisClusterHost; + } + + public int clusterPort() { + return redisClusterPort; + } + + public String getClientId() { + return clientId; + } + + public String getSpOID() { + return spOID; + } + + public String getAuthority() { + return authority; + } + + public String getClientSecret() { + return clientSecret; + } + + public Set getRedisScopes() { + return redisScopes; + } + +} diff --git a/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java b/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java new file mode 100644 index 0000000000..7ce58d0a65 --- /dev/null +++ b/src/test/java/io/lettuce/authx/TokenBasedRedisCredentialsProviderTest.java @@ -0,0 +1,158 @@ +package io.lettuce.authx; + +import io.lettuce.core.RedisCredentials; +import io.lettuce.core.TestTokenManager; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenManagerConfig; + +import java.time.Duration; +import java.util.Collections; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TokenBasedRedisCredentialsProviderTest { + + private TestTokenManager tokenManager; + + private TokenBasedRedisCredentialsProvider credentialsProvider; + + @BeforeEach + public void setUp() { + // Use TestToken manager to emit tokens/errors on request + TokenManagerConfig tokenManagerConfig = mock(TokenManagerConfig.class); + when(tokenManagerConfig.getRetryPolicy()).thenReturn(mock(TokenManagerConfig.RetryPolicy.class)); + tokenManager = new TestTokenManager(null, tokenManagerConfig); + credentialsProvider = TokenBasedRedisCredentialsProvider.create(tokenManager); + } + + @Test + public void shouldReturnPreviouslyEmittedTokenWhenResolved() { + tokenManager.emitToken(testToken("test-user", "token-1")); + + Mono credentials = credentialsProvider.resolveCredentials(); + + StepVerifier.create(credentials).assertNext(actual -> { + assertThat(actual.getUsername()).isEqualTo("test-user"); + assertThat(new String(actual.getPassword())).isEqualTo("token-1"); + }).verifyComplete(); + } + + @Test + public void shouldReturnLatestEmittedTokenWhenResolved() { + tokenManager.emitToken(testToken("test-user", "token-2")); + tokenManager.emitToken(testToken("test-user", "token-3")); // Latest token + + Mono credentials = credentialsProvider.resolveCredentials(); + + StepVerifier.create(credentials).assertNext(actual -> { + assertThat(actual.getUsername()).isEqualTo("test-user"); + assertThat(new String(actual.getPassword())).isEqualTo("token-3"); + }).verifyComplete(); + } + + @Test + public void shouldReturnTokenEmittedBeforeSubscription() { + + tokenManager.emitToken(testToken("test-user", "token-1")); + + // Test resolveCredentials + Mono credentials1 = credentialsProvider.resolveCredentials(); + + StepVerifier.create(credentials1).assertNext(actual -> { + assertThat(actual.getUsername()).isEqualTo("test-user"); + assertThat(new String(actual.getPassword())).isEqualTo("token-1"); + }).verifyComplete(); + + // Emit second token and subscribe another + tokenManager.emitToken(testToken("test-user", "token-2")); + tokenManager.emitToken(testToken("test-user", "token-3")); + Mono credentials2 = credentialsProvider.resolveCredentials(); + StepVerifier.create(credentials2).assertNext(actual -> { + assertThat(actual.getUsername()).isEqualTo("test-user"); + assertThat(new String(actual.getPassword())).isEqualTo("token-3"); + }).verifyComplete(); + } + + @Test + public void shouldWaitForAndReturnTokenWhenEmittedLater() { + Mono result = credentialsProvider.resolveCredentials(); + + tokenManager.emitTokenWithDelay(testToken("test-user", "delayed-token"), 100); // Emit token after 100ms + StepVerifier.create(result) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("delayed-token")) + .verifyComplete(); + } + + @Test + public void shouldCompleteAllSubscribersOnStop() { + Flux credentialsFlux1 = credentialsProvider.credentials(); + Flux credentialsFlux2 = credentialsProvider.credentials(); + + Disposable subscription1 = credentialsFlux1.subscribe(); + Disposable subscription2 = credentialsFlux2.subscribe(); + + tokenManager.emitToken(testToken("test-user", "token-1")); + + new Thread(() -> { + try { + Thread.sleep(100); // Delay of 100 milliseconds + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + credentialsProvider.close(); + }).start(); + + StepVerifier.create(credentialsFlux1) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token-1")) + .verifyComplete(); + + StepVerifier.create(credentialsFlux2) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token-1")) + .verifyComplete(); + + assertThat(subscription1.isDisposed()).isTrue(); + assertThat(subscription2.isDisposed()).isTrue(); + + } + + @Test + public void shouldPropagateMultipleTokensOnStream() { + + Flux result = credentialsProvider.credentials(); + StepVerifier.create(result).then(() -> tokenManager.emitToken(testToken("test-user", "token1"))) + .then(() -> tokenManager.emitToken(testToken("test-user", "token2"))) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token1")) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token2")) + .thenCancel().verify(Duration.ofMillis(100)); + } + + @Test + public void shouldHandleTokenRequestErrorGracefully() { + Exception simulatedError = new RuntimeException("Token request failed"); + + Flux result = credentialsProvider.credentials(); + + StepVerifier.create(result).then(() -> { + tokenManager.emitToken(testToken("test-user", "token1")); + tokenManager.emitError(simulatedError); + tokenManager.emitToken(testToken("test-user", "token2")); + }).assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token1")) + .assertNext(credentials -> assertThat(String.valueOf(credentials.getPassword())).isEqualTo("token2")) + .thenCancel().verify(Duration.ofMillis(100)); + + } + + private SimpleToken testToken(String username, String value) { + return new SimpleToken(username, value, System.currentTimeMillis() + 5000, // expires in 5 seconds + System.currentTimeMillis(), Collections.emptyMap()); + } + +} diff --git a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java index 864a2103b0..9914d21896 100644 --- a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java +++ b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java @@ -2,9 +2,16 @@ import static io.lettuce.TestTags.INTEGRATION_TEST; import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import javax.inject.Inject; +import io.lettuce.authx.TokenBasedRedisCredentialsProvider; +import io.lettuce.core.event.command.CommandListener; +import io.lettuce.core.event.command.CommandSucceededEvent; +import io.lettuce.core.protocol.RedisCommand; +import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -20,11 +27,20 @@ import io.lettuce.test.condition.EnabledOnCommand; import io.lettuce.test.settings.TestSettings; import reactor.core.publisher.Mono; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenManagerConfig; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; /** * Integration test for authentication. * * @author Mark Paluch + * @author Ivo Gaydajiev */ @Tag(INTEGRATION_TEST) @ExtendWith(LettuceExtension.class) @@ -37,6 +53,9 @@ void setUp(StatefulRedisConnection connection) { connection.sync().dispatch(CommandType.ACL, new StatusOutput<>(StringCodec.UTF8), new CommandArgs<>(StringCodec.UTF8).add("SETUSER").add("john").add("on").add(">foobared").add("-@all")); + + connection.sync().dispatch(CommandType.ACL, new StatusOutput<>(StringCodec.UTF8), + new CommandArgs<>(StringCodec.UTF8).add("SETUSER").add("steave").add("on").add(">foobared").add("+@all")); } @Test @@ -71,4 +90,107 @@ void ownCredentialProvider(RedisClient client) { }); } + // Simulate test user credential rotation, and verify that re-authentication is successful + @Test + @Inject + void streamingCredentialProvider(RedisClient client) { + + TestCommandListener listener = new TestCommandListener(); + client.addListener(listener); + client.setOptions(client.getOptions().mutate() + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + + // Build RedisURI with streaming credentials provider + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + RedisURI uri = RedisURI.builder().withHost(TestSettings.host()).withPort(TestSettings.port()) + .withClientName("streaming_cred_test").withAuthentication(credentialsProvider) + .withTimeout(Duration.ofSeconds(5)).build(); + + credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray()); + + // verify that the initial connection is successful with default user credentials + StatefulRedisConnection connection = client.connect(uri); + assertThat(connection.sync().aclWhoami()).isEqualTo(TestSettings.username()); + + // rotate the credentials + credentialsProvider.emitCredentials("steave", "foobared".toCharArray()); + + Awaitility.await().atMost(Duration.ofSeconds(1)).until(() -> listener.succeeded.stream() + .anyMatch(command -> isAuthCommandWithCredentials(command, "steave", "foobared".toCharArray()))); + + // verify that the connection is re-authenticated with the new user credentials + assertThat(connection.sync().aclWhoami()).isEqualTo("steave"); + + credentialsProvider.shutdown(); + connection.close(); + client.removeListener(listener); + client.setOptions( + client.getOptions().mutate().reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.DEFAULT).build()); + } + + @Test + @Inject + void tokenBasedCredentialProvider(RedisClient client) { + + TestCommandListener listener = new TestCommandListener(); + client.addListener(listener); + client.setOptions(client.getOptions().mutate() + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + + TokenManagerConfig tokenManagerConfig = mock(TokenManagerConfig.class); + when(tokenManagerConfig.getRetryPolicy()).thenReturn(mock(TokenManagerConfig.RetryPolicy.class)); + TestTokenManager tokenManager = new TestTokenManager(null, tokenManagerConfig); + TokenBasedRedisCredentialsProvider credentialsProvider = TokenBasedRedisCredentialsProvider.create(tokenManager); + + // Build RedisURI with streaming credentials provider + RedisURI uri = RedisURI.builder().withHost(TestSettings.host()).withPort(TestSettings.port()) + .withClientName("streaming_cred_test").withAuthentication(credentialsProvider) + .withTimeout(Duration.ofSeconds(5)).build(); + tokenManager.emitToken(testToken(TestSettings.username(), TestSettings.password().toString().toCharArray())); + + StatefulRedisConnection connection = client.connect(StringCodec.UTF8, uri); + assertThat(connection.sync().aclWhoami()).isEqualTo(TestSettings.username()); + + // rotate the credentials + tokenManager.emitToken(testToken("steave", "foobared".toCharArray())); + + Awaitility.await().atMost(Duration.ofSeconds(1)).until(() -> listener.succeeded.stream() + .anyMatch(command -> isAuthCommandWithCredentials(command, "steave", "foobared".toCharArray()))); + + // verify that the connection is re-authenticated with the new user credentials + assertThat(connection.sync().aclWhoami()).isEqualTo("steave"); + + credentialsProvider.close(); + connection.close(); + client.removeListener(listener); + client.setOptions( + client.getOptions().mutate().reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.DEFAULT).build()); + } + + static class TestCommandListener implements CommandListener { + + final List> succeeded = new ArrayList<>(); + + @Override + public void commandSucceeded(CommandSucceededEvent event) { + synchronized (succeeded) { + succeeded.add(event.getCommand()); + } + } + + } + + private boolean isAuthCommandWithCredentials(RedisCommand command, String username, char[] password) { + if (command.getType() == CommandType.AUTH) { + CommandArgs args = command.getArgs(); + return args.toCommandString().contains(username) && args.toCommandString().contains(String.valueOf(password)); + } + return false; + } + + private SimpleToken testToken(String username, char[] password) { + return new SimpleToken(username, String.valueOf(password), Instant.now().plusMillis(500).toEpochMilli(), + Instant.now().toEpochMilli(), Collections.emptyMap()); + } + } diff --git a/src/test/java/io/lettuce/core/MyStreamingRedisCredentialsProvider.java b/src/test/java/io/lettuce/core/MyStreamingRedisCredentialsProvider.java new file mode 100644 index 0000000000..12e9e37d15 --- /dev/null +++ b/src/test/java/io/lettuce/core/MyStreamingRedisCredentialsProvider.java @@ -0,0 +1,49 @@ +package io.lettuce.core; + +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * A provider for streaming credentials that can be used to authorize a Redis connection + * + * @author Ivo Gaydajiev + * @since 6.6.0 + */ +public class MyStreamingRedisCredentialsProvider implements RedisCredentialsProvider { + + private final Sinks.Many credentialsSink = Sinks.many().replay().latest(); + + @Override + public boolean supportsStreaming() { + return true; + } + + @Override + public Mono resolveCredentials() { + + return credentialsSink.asFlux().next(); + } + + public Flux credentials() { + + return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials + } + + public void shutdown() { + credentialsSink.tryEmitComplete(); + } + + public void emitCredentials(RedisCredentials credentials) { + credentialsSink.tryEmitNext(credentials); + } + + public void emitCredentials(String username, char[] password) { + credentialsSink.tryEmitNext(new StaticRedisCredentials(username, password)); + } + + public void tryEmitError(RuntimeException testError) { + credentialsSink.tryEmitError(testError); + } + +} diff --git a/src/test/java/io/lettuce/core/RedisAuthenticationHandlerUnitTests.java b/src/test/java/io/lettuce/core/RedisAuthenticationHandlerUnitTests.java new file mode 100644 index 0000000000..07cea7a15e --- /dev/null +++ b/src/test/java/io/lettuce/core/RedisAuthenticationHandlerUnitTests.java @@ -0,0 +1,214 @@ +package io.lettuce.core; + +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.event.DefaultEventBus; +import io.lettuce.core.event.EventBus; +import io.lettuce.core.event.connection.ReauthenticationFailedEvent; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.CommandType; +import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.core.protocol.RedisCommand; +import io.lettuce.core.resource.ClientResources; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatcher; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; + +import java.time.Duration; + +import static io.lettuce.TestTags.UNIT_TEST; +import static io.lettuce.core.protocol.CommandType.AUTH; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for the {@link RedisAuthenticationHandler} + */ +@Tag(UNIT_TEST) +public class RedisAuthenticationHandlerUnitTests { + + private StatefulRedisConnectionImpl connection; + + RedisChannelWriter writer; + + ClientResources resources; + + EventBus eventBus; + + ConnectionState connectionState; + + @BeforeEach + void setUp() { + eventBus = new DefaultEventBus(Schedulers.immediate()); + writer = mock(RedisChannelWriter.class); + connection = mock(StatefulRedisConnectionImpl.class); + resources = mock(ClientResources.class); + when(resources.eventBus()).thenReturn(eventBus); + + connectionState = mock(ConnectionState.class); + when(connection.getResources()).thenReturn(resources); + when(connection.getCodec()).thenReturn(StringCodec.UTF8); + when(connection.getConnectionState()).thenReturn(connectionState); + when(connection.getChannelWriter()).thenReturn(writer); + } + + @SuppressWarnings("unchecked") + @Test + void subscribeWithStreamingCredentialsProviderInvokesReauth() { + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, credentialsProvider, + false); + + // Subscribe to the provider + handler.subscribe(); + credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray()); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(AsyncCommand.class); + verify(writer).write(captor.capture()); + + AsyncCommand credentialsCommand = captor.getValue(); + assertThat(credentialsCommand.getType()).isEqualTo(AUTH); + assertThat(credentialsCommand.getArgs().count()).isEqualTo(2); + assertThat(credentialsCommand.getArgs().toCommandString()).isEqualTo("newuser" + " " + "newpassword"); + + credentialsProvider.shutdown(); + } + + @Test + void shouldHandleErrorInCredentialsStream() { + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, credentialsProvider, false); + + verify(connection, times(0)).dispatch(any(RedisCommand.class)); // No command should be sent + + // Verify the event was published + StepVerifier.create(eventBus.get()).then(() -> { + handler.subscribe(); + credentialsProvider.tryEmitError(new RuntimeException("Test error")); + }).expectNextMatches(event -> event instanceof ReauthenticationFailedEvent).thenCancel().verify(Duration.ofSeconds(1)); + + credentialsProvider.shutdown(); + } + + @Test + void shouldNotSubscribeIfConnectionIsNotSupported() { + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, credentialsProvider, true); + + // Subscribe to the provider (it should not subscribe due to unsupported connection) + handler.subscribe(); + credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray()); + + // Ensure credentials() was not called + verify(connection, times(0)).dispatch(any(RedisCommand.class)); // No command should be sent + } + + @Test + void testIsSupportedConnectionWithRESP2ProtocolOnPubSubConnection() { + + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, + mock(RedisCredentialsProvider.class), true); + + assertFalse(handler.isSupportedConnection()); + } + + @Test + void testIsSupportedConnectionWithNonPubSubConnection() { + + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, + mock(RedisCredentialsProvider.class), false); + + assertTrue(handler.isSupportedConnection()); + } + + @Test + void testIsSupportedConnectionWithRESP3ProtocolOnPubSubConnection() { + + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP3); + + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, + mock(RedisCredentialsProvider.class), true); + + assertTrue(handler.isSupportedConnection()); + } + + @Test + public void testSetCredentialsWhenCredentialsAreNull() { + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, + mock(RedisCredentialsProvider.class), false); + + handler.setCredentials(null); + + verify(connection, times(0)).dispatch(any(RedisCommand.class)); // No command should be sent + } + + @Test + void testSetCredentialsDoesNotDispatchAuthIfInTransaction() { + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler<>(connection, credentialsProvider, false); + + // Subscribe to the provider + handler.subscribe(); + + // Indicate a transaction is ongoing + handler.startTransaction(); + + // Attempt to authenticate + credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray()); + + // verify that the AUTH command was not sent + verify(connection, times(0)).dispatch(any(RedisCommand.class)); + + // Indicate a transaction is ongoing + handler.endTransaction(); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(AsyncCommand.class); + verify(writer).write(captor.capture()); + + AsyncCommand credentialsCommand = captor.getValue(); + assertThat(credentialsCommand.getType()).isEqualTo(AUTH); + assertThat(credentialsCommand.getArgs().count()).isEqualTo(2); + assertThat(credentialsCommand.getArgs().toCommandString()).isEqualTo("newuser" + " " + "newpassword"); + } + + public static ArgumentMatcher> isAuthCommand(String expectedUsername, + String expectedPassword) { + return new ArgumentMatcher>() { + + @Override + public boolean matches(RedisCommand command) { + if (command.getType() != CommandType.AUTH) { + return false; + } + + // Retrieve arguments (adjust based on your RedisCommand implementation) + return command.getArgs().toCommandString().equals(expectedUsername + " " + expectedPassword); + } + + @Override + public String toString() { + return String.format("Expected AUTH command with username=%s and password=%s", expectedUsername, + expectedPassword); + } + + }; + } + +} diff --git a/src/test/java/io/lettuce/core/RedisClientConnectIntegrationTests.java b/src/test/java/io/lettuce/core/RedisClientConnectIntegrationTests.java index 4e7c281e40..416ffa3a44 100644 --- a/src/test/java/io/lettuce/core/RedisClientConnectIntegrationTests.java +++ b/src/test/java/io/lettuce/core/RedisClientConnectIntegrationTests.java @@ -32,6 +32,7 @@ import javax.inject.Inject; +import io.lettuce.core.protocol.ProtocolVersion; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Tag; @@ -219,6 +220,23 @@ void connectPubSubCodecSentinelMissingHostAndSocketUri() { assertThatThrownBy(() -> client.connectPubSub(UTF8, invalidSentinel())).isInstanceOf(IllegalArgumentException.class); } + @Test + void connectPubSubAsyncReauthNotSupportedWithRESP2() { + ClientOptions.ReauthenticateBehavior reauth = client.getOptions().getReauthenticateBehaviour(); + ProtocolVersion protocolVersion = client.getOptions().getConfiguredProtocolVersion(); + try { + client.setOptions(client.getOptions().mutate().protocolVersion(ProtocolVersion.RESP2) + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + + RedisURI redisURI = redis(host, port).build(); + assertThatThrownBy(() -> client.connectPubSubAsync(UTF8, redisURI)).isInstanceOf(RedisConnectionException.class); + + } finally { + client.setOptions( + client.getOptions().mutate().protocolVersion(protocolVersion).reauthenticateBehavior(reauth).build()); + } + } + /* * Sentinel Stateful */ diff --git a/src/test/java/io/lettuce/core/TestTokenManager.java b/src/test/java/io/lettuce/core/TestTokenManager.java new file mode 100644 index 0000000000..391b6302b9 --- /dev/null +++ b/src/test/java/io/lettuce/core/TestTokenManager.java @@ -0,0 +1,50 @@ +package io.lettuce.core; + +import redis.clients.authentication.core.IdentityProvider; +import redis.clients.authentication.core.SimpleToken; +import redis.clients.authentication.core.TokenListener; +import redis.clients.authentication.core.TokenManager; +import redis.clients.authentication.core.TokenManagerConfig; + +public class TestTokenManager extends TokenManager { + + private TokenListener listener; + + public TestTokenManager(IdentityProvider identityProvider, TokenManagerConfig tokenManagerConfig) { + super(identityProvider, tokenManagerConfig); + } + + @Override + public void start(TokenListener listener, boolean waitForToken) { + this.listener = listener; + } + + @Override + public void stop() { + // Cleanup logic if needed + } + + public void emitToken(SimpleToken token) { + if (listener != null) { + listener.onTokenRenewed(token); + } + } + + public void emitError(Exception exception) { + if (listener != null) { + listener.onError(exception); + } + } + + public void emitTokenWithDelay(SimpleToken token, long delayMillis) { + new Thread(() -> { + try { + Thread.sleep(delayMillis); + emitToken(token); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + }).start(); + } + +} diff --git a/src/test/java/io/lettuce/core/cluster/ClusterClientOptionsIntegrationTests.java b/src/test/java/io/lettuce/core/cluster/ClusterClientOptionsIntegrationTests.java index 6eddfa2e0d..94cbbe76bb 100644 --- a/src/test/java/io/lettuce/core/cluster/ClusterClientOptionsIntegrationTests.java +++ b/src/test/java/io/lettuce/core/cluster/ClusterClientOptionsIntegrationTests.java @@ -1,6 +1,7 @@ package io.lettuce.core.cluster; import static io.lettuce.TestTags.INTEGRATION_TEST; +import static io.lettuce.core.codec.StringCodec.UTF8; import static org.assertj.core.api.Assertions.*; import java.time.Duration; @@ -8,6 +9,9 @@ import javax.inject.Inject; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisConnectionException; +import io.lettuce.core.protocol.ProtocolVersion; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -80,4 +84,23 @@ void shouldApplyTimeoutOptionsToPubSubClusterConnection() throws InterruptedExce Thread.sleep(300); } + @Test + void connectPubSubAsyncReauthNotSupportedWithRESP2() { + + ClientOptions.ReauthenticateBehavior reauth = clusterClient.getClusterClientOptions().getReauthenticateBehaviour(); + ProtocolVersion protocolVersion = clusterClient.getClusterClientOptions().getConfiguredProtocolVersion(); + + try { + clusterClient.setOptions(clusterClient.getClusterClientOptions().mutate().protocolVersion(ProtocolVersion.RESP2) + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + assertThatThrownBy(() -> clusterClient.connectPubSub(UTF8)).isInstanceOf(RedisConnectionException.class); + + } finally { + + clusterClient.setOptions(clusterClient.getClusterClientOptions().mutate().protocolVersion(protocolVersion) + .reauthenticateBehavior(reauth).build()); + } + + } + } diff --git a/src/test/java/io/lettuce/core/cluster/RedisClusterStreamingCredentialsProviderlIntegrationTests.java b/src/test/java/io/lettuce/core/cluster/RedisClusterStreamingCredentialsProviderlIntegrationTests.java new file mode 100644 index 0000000000..908ec7583e --- /dev/null +++ b/src/test/java/io/lettuce/core/cluster/RedisClusterStreamingCredentialsProviderlIntegrationTests.java @@ -0,0 +1,172 @@ +package io.lettuce.core.cluster; + +import io.lettuce.core.AclSetuserArgs; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.MyStreamingRedisCredentialsProvider; +import io.lettuce.core.RedisCommandExecutionException; +import io.lettuce.core.RedisURI; +import io.lettuce.core.TestSupport; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.sync.Executions; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.test.CanConnect; +import io.lettuce.test.resource.FastShutdown; +import io.lettuce.test.resource.TestClientResources; +import io.lettuce.test.settings.TestSettings; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.stream.Collectors; + +import static io.lettuce.TestTags.INTEGRATION_TEST; +import static io.lettuce.test.settings.TestSettings.host; +import static io.lettuce.test.settings.TestSettings.hostAddr; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * @author Ivo Gaydajiev + */ +@Tag(INTEGRATION_TEST) +class RedisClusterStreamingCredentialsProviderIntegrationTests extends TestSupport { + + private static final int CLUSTER_PORT_SSL_1 = 7442; + + private static final int CLUSTER_PORT_SSL_2 = 7444; // replica cannot replicate properly with upstream + + private static final int CLUSTER_PORT_SSL_3 = 7445; + + private static final String SLOT_1_KEY = "8HMdi"; + + private static final String SLOT_16352_KEY = "UyAa4KqoWgPGKa"; + + private static MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + + private static RedisURI redisURI = RedisURI.Builder.redis(host(), CLUSTER_PORT_SSL_1).withSsl(true) + .withAuthentication(credentialsProvider).withVerifyPeer(false).build(); + + private static RedisClusterClient redisClient = RedisClusterClient.create(TestClientResources.get(), redisURI); + + @BeforeEach + void before() { + assumeTrue(CanConnect.to(host(), CLUSTER_PORT_SSL_1), "Assume that stunnel runs on port 7442"); + assumeTrue(CanConnect.to(host(), CLUSTER_PORT_SSL_2), "Assume that stunnel runs on port 7444"); + assumeTrue(CanConnect.to(host(), CLUSTER_PORT_SSL_3), "Assume that stunnel runs on port 7445"); + assumeTrue(CanConnect.to(host(), 7479), "Assume that Redis runs on port 7479"); + assumeTrue(CanConnect.to(host(), 7480), "Assume that Redis runs on port 7480"); + assumeTrue(CanConnect.to(host(), 7481), "Assume that Redis runs on port 7481"); + } + + @BeforeAll + static void beforeClass() { + credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray()); + } + + @AfterAll + static void afterClass() { + credentialsProvider.shutdown(); + FastShutdown.shutdown(redisClient); + } + + @Test + void defaultClusterConnectionShouldWork() { + + StatefulRedisClusterConnection connection = redisClient.connect(); + assertThat(connection.sync().ping()).isEqualTo("PONG"); + + connection.close(); + } + + @Test + void partitionViewShouldContainClusterPorts() { + + StatefulRedisClusterConnection connection = redisClient.connect(); + List ports = connection.getPartitions().stream().map(redisClusterNode -> redisClusterNode.getUri().getPort()) + .collect(Collectors.toList()); + connection.close(); + + assertThat(ports).contains(CLUSTER_PORT_SSL_1, CLUSTER_PORT_SSL_3); + } + + @Test + void routedOperationsAreWorking() { + + StatefulRedisClusterConnection connection = redisClient.connect(); + RedisAdvancedClusterCommands sync = connection.sync(); + + sync.set(SLOT_1_KEY, "value1"); + sync.set(SLOT_16352_KEY, "value2"); + + assertThat(sync.get(SLOT_1_KEY)).isEqualTo("value1"); + assertThat(sync.get(SLOT_16352_KEY)).isEqualTo("value2"); + + connection.close(); + } + + @Test + void nodeConnectionsShouldWork() { + + StatefulRedisClusterConnection connection = redisClient.connect(); + + // master 2 + StatefulRedisConnection node2Connection = connection.getConnection(hostAddr(), 7445); + + try { + node2Connection.sync().get(SLOT_1_KEY); + } catch (RedisCommandExecutionException e) { + assertThat(e).hasMessage("MOVED 1 127.0.0.1:7442"); + } + + connection.close(); + } + + @Test + void nodeSelectionApiShouldWork() { + + StatefulRedisClusterConnection connection = redisClient.connect(); + + Executions ping = connection.sync().all().commands().ping(); + assertThat(ping).hasSize(3).contains("PONG"); + + connection.close(); + } + + @Test + void shouldPerformNodeConnectionReauth() { + ClusterClientOptions origClientOptions = redisClient.getClusterClientOptions(); + redisClient.setOptions(origClientOptions.mutate() + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + + StatefulRedisClusterConnection connection = redisClient.connect(); + connection.getPartitions().forEach( + partition -> createTestUser(connection.getConnection(partition.getNodeId()).sync(), "steave", "foobared")); + + credentialsProvider.emitCredentials("steave", "foobared".toCharArray()); + + // Verify each node's authenticated username matches the updated credentials + connection.getPartitions().forEach(partition -> { + StatefulRedisConnection userConn = connection.getConnection(partition.getNodeId()); + assertThat(userConn.sync().aclWhoami()).isEqualTo("steave"); + }); + + // re-auth with the default credentials + credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray()); + + connection.getPartitions().forEach(partition -> { + connection.getConnection(partition.getNodeId()).sync().aclDeluser("steave"); + }); + + connection.close(); + } + + public static void createTestUser(RedisCommands commands, String username, String password) { + commands.aclSetuser(username, AclSetuserArgs.Builder.on().allCommands().addPassword(password)); + } + +} diff --git a/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java b/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java index 21d9eb5e83..12bbbeddaf 100644 --- a/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java +++ b/src/test/java/io/lettuce/core/event/ConnectionEventsTriggeredIntegrationTests.java @@ -6,9 +6,19 @@ import java.time.Duration; import java.time.temporal.ChronoUnit; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.MyStreamingRedisCredentialsProvider; +import io.lettuce.core.event.connection.AuthenticationEvent; +import io.lettuce.core.event.connection.ReauthenticationEvent; +import io.lettuce.core.event.connection.ReauthenticationFailedEvent; +import io.lettuce.test.LettuceExtension; +import io.lettuce.test.WithPassword; +import io.lettuce.test.settings.TestSettings; +import org.assertj.core.api.InstanceOfAssertFactories; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import reactor.core.publisher.Flux; import reactor.test.StepVerifier; import io.lettuce.core.RedisClient; @@ -20,8 +30,10 @@ /** * @author Mark Paluch + * @author Ivo Gaydajiev */ @Tag(INTEGRATION_TEST) +@ExtendWith(LettuceExtension.class) class ConnectionEventsTriggeredIntegrationTests extends TestSupport { @Test @@ -41,4 +53,30 @@ void testConnectionEvents() { FastShutdown.shutdown(client); } + @Test + void testReauthenticateEvents() { + + MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider(); + credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray()); + + RedisClient client = RedisClient.create(RedisURI.create(TestSettings.host(), TestSettings.port())); + client.setOptions(ClientOptions.builder() + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build()); + RedisURI uri = RedisURI.Builder.redis(host, port).withAuthentication(credentialsProvider).build(); + + Flux publisher = client.getResources().eventBus().get() + .filter(event -> event instanceof AuthenticationEvent).cast(AuthenticationEvent.class); + + WithPassword.run(client, () -> StepVerifier.create(publisher).then(() -> client.connect(uri)) + .assertNext(event -> assertThat(event).asInstanceOf(InstanceOfAssertFactories.type(ReauthenticationEvent.class)) + .extracting(ReauthenticationEvent::getEpId).isNotNull()) + .then(() -> credentialsProvider.emitCredentials(TestSettings.username(), "invalid".toCharArray())) + .assertNext(event -> assertThat(event) + .asInstanceOf(InstanceOfAssertFactories.type(ReauthenticationFailedEvent.class)) + .extracting(ReauthenticationFailedEvent::getEpId).isNotNull()) + .thenCancel().verify(Duration.of(1, ChronoUnit.SECONDS))); + + FastShutdown.shutdown(client); + } + } diff --git a/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java b/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java new file mode 100644 index 0000000000..e8bd9f9a3b --- /dev/null +++ b/src/test/java/io/lettuce/examples/TokenBasedAuthExample.java @@ -0,0 +1,138 @@ +package io.lettuce.examples; + +import io.lettuce.authx.TokenBasedRedisCredentialsProvider; +import io.lettuce.core.ClientOptions; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.SocketOptions; +import io.lettuce.core.TimeoutOptions; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.cluster.ClusterClientOptions; +import io.lettuce.core.cluster.RedisClusterClient; +import io.lettuce.core.cluster.api.StatefulRedisClusterConnection; +import io.lettuce.core.cluster.api.sync.NodeSelection; +import io.lettuce.core.cluster.api.sync.RedisAdvancedClusterCommands; +import io.lettuce.core.codec.StringCodec; +import redis.clients.authentication.core.IdentityProviderConfig; +import redis.clients.authentication.core.TokenAuthConfig; +import redis.clients.authentication.entraid.EntraIDTokenAuthConfigBuilder; + +import java.time.Duration; +import java.util.Collections; +import java.util.Set; + +public class TokenBasedAuthExample { + + public static final String REDIS_URI = "redis://108.143.40.70:12002"; + + public static void main(String[] args) throws Exception { + // Configure TokenManager + String authority = "https://login.microsoftonline.com/562f7bf2-f594-47bf-8ac3-a06514b5d434"; + Set scopes = Collections.singleton("https://redis.azure.com/.default"); + + String User1_clientId = System.getenv("USER1_CLIENT_ID"); + String User1_secret = System.getenv("USER1_SECRET"); + + String User2_clientId = System.getenv("USER2_CLIENT_ID"); + String User2_secret = System.getenv("USER2_SECRET"); + // User 1 + // from redis-authx-entraind + IdentityProviderConfig config1; + try (EntraIDTokenAuthConfigBuilder builder = EntraIDTokenAuthConfigBuilder.builder()) { + config1 = builder.authority(authority).clientId(User1_clientId).secret(User1_secret).scopes(scopes) + .tokenRequestExecTimeoutInMs(10000).build().getIdentityProviderConfig(); + } + + // from redis-authx-core + TokenAuthConfig tokenAuthConfigUser1 = TokenAuthConfig.builder().tokenRequestExecTimeoutInMs(10000) + .expirationRefreshRatio(0.1f).identityProviderConfig(config1).build(); + // Create credentials provider user1 + TokenBasedRedisCredentialsProvider credentialsUser1 = TokenBasedRedisCredentialsProvider.create(tokenAuthConfigUser1); + + // User2 + // from redis-authx-entraind + IdentityProviderConfig config2 = EntraIDTokenAuthConfigBuilder.builder().authority(authority).clientId(User2_clientId) + .secret(User2_secret).scopes(scopes).tokenRequestExecTimeoutInMs(10000).build().getIdentityProviderConfig(); + // from redis-authx-core + TokenAuthConfig tokenAuthConfigUser2 = TokenAuthConfig.builder().tokenRequestExecTimeoutInMs(10000) + .expirationRefreshRatio(0.1f).identityProviderConfig(config2).build(); + // Create credentials provider user2 + // TODO: lettuce-autx-tba ( TokenBasedRedisCredentialsProvider & Example there) + TokenBasedRedisCredentialsProvider credentialsUser2 = TokenBasedRedisCredentialsProvider.create(tokenAuthConfigUser2); + + // lettuce-core + RedisURI redisURI1 = RedisURI.create(REDIS_URI); + redisURI1.setCredentialsProvider(credentialsUser1); + + RedisURI redisURI2 = RedisURI.create(REDIS_URI); + redisURI2.setCredentialsProvider(credentialsUser2); + + // Create RedisClient + ClientOptions clientOptions = ClientOptions.builder() + .socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(5)).build()) + .disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS) + .timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1))) + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build(); + try { + + // RedisClient using user1 credentials by default + RedisClient redisClient = RedisClient.create(redisURI1); + redisClient.setOptions(clientOptions); + + // create connection using default URI (authorised as user1) + try (StatefulRedisConnection user1 = redisClient.connect(StringCodec.UTF8)) { + + user1.reactive().aclWhoami().doOnNext(System.out::println).block(); + } + + // another connection using different authorizations (user2 credentials provider) + try (StatefulRedisConnection user2 = redisClient.connect(StringCodec.UTF8, redisURI2);) { + user2.reactive().aclWhoami().doOnNext(System.out::println).block(); + } + + // Shutdown Redis client and close connections + redisClient.shutdown(); + + ClusterClientOptions clusterClientOptions = ClusterClientOptions.builder() + .socketOptions(SocketOptions.builder().connectTimeout(Duration.ofSeconds(5)).build()) + .disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS) + .timeoutOptions(TimeoutOptions.enabled(Duration.ofSeconds(1))) + .reauthenticateBehavior(ClientOptions.ReauthenticateBehavior.ON_NEW_CREDENTIALS).build(); + + // RedisClient using user1 credentials by default + RedisClusterClient redisClusterClient = RedisClusterClient.create(redisURI1); + redisClusterClient.setOptions(clusterClientOptions); + + // create connection using default URI (authorised as user1) + try (StatefulRedisClusterConnection clusterConnection = redisClusterClient + .connect(StringCodec.UTF8)) { + + String info = clusterConnection.sync().clusterInfo(); + System.out.println("Cluster Info :" + info); + + String nodes = clusterConnection.sync().clusterNodes(); + System.out.println("Cluster Nodes :" + nodes); + + clusterConnection.sync().set("cluster-key", "cluster-value"); + System.out.println("set " + clusterConnection.sync().get("cluster-key")); + + RedisAdvancedClusterCommands sync = clusterConnection.sync(); + NodeSelection upstream = sync.upstream(); + + upstream.commands().clientId().forEach((v) -> { + System.out.println("Client Id : " + v); + }); + + System.out.println(" whoami :" + clusterConnection + .getConnection(clusterConnection.getPartitions().getPartition(0).getNodeId()).sync().aclWhoami()); + } + // Shutdown Redis client and close connections + redisClusterClient.shutdown(); + } finally { + credentialsUser1.close(); + credentialsUser2.close(); + } + + } + +} diff --git a/src/test/resources/.env.entraid b/src/test/resources/.env.entraid new file mode 100644 index 0000000000..016449e929 --- /dev/null +++ b/src/test/resources/.env.entraid @@ -0,0 +1,11 @@ +AZURE_SP_OID= +AZURE_CLIENT_ID= +AZURE_CLIENT_SECRET= +AZURE_REDIS_SCOPES=https://redis.azure.com/.default +AZURE_AUTHORITY=https://login.microsoftonline.com/ +# Redis standalone db with Azure enabled authentication +REDIS_AZURE_HOST= +REDIS_AZURE_PORT=6379 +# Redis cluster db with Azure enabled authentication & osscluster API enabled +REDIS_AZURE_CLUSTER_HOST= +REDIS_AZURE_CLUSTER_PORT=6379