From 916b64feddcecf5c6821f0d207868b03e4fc5b95 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Wed, 15 Sep 2021 11:24:01 -0700 Subject: [PATCH] oidc: add option to override discovered issuer URL --- oidc/oidc.go | 31 +++++++++++++++++++++++++++++-- oidc/oidc_test.go | 42 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/oidc/oidc.go b/oidc/oidc.go index 62898d1b..3e1d80e0 100644 --- a/oidc/oidc.go +++ b/oidc/oidc.go @@ -39,6 +39,10 @@ var ( errInvalidAtHash = errors.New("access token hash does not match value in ID token") ) +type contextKey int + +var issuerURLKey contextKey + // ClientContext returns a new Context that carries the provided HTTP client. // // This method sets the same context key used by the golang.org/x/oauth2 package, @@ -65,6 +69,25 @@ func cloneContext(ctx context.Context) context.Context { return cp } +// InsecureIssuerURLContext allows discovery to work when the issuer_url reported +// by upstream is mismatched with the discovery URL. This is meant for integration +// with off-spec providers such as Azure. +// +// discoveryBaseURL := "https://login.microsoftonline.com/organizations/v2.0" +// issuerURL := "https://login.microsoftonline.com/my-tenantid/v2.0" +// +// ctx := oidc.InsecureIssuerURLContext(parentContext, issuerURL) +// +// // Provider will be discovered with the discoveryBaseURL, but use issuerURL +// // for future issuer validation. +// provider, err := oidc.NewProvider(ctx, discoveryBaseURL) +// +// This is insecure because validating the correct issuer is critical for multi-tenant +// proivders. Any overrides here MUST be carefully reviewed. +func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Context { + return context.WithValue(ctx, issuerURLKey, issuerURL) +} + func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) { client := http.DefaultClient if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { @@ -142,7 +165,11 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) { return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err) } - if p.Issuer != issuer { + issuerURL, skipIssuerValidation := ctx.Value(issuerURLKey).(string) + if !skipIssuerValidation { + issuerURL = issuer + } + if p.Issuer != issuerURL && !skipIssuerValidation { return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer) } var algs []string @@ -152,7 +179,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) { } } return &Provider{ - issuer: p.Issuer, + issuer: issuerURL, authURL: p.AuthURL, tokenURL: p.TokenURL, userInfoURL: p.UserInfoURL, diff --git a/oidc/oidc_test.go b/oidc/oidc_test.go index 554f3da8..26352082 100644 --- a/oidc/oidc_test.go +++ b/oidc/oidc_test.go @@ -104,14 +104,16 @@ func TestAccessTokenVerification(t *testing.T) { func TestNewProvider(t *testing.T) { tests := []struct { - name string - data string - trailingSlash bool - wantAuthURL string - wantTokenURL string - wantUserInfoURL string - wantAlgorithms []string - wantErr bool + name string + data string + issuerURLOverride string + trailingSlash bool + wantAuthURL string + wantTokenURL string + wantUserInfoURL string + wantIssuerURL string + wantAlgorithms []string + wantErr bool }{ { name: "basic_case", @@ -165,6 +167,21 @@ func TestNewProvider(t *testing.T) { }`, wantErr: true, }, + { + name: "mismatched_issuer_discovery_override", + issuerURLOverride: "https://example.com", + data: `{ + "issuer": "ISSUER", + "authorization_endpoint": "https://example.com/auth", + "token_endpoint": "https://example.com/token", + "jwks_uri": "https://example.com/keys", + "id_token_signing_alg_values_supported": ["RS256"] + }`, + wantIssuerURL: "https://example.com", + wantAuthURL: "https://example.com/auth", + wantTokenURL: "https://example.com/token", + wantAlgorithms: []string{"RS256"}, + }, { name: "issuer_with_trailing_slash", data: `{ @@ -269,6 +286,10 @@ func TestNewProvider(t *testing.T) { issuer += "/" } + if test.issuerURLOverride != "" { + ctx = InsecureIssuerURLContext(ctx, test.issuerURLOverride) + } + p, err := NewProvider(ctx, issuer) if err != nil { if !test.wantErr { @@ -280,6 +301,11 @@ func TestNewProvider(t *testing.T) { t.Fatalf("NewProvider(): expected error") } + if test.wantIssuerURL != "" && p.issuer != test.wantIssuerURL { + t.Errorf("NewProvider() unexpected issuer value, got=%s, want=%s", + p.issuer, test.wantIssuerURL) + } + if p.authURL != test.wantAuthURL { t.Errorf("NewProvider() unexpected authURL value, got=%s, want=%s", p.authURL, test.wantAuthURL)