Skip to content

Commit

Permalink
Merge pull request #315 from ericchiang/issuer
Browse files Browse the repository at this point in the history
oidc: add option to override discovered issuer URL
  • Loading branch information
ericchiang authored Sep 17, 2021
2 parents 15b94d9 + 916b64f commit d42db69
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
31 changes: 29 additions & 2 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ var (
errInvalidAtHash = errors.New("access token hash does not match value in ID token")
)

type contextKey int

var issuerURLKey contextKey

// ClientContext returns a new Context that carries the provided HTTP client.
//
// This method sets the same context key used by the golang.org/x/oauth2 package,
Expand All @@ -65,6 +69,25 @@ func cloneContext(ctx context.Context) context.Context {
return cp
}

// InsecureIssuerURLContext allows discovery to work when the issuer_url reported
// by upstream is mismatched with the discovery URL. This is meant for integration
// with off-spec providers such as Azure.
//
// discoveryBaseURL := "https://login.microsoftonline.com/organizations/v2.0"
// issuerURL := "https://login.microsoftonline.com/my-tenantid/v2.0"
//
// ctx := oidc.InsecureIssuerURLContext(parentContext, issuerURL)
//
// // Provider will be discovered with the discoveryBaseURL, but use issuerURL
// // for future issuer validation.
// provider, err := oidc.NewProvider(ctx, discoveryBaseURL)
//
// This is insecure because validating the correct issuer is critical for multi-tenant
// proivders. Any overrides here MUST be carefully reviewed.
func InsecureIssuerURLContext(ctx context.Context, issuerURL string) context.Context {
return context.WithValue(ctx, issuerURLKey, issuerURL)
}

func doRequest(ctx context.Context, req *http.Request) (*http.Response, error) {
client := http.DefaultClient
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
Expand Down Expand Up @@ -142,7 +165,11 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err)
}

if p.Issuer != issuer {
issuerURL, skipIssuerValidation := ctx.Value(issuerURLKey).(string)
if !skipIssuerValidation {
issuerURL = issuer
}
if p.Issuer != issuerURL && !skipIssuerValidation {
return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuer, p.Issuer)
}
var algs []string
Expand All @@ -152,7 +179,7 @@ func NewProvider(ctx context.Context, issuer string) (*Provider, error) {
}
}
return &Provider{
issuer: p.Issuer,
issuer: issuerURL,
authURL: p.AuthURL,
tokenURL: p.TokenURL,
userInfoURL: p.UserInfoURL,
Expand Down
42 changes: 34 additions & 8 deletions oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,16 @@ func TestAccessTokenVerification(t *testing.T) {

func TestNewProvider(t *testing.T) {
tests := []struct {
name string
data string
trailingSlash bool
wantAuthURL string
wantTokenURL string
wantUserInfoURL string
wantAlgorithms []string
wantErr bool
name string
data string
issuerURLOverride string
trailingSlash bool
wantAuthURL string
wantTokenURL string
wantUserInfoURL string
wantIssuerURL string
wantAlgorithms []string
wantErr bool
}{
{
name: "basic_case",
Expand Down Expand Up @@ -165,6 +167,21 @@ func TestNewProvider(t *testing.T) {
}`,
wantErr: true,
},
{
name: "mismatched_issuer_discovery_override",
issuerURLOverride: "https://example.com",
data: `{
"issuer": "ISSUER",
"authorization_endpoint": "https://example.com/auth",
"token_endpoint": "https://example.com/token",
"jwks_uri": "https://example.com/keys",
"id_token_signing_alg_values_supported": ["RS256"]
}`,
wantIssuerURL: "https://example.com",
wantAuthURL: "https://example.com/auth",
wantTokenURL: "https://example.com/token",
wantAlgorithms: []string{"RS256"},
},
{
name: "issuer_with_trailing_slash",
data: `{
Expand Down Expand Up @@ -269,6 +286,10 @@ func TestNewProvider(t *testing.T) {
issuer += "/"
}

if test.issuerURLOverride != "" {
ctx = InsecureIssuerURLContext(ctx, test.issuerURLOverride)
}

p, err := NewProvider(ctx, issuer)
if err != nil {
if !test.wantErr {
Expand All @@ -280,6 +301,11 @@ func TestNewProvider(t *testing.T) {
t.Fatalf("NewProvider(): expected error")
}

if test.wantIssuerURL != "" && p.issuer != test.wantIssuerURL {
t.Errorf("NewProvider() unexpected issuer value, got=%s, want=%s",
p.issuer, test.wantIssuerURL)
}

if p.authURL != test.wantAuthURL {
t.Errorf("NewProvider() unexpected authURL value, got=%s, want=%s",
p.authURL, test.wantAuthURL)
Expand Down

0 comments on commit d42db69

Please sign in to comment.