Skip to content

Commit

Permalink
Merge pull request #475 from supabase-community/auth-source
Browse files Browse the repository at this point in the history
Add Source to session status & rework compose auth
  • Loading branch information
jan-tennert authored Feb 22, 2024
2 parents a464abe + ab5e3ed commit 14c9555
Show file tree
Hide file tree
Showing 32 changed files with 494 additions and 442 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ sealed interface Auth : MainPlugin<AuthConfig>, CustomSerializationPlugin {
/**
* Imports a user session and starts auto-refreshing if [autoRefresh] is true
*/
suspend fun importSession(session: UserSession, autoRefresh: Boolean = true)
suspend fun importSession(session: UserSession, autoRefresh: Boolean = true, source: SessionSource = SessionSource.Unknown)

/**
* Imports the jwt token and retrieves the user profile.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import io.github.jan.supabase.gotrue.mfa.MfaApiImpl
import io.github.jan.supabase.gotrue.providers.AuthProvider
import io.github.jan.supabase.gotrue.providers.ExternalAuthConfigDefaults
import io.github.jan.supabase.gotrue.providers.OAuthProvider
import io.github.jan.supabase.gotrue.providers.builtin.OTP
import io.github.jan.supabase.gotrue.providers.builtin.SSO
import io.github.jan.supabase.gotrue.user.UserInfo
import io.github.jan.supabase.gotrue.user.UserSession
Expand Down Expand Up @@ -94,11 +95,11 @@ internal class AuthImpl(
Auth.logger.i {
"No session found."
}
_sessionStatus.value = SessionStatus.NotAuthenticated
_sessionStatus.value = SessionStatus.NotAuthenticated(false)
}
}
} else {
_sessionStatus.value = SessionStatus.NotAuthenticated
_sessionStatus.value = SessionStatus.NotAuthenticated(false)
}
}

Expand All @@ -107,15 +108,15 @@ internal class AuthImpl(
redirectUrl: String?,
config: (C.() -> Unit)?
) = provider.login(supabaseClient, {
importSession(it)
importSession(it, source = SessionSource.SignIn(provider))
}, redirectUrl, config)

override suspend fun <C, R, Provider : AuthProvider<C, R>> signUpWith(
provider: Provider,
redirectUrl: String?,
config: (C.() -> Unit)?
): R? = provider.signUp(supabaseClient, {
importSession(it)
importSession(it, source = SessionSource.SignUp(provider))
}, redirectUrl, config)

@SupabaseExperimental
Expand All @@ -134,7 +135,7 @@ internal class AuthImpl(
response.request.url.toString()
},
onSessionSuccess = {
importSession(it)
importSession(it, source = SessionSource.UserIdentitiesChanged(it))
}
)
}
Expand All @@ -146,7 +147,7 @@ internal class AuthImpl(
val session = currentSessionOrNull() ?: return
val newUser = session.user?.copy(identities = session.user.identities?.filter { it.identityId != identityId })
val newSession = session.copy(user = newUser)
_sessionStatus.value = SessionStatus.Authenticated(newSession, sessionStatus.value)
_sessionStatus.value = SessionStatus.Authenticated(newSession, SessionSource.UserIdentitiesChanged(session))
}
}

Expand Down Expand Up @@ -219,7 +220,7 @@ internal class AuthImpl(
if (this.config.autoSaveToStorage) {
sessionManager.saveSession(newSession)
}
_sessionStatus.value = SessionStatus.Authenticated(newSession, sessionStatus.value)
_sessionStatus.value = SessionStatus.Authenticated(newSession, SessionSource.UserChanged(newSession))
}
return userInfo
}
Expand Down Expand Up @@ -322,7 +323,7 @@ internal class AuthImpl(
}
val response = api.postJson("verify", body)
val session = response.body<UserSession>()
importSession(session)
importSession(session, source = SessionSource.SignIn(OTP))
}

override suspend fun verifyEmailOtp(
Expand Down Expand Up @@ -355,7 +356,7 @@ internal class AuthImpl(
val user = retrieveUser(currentAccessTokenOrNull() ?: error("No session found"))
if (updateSession) {
val session = currentSessionOrNull() ?: error("No session found")
val newStatus = SessionStatus.Authenticated(session.copy(user = user), sessionStatus.value)
val newStatus = SessionStatus.Authenticated(session.copy(user = user), SessionSource.UserChanged(currentSessionOrNull() ?: error("Session shouldn't be null")))
_sessionStatus.value = newStatus
if (config.autoSaveToStorage) sessionManager.saveSession(newStatus.session)
}
Expand All @@ -372,7 +373,7 @@ internal class AuthImpl(
}.safeBody<UserSession>()
codeVerifierCache.deleteCodeVerifier()
if (saveSession) {
importSession(session)
importSession(session, source = SessionSource.External)
}
return session
}
Expand All @@ -395,12 +396,16 @@ internal class AuthImpl(
currentSessionOrNull()?.refreshToken
?: error("No refresh token found in current session")
)
importSession(newSession)
importSession(newSession, source = SessionSource.Refresh(currentSessionOrNull() ?: error("No session found")))
}

override suspend fun importSession(session: UserSession, autoRefresh: Boolean) {
override suspend fun importSession(
session: UserSession,
autoRefresh: Boolean,
source: SessionSource
) {
if (!autoRefresh) {
_sessionStatus.value = SessionStatus.Authenticated(session, sessionStatus.value)
_sessionStatus.value = SessionStatus.Authenticated(session, source)
if (session.refreshToken.isNotBlank() && session.expiresIn != 0L && config.autoSaveToStorage) {
sessionManager.saveSession(session)
}
Expand All @@ -412,15 +417,15 @@ internal class AuthImpl(
{ importSession(session) }
)
} else {
_sessionStatus.value = SessionStatus.Authenticated(session, sessionStatus.value)
_sessionStatus.value = SessionStatus.Authenticated(session, source)
if (config.autoSaveToStorage) sessionManager.saveSession(session)
sessionJob?.cancel()
sessionJob = authScope.launch {
delayBeforeExpiry(session)
launch {
tryImportingSession(
{ handleExpiredSession(session) },
{ importSession(session) }
{ importSession(session, source = source) }
)
}
}
Expand Down Expand Up @@ -458,11 +463,11 @@ internal class AuthImpl(
"Session expired. Refreshing session..."
}
val newSession = refreshSession(session.refreshToken)
importSession(newSession, autoRefresh)
importSession(newSession, autoRefresh, SessionSource.Refresh(session))
}

override suspend fun startAutoRefreshForCurrentSession() =
importSession(currentSessionOrNull() ?: error("No session found"), true)
importSession(currentSessionOrNull() ?: error("No session found"), true, (sessionStatus.value as SessionStatus.Authenticated).source)

override fun stopAutoRefreshForCurrentSession() {
sessionJob?.cancel()
Expand All @@ -472,7 +477,7 @@ internal class AuthImpl(
override suspend fun loadFromStorage(autoRefresh: Boolean): Boolean {
val session = sessionManager.loadSession()
session?.let {
importSession(it, autoRefresh)
importSession(it, autoRefresh, SessionSource.Storage)
}
return session != null
}
Expand Down Expand Up @@ -536,7 +541,7 @@ internal class AuthImpl(
codeVerifierCache.deleteCodeVerifier()
sessionManager.deleteSession()
sessionJob?.cancel()
_sessionStatus.value = SessionStatus.NotAuthenticated
_sessionStatus.value = SessionStatus.NotAuthenticated(true)
sessionJob = null
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.jan.supabase.gotrue

import io.github.jan.supabase.gotrue.providers.AuthProvider
import io.github.jan.supabase.gotrue.user.UserSession

/**
Expand All @@ -9,8 +10,9 @@ sealed interface SessionStatus {

/**
* This status means that the user is not logged in
* @param isSignOut Whether this status was caused by a sign out
*/
data object NotAuthenticated : SessionStatus
data class NotAuthenticated(val isSignOut: Boolean) : SessionStatus

/**
* This status means that [Auth] is currently loading the session from storage
Expand All @@ -25,7 +27,59 @@ sealed interface SessionStatus {
/**
* This status means that [Auth] holds a valid session
* @param session The session
* @param oldStatus The previous status. Useful for knowing if the user was already authenticated
* @param source The source of the session
*/
data class Authenticated(val session: UserSession, val oldStatus: SessionStatus) : SessionStatus
data class Authenticated(val session: UserSession, val source: SessionSource = SessionSource.Unknown) : SessionStatus

}

/**
* Represents the source of a session
*/
sealed interface SessionSource {

/**
* The session was loaded from storage
*/
data object Storage : SessionSource

/**
* The session was loaded from a sign in
* @param provider The provider that was used to sign in
*/
data class SignIn(val provider: AuthProvider<*, *>) : SessionSource

/**
* The session was loaded from a sign up (only if auto-confirm is enabled)
* @param provider The provider that was used to sign up
*/
data class SignUp(val provider: AuthProvider<*, *>) : SessionSource

/**
* The session comes from an external source, e.g. OAuth via deeplinks.
*/
data object External : SessionSource

/**
* The session comes from an unknown source
*/
data object Unknown : SessionSource

/**
* The session was refreshed
* @param oldSession The old session
*/
data class Refresh(val oldSession: UserSession) : SessionSource

/**
* The session was changed due to a user change (e.g. via [Auth.modifyUser] or [Auth.retrieveUserForCurrentSession])
* @param oldSession The old session
*/
data class UserChanged(val oldSession: UserSession) : SessionSource

/**
* The session was changed due to a user identity change (e.g. via [Auth.linkIdentity] or [Auth.unlinkIdentity])
* @param oldSession The old session
*/
data class UserIdentitiesChanged(val oldSession: UserSession) : SessionSource
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ fun Auth.parseFragmentAndImportSession(fragment: String, onSessionSuccess: (User
val user = retrieveUser(session.accessToken)
val newSession = session.copy(user = user)
onSessionSuccess(newSession)
importSession(newSession)
importSession(newSession, source = SessionSource.External)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,29 @@ sealed interface MfaApi {
}

internal class MfaApiImpl(
val gotrue: AuthImpl
val auth: AuthImpl
) : MfaApi {

override val isMfaEnabledFlow: Flow<Boolean> = gotrue.sessionStatus.map {
override val isMfaEnabledFlow: Flow<Boolean> = auth.sessionStatus.map {
when(it) {
is SessionStatus.Authenticated -> isMfaEnabled
SessionStatus.LoadingFromStorage -> false
SessionStatus.NetworkError -> false
SessionStatus.NotAuthenticated -> false
is SessionStatus.NotAuthenticated -> false
}
}
override val loggedInUsingMfaFlow: Flow<Boolean> = gotrue.sessionStatus.map {
override val loggedInUsingMfaFlow: Flow<Boolean> = auth.sessionStatus.map {
when(it) {
is SessionStatus.Authenticated -> loggedInUsingMfa
SessionStatus.LoadingFromStorage -> false
SessionStatus.NetworkError -> false
SessionStatus.NotAuthenticated -> false
is SessionStatus.NotAuthenticated -> false
}
}
override val verifiedFactors: List<UserMfaFactor>
get() = (gotrue.sessionStatus.value as? SessionStatus.Authenticated)?.session?.user?.factors?.filter(UserMfaFactor::isVerified) ?: emptyList()
get() = (auth.sessionStatus.value as? SessionStatus.Authenticated)?.session?.user?.factors?.filter(UserMfaFactor::isVerified) ?: emptyList()

val api = gotrue.api
val api = auth.api

override suspend fun <Response> enroll(
factorType: FactorType<Response>,
Expand Down Expand Up @@ -164,7 +164,7 @@ internal class MfaApiImpl(
})
val session = result.body<UserSession>()
if(saveSession) {
gotrue.importSession(session)
auth.importSession(session)
}
return session
}
Expand All @@ -174,7 +174,7 @@ internal class MfaApiImpl(
}

override fun getAuthenticatorAssuranceLevel(): MfaLevel {
val jwt = gotrue.currentAccessTokenOrNull() ?: error("Current session is null")
val jwt = auth.currentAccessTokenOrNull() ?: error("Current session is null")
val parts = jwt.split(".")
val decodedJwt = Json.decodeFromString<JsonObject>(parts[1].decodeBase64String())
val aal = AuthenticatorAssuranceLevel.from(decodedJwt["aal"]?.jsonPrimitive?.content ?: error("No 'aal' claim found in JWT"))
Expand All @@ -184,7 +184,7 @@ internal class MfaApiImpl(


override suspend fun retrieveFactorsForCurrentUser(): List<UserMfaFactor> {
return gotrue.retrieveUser(gotrue.currentAccessTokenOrNull() ?: error("Current session is null")).factors
return auth.retrieveUser(auth.currentAccessTokenOrNull() ?: error("Current session is null")).factors
}

}
Loading

0 comments on commit 14c9555

Please sign in to comment.