Skip to content

Commit

Permalink
Merge pull request #688 from supabase-community/third-party-auth
Browse files Browse the repository at this point in the history
Add support for third party auth and add Slack OIDC provider
  • Loading branch information
jan-tennert authored Aug 15, 2024
2 parents 552a42a + 0459c0d commit 446d3d7
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.github.jan.supabase.gotrue

import io.github.jan.supabase.SupabaseClient
import io.github.jan.supabase.annotations.SupabaseInternal
import io.github.jan.supabase.plugins.MainConfig
import io.github.jan.supabase.plugins.MainPlugin

/**
* Returns the access token used for requests. The token is resolved in the following order:
* 1. [jwtToken] if not null
* 2. [SupabaseClient.resolveAccessToken] if not null
* 3. [Auth.currentAccessTokenOrNull] if the Auth plugin is installed
* 4. [SupabaseClient.supabaseKey] if [keyAsFallback] is true
*/
@SupabaseInternal
suspend fun SupabaseClient.resolveAccessToken(
jwtToken: String? = null,
keyAsFallback: Boolean = true
): String? {
val key = if(keyAsFallback) supabaseKey else null
return jwtToken ?: accessToken?.invoke()
?: pluginManager.getPluginOrNull(Auth)?.currentAccessTokenOrNull() ?: key
}

/**
* Returns the access token used for requests. The token is resolved in the following order:
* 1. [MainConfig.jwtToken] if not null
* 2. [SupabaseClient.resolveAccessToken] if not null
* 3. [Auth.currentAccessTokenOrNull] if the Auth plugin is installed
* 4. [SupabaseClient.supabaseKey] if [keyAsFallback] is true
*/
@SupabaseInternal
suspend fun <C : MainConfig> SupabaseClient.resolveAccessToken(
plugin: MainPlugin<C>,
keyAsFallback: Boolean = true
) = resolveAccessToken(plugin.config.jwtToken, keyAsFallback)
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ internal class AuthImpl(
override val pluginKey: String
get() = Auth.key

init {
if(supabaseClient.accessToken != null) error("The Auth plugin is not available when using a custom access token provider. Please uninstall the Auth plugin.")
}

override fun init() {
setupPlatform()
if (config.autoLoadFromStorage) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ class AuthenticatedSupabaseApi @SupabaseInternal constructor(
private val jwtToken: String? = null // Can be configured plugin-wide. By default, all plugins use the token from the current session
): SupabaseApi(resolveUrl, parseErrorResponse, supabaseClient) {

override suspend fun rawRequest(url: String, builder: HttpRequestBuilder.() -> Unit): HttpResponse = super.rawRequest(url) {
val jwtToken = jwtToken ?: supabaseClient.pluginManager.getPluginOrNull(Auth)?.currentAccessTokenOrNull() ?: supabaseClient.supabaseKey
bearerAuth(jwtToken)
builder()
defaultRequest?.invoke(this)
override suspend fun rawRequest(url: String, builder: HttpRequestBuilder.() -> Unit): HttpResponse {
val accessToken = supabaseClient.resolveAccessToken(jwtToken) ?: error("No access token available")
return super.rawRequest(url) {
bearerAuth(accessToken)
builder()
defaultRequest?.invoke(this)
}
}

suspend fun rawRequest(builder: HttpRequestBuilder.() -> Unit): HttpResponse = rawRequest("", builder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ data object Slack : OAuthProvider() {

}

data object SlackOIDC : OAuthProvider() {

override val name = "slack_oidc"

}

data object Twitch : OAuthProvider() {

override val name = "twitch"
Expand Down
73 changes: 73 additions & 0 deletions GoTrue/src/commonTest/kotlin/AccessTokenTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import io.github.jan.supabase.gotrue.Auth
import io.github.jan.supabase.gotrue.auth
import io.github.jan.supabase.gotrue.minimalSettings
import io.github.jan.supabase.gotrue.resolveAccessToken
import io.github.jan.supabase.testing.createMockedSupabaseClient
import kotlinx.coroutines.test.runTest
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull

class AccessTokenTest {

@Test
fun testAccessTokenWithJwtToken() {
runTest {
val client = createMockedSupabaseClient(
configuration = {
install(Auth) {
minimalSettings()
}
}
)
client.auth.importAuthToken("myAuth") //this should be ignored as per plugin tokens override the used access token
assertEquals("myJwtToken", client.resolveAccessToken("myJwtToken"))
}
}

@Test
fun testAccessTokenWithKeyAsFallback() {
runTest {
val client = createMockedSupabaseClient(supabaseKey = "myKey")
assertEquals("myKey", client.resolveAccessToken())
}
}

@Test
fun testAccessTokenWithoutKey() {
runTest {
val client = createMockedSupabaseClient()
assertNull(client.resolveAccessToken(keyAsFallback = false))
}
}

@Test
fun testAccessTokenWithCustomAccessToken() {
runTest {
val client = createMockedSupabaseClient(
configuration = {
accessToken = {
"myCustomToken"
}
}
)
assertEquals("myCustomToken", client.resolveAccessToken())
}
}

@Test
fun testAccessTokenWithAuth() {
runTest {
val client = createMockedSupabaseClient(
configuration = {
install(Auth) {
minimalSettings()
}
}
)
client.auth.importAuthToken("myAuth")
assertEquals("myAuth", client.resolveAccessToken())
}
}

}
17 changes: 17 additions & 0 deletions GoTrue/src/commonTest/kotlin/AuthTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import kotlinx.coroutines.test.runTest
import kotlinx.serialization.json.buildJsonObject
import kotlin.test.Test
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertIs
import kotlin.test.assertNull

Expand Down Expand Up @@ -46,6 +47,22 @@ class AuthTest {
}
}

@Test
fun testErrorWhenUsingAccessToken() {
runTest {
assertFailsWith<IllegalStateException> {
createMockedSupabaseClient(
configuration = {
accessToken = {
"myToken"
}
install(Auth)
}
)
}
}
}

@Test
fun testSavingSessionToStorage() {
runTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.github.jan.supabase.realtime
import io.github.jan.supabase.annotations.SupabaseInternal
import io.github.jan.supabase.collections.AtomicMutableList
import io.github.jan.supabase.decodeIfNotEmptyOrDefault
import io.github.jan.supabase.gotrue.Auth
import io.github.jan.supabase.gotrue.resolveAccessToken
import io.github.jan.supabase.logging.d
import io.github.jan.supabase.logging.e
import io.github.jan.supabase.logging.w
Expand All @@ -17,7 +17,6 @@ import io.ktor.http.headers
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.first
import kotlinx.datetime.Clock
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.buildJsonObject
Expand Down Expand Up @@ -61,9 +60,7 @@ internal class RealtimeChannelImpl(
}
_status.value = RealtimeChannel.Status.SUBSCRIBING
Realtime.logger.d { "Subscribing to channel $topic" }
val currentJwt = realtimeImpl.config.jwtToken ?: supabaseClient.pluginManager.getPluginOrNull(Auth)?.currentSessionOrNull()?.let {
if(it.expiresAt > Clock.System.now()) it.accessToken else null
}
val currentJwt = supabaseClient.resolveAccessToken(realtimeImpl, keyAsFallback = false)
val postgrestChanges = clientChanges.toList()
val joinConfig = RealtimeJoinPayload(RealtimeJoinConfig(broadcastJoinConfig, presenceJoinConfig, postgrestChanges, isPrivate))
val joinConfigObject = buildJsonObject {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package io.github.jan.supabase.storage
import io.github.jan.supabase.SupabaseClient
import io.github.jan.supabase.exceptions.HttpRequestException
import io.github.jan.supabase.exceptions.RestException
import io.github.jan.supabase.gotrue.Auth
import io.github.jan.supabase.gotrue.resolveAccessToken
import io.github.jan.supabase.storage.resumable.ResumableClient
import io.ktor.client.plugins.HttpRequestTimeoutException
import io.ktor.utils.io.ByteReadChannel
Expand Down Expand Up @@ -314,8 +314,8 @@ sealed interface BucketApi {
* **Authentication: Bearer <your_access_token>**
* @param path The path to download
*/
fun BucketApi.authenticatedRequest(path: String): Pair<String, String> {
suspend fun BucketApi.authenticatedRequest(path: String): Pair<String, String> {
val url = authenticatedUrl(path)
val token = supabaseClient.storage.config.jwtToken ?: supabaseClient.pluginManager.getPluginOrNull(Auth)?.currentAccessTokenOrNull() ?: supabaseClient.supabaseKey
val token = supabaseClient.resolveAccessToken(supabaseClient.storage) ?: error("No access token available")
return token to url
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package io.github.jan.supabase

/**
* Optional function for using a third-party authentication system with
* Supabase. The function should return an access token or ID token (JWT) by
* obtaining it from the third-party auth client library. Note that this
* function may be called concurrently and many times. Use memoization and
* locking techniques if this is not supported by the client libraries.
*/
typealias AccessTokenProvider = suspend () -> String
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ sealed interface SupabaseClient {
val pluginManager: PluginManager

/**
* The http client used to interact with the supabase api
* The http client used to interact with the Supabase api
*/
val httpClient: KtorSupabaseHttpClient

Expand All @@ -53,6 +53,12 @@ sealed interface SupabaseClient {
*/
val defaultSerializer: SupabaseSerializer

/**
* The custom access token provider used to provide custom access tokens for requests. Configured within the [SupabaseClientBuilder]
*/
@SupabaseInternal
val accessToken: AccessTokenProvider?

/**
* Releases all resources held by the [httpClient] and all plugins the [pluginManager]
*/
Expand Down Expand Up @@ -88,6 +94,7 @@ internal class SupabaseClientImpl(
requestTimeout: Long,
httpEngine: HttpClientEngine?,
override val defaultSerializer: SupabaseSerializer,
override val accessToken: AccessTokenProvider?,
) : SupabaseClient {

init {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ class SupabaseClientBuilder @PublishedApi internal constructor(private val supab
*/
var defaultSerializer: SupabaseSerializer = KotlinXSerializer(Json { ignoreUnknownKeys = true })

/**
* Optional function for using a third-party authentication system with
* Supabase. The function should return an access token or ID token (JWT) by
* obtaining it from the third-party auth client library. Note that this
* function may be called concurrently and many times. Use memoization and
* locking techniques if this is not supported by the client libraries.
*
* When set, the Auth plugin from `auth-kt` cannot be used.
* Create another client if you wish to use Supabase Auth and third-party
* authentications concurrently in the same application.
*/
var accessToken: AccessTokenProvider? = null

private val httpConfigOverrides = mutableListOf<HttpClientConfig<*>.() -> Unit>()
private val plugins = mutableMapOf<String, ((SupabaseClient) -> SupabasePlugin<*>)>()

Expand Down Expand Up @@ -95,7 +108,8 @@ class SupabaseClientBuilder @PublishedApi internal constructor(private val supab
useHTTPS,
requestTimeout.inWholeMilliseconds,
httpEngine,
defaultSerializer
defaultSerializer,
accessToken
)
}

Expand Down
14 changes: 14 additions & 0 deletions test-common/src/commonTest/kotlin/SupabaseClientTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ class SupabaseClientTest {
}
}

@Test
fun testAccessTokenProvider() {
runTest {
val client = createMockedSupabaseClient(
configuration = {
accessToken = {
"myToken"
}
}
)
assertEquals("myToken", client.accessToken?.invoke())
}
}

@Test
fun testDefaultLogLevel() {
createMockedSupabaseClient(
Expand Down

0 comments on commit 446d3d7

Please sign in to comment.