From 24db6b931174bacdc2fa01d51ab4c4b534c467b6 Mon Sep 17 00:00:00 2001 From: Nikos Date: Thu, 17 Oct 2024 11:00:11 +0300 Subject: [PATCH] refactor: simplify handler and test logic --- device_request_handler.go | 17 +++---- device_request_handler_test.go | 88 ++++++++++++++++++---------------- device_request_test.go | 18 ------- fosite_test.go | 11 +++++ 4 files changed, 68 insertions(+), 66 deletions(-) delete mode 100644 device_request_test.go diff --git a/device_request_handler.go b/device_request_handler.go index 2f351083..a3421b30 100644 --- a/device_request_handler.go +++ b/device_request_handler.go @@ -27,7 +27,7 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic return request, errorsx.WithStack(ErrInvalidRequest.WithHintf("HTTP method is '%s', expected 'POST'.", r.Method)) } if err := r.ParseForm(); err != nil { - return nil, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) + return request, errorsx.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithWrap(err).WithDebug(err.Error())) } if len(r.PostForm) == 0 { return request, errorsx.WithStack(ErrInvalidRequest.WithHint("The POST body can not be empty.")) @@ -44,11 +44,11 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic request.Client = client if !client.GetGrantTypes().Has(string(GrantTypeDeviceCode)) { - return nil, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant.")) + return request, errorsx.WithStack(ErrInvalidGrant.WithHint("The requested OAuth 2.0 Client does not have the 'urn:ietf:params:oauth:grant-type:device_code' grant.")) } if err := f.validateDeviceScope(ctx, r, request); err != nil { - return nil, err + return request, err } if err := f.validateAudience(ctx, r, request); err != nil { @@ -59,12 +59,13 @@ func (f *Fosite) NewDeviceRequest(ctx context.Context, r *http.Request) (_ Devic } func (f *Fosite) validateDeviceScope(ctx context.Context, r *http.Request, request *DeviceRequest) error { - scope := RemoveEmpty(strings.Split(request.Form.Get("scope"), " ")) - for _, permission := range scope { - if !f.Config.GetScopeStrategy(ctx)(request.Client.GetScopes(), permission) { - return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", permission)) + scopes := RemoveEmpty(strings.Split(request.Form.Get("scope"), " ")) + scopeStrategy := f.Config.GetScopeStrategy(ctx) + for _, scope := range scopes { + if !scopeStrategy(request.Client.GetScopes(), scope) { + return errorsx.WithStack(ErrInvalidScope.WithHintf("The OAuth 2.0 Client is not allowed to request scope '%s'.", scope)) } } - request.SetRequestedScopes(scope) + request.SetRequestedScopes(scopes) return nil } diff --git a/device_request_handler_test.go b/device_request_handler_test.go index 38cb336f..6b0d4d11 100644 --- a/device_request_handler_test.go +++ b/device_request_handler_test.go @@ -22,7 +22,17 @@ import ( func TestNewDeviceRequestWithPublicClient(t *testing.T) { ctrl := gomock.NewController(t) store := internal.NewMockStorage(ctrl) - client := &DefaultClient{ID: "client_id"} + deviceClient := &DefaultClient{ID: "client_id"} + deviceClient.Public = true + deviceClient.Scopes = []string{"17", "42"} + deviceClient.Audience = []string{"aud2"} + deviceClient.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + + authCodeClient := &DefaultClient{ID: "client_id_2"} + authCodeClient.Public = true + authCodeClient.Scopes = []string{"17", "42"} + authCodeClient.GrantTypes = []string{"authorization_code"} + defer ctrl.Finish() config := &Config{ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} fosite := &Fosite{Store: store, Config: config} @@ -63,10 +73,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, expectedError: ErrInvalidScope, }, { @@ -74,29 +81,22 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { form: url.Values{ "client_id": {"client_id"}, "scope": {"17 42"}, - "audience": {"aud"}, + "audience": {"random_aud"}, }, method: "POST", mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.Audience = []string{"aud2"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, expectedError: ErrInvalidRequest, }, { description: "fails because it doesn't have the proper grant", form: url.Values{ - "client_id": {"client_id"}, + "client_id": {"client_id_2"}, "scope": {"17 42"}, }, method: "POST", mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"authorization_code"} + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id_2")).Return(authCodeClient, nil) }, expectedError: ErrInvalidGrant, }, { @@ -107,10 +107,7 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { }, method: "POST", mock: func() { - store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = true - client.Scopes = []string{"17", "42"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(deviceClient, nil) }, }} { t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { @@ -123,10 +120,8 @@ func TestNewDeviceRequestWithPublicClient(t *testing.T) { } ar, err := fosite.NewDeviceRequest(context.Background(), r) - if c.expectedError != nil { - assert.EqualError(t, err, c.expectedError.Error()) - } else { - require.NoError(t, err) + require.ErrorIs(t, err, c.expectedError) + if c.expectedError == nil { assert.NotNil(t, ar.GetRequestedAt()) } }) @@ -141,6 +136,12 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { defer ctrl.Finish() config := &Config{ClientSecretsHasher: hasher, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy} fosite := &Fosite{Store: store, Config: config} + + client.Public = false + client.Secret = []byte("client_secret") + client.Scopes = []string{"foo", "bar"} + client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} + for k, c := range []struct { header http.Header form url.Values @@ -148,8 +149,8 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { expectedError error mock func() expect DeviceRequester + description string }{ - // No client authn provided { form: url.Values{ "client_id": {"client_id"}, @@ -159,14 +160,26 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { method: "POST", mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = false - client.Secret = []byte("client_secret") - client.Scopes = []string{"foo", "bar"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} hasher.EXPECT().Compare(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("")) }, + description: "Should failed becaue no client authn provided.", + }, + { + form: url.Values{ + "client_id": {"client_id2"}, + "scope": {"foo bar"}, + }, + header: http.Header{ + "Authorization": {basicAuth("client_id", "client_secret")}, + }, + expectedError: ErrInvalidRequest, + method: "POST", + mock: func() { + store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) + hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil) + }, + description: "should fail because different client is used in authn than in form", }, - // success { form: url.Values{ "client_id": {"client_id"}, @@ -178,15 +191,12 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { method: "POST", mock: func() { store.EXPECT().GetClient(gomock.Any(), gomock.Eq("client_id")).Return(client, nil) - client.Public = false - client.Secret = []byte("client_secret") - client.Scopes = []string{"foo", "bar"} - client.GrantTypes = []string{"urn:ietf:params:oauth:grant-type:device_code"} hasher.EXPECT().Compare(gomock.Any(), gomock.Eq([]byte("client_secret")), gomock.Eq([]byte("client_secret"))).Return(nil) }, + description: "should succeed", }, } { - t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + t.Run(fmt.Sprintf("case=%d description=%s", k, c.description), func(t *testing.T) { c.mock() r := &http.Request{ Header: c.header, @@ -196,11 +206,9 @@ func TestNewDeviceRequestWithClientAuthn(t *testing.T) { } req, err := fosite.NewDeviceRequest(context.Background(), r) - if c.expectedError != nil { - assert.EqualError(t, err, c.expectedError.Error()) - } else { - require.NoError(t, err) - assert.NotNil(t, req.GetRequestedAt()) + require.ErrorIs(t, err, c.expectedError) + if c.expectedError == nil { + assert.NotZero(t, req.GetRequestedAt()) } }) } diff --git a/device_request_test.go b/device_request_test.go deleted file mode 100644 index 7e67c052..00000000 --- a/device_request_test.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright © 2024 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package fosite - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestDeviceRequest(t *testing.T) { - r := NewDeviceRequest() - r.Client = &DefaultClient{} - r.SetRequestedScopes([]string{"17", "42"}) - assert.True(t, r.GetRequestedScopes().Has("17", "42")) - assert.Equal(t, r.Client, r.GetClient()) -} diff --git a/fosite_test.go b/fosite_test.go index 2c86b498..0a273fdc 100644 --- a/fosite_test.go +++ b/fosite_test.go @@ -13,6 +13,7 @@ import ( . "github.com/ory/fosite" "github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/par" + "github.com/ory/fosite/handler/rfc8628" ) func TestAuthorizeEndpointHandlers(t *testing.T) { @@ -25,6 +26,16 @@ func TestAuthorizeEndpointHandlers(t *testing.T) { assert.Equal(t, hs[0], h) } +func TestDeviceAuthorizeEndpointHandlers(t *testing.T) { + h := &rfc8628.DeviceAuthHandler{} + hs := DeviceEndpointHandlers{} + hs.Append(h) + hs.Append(h) + hs.Append(&rfc8628.DeviceAuthHandler{}) + assert.Len(t, hs, 1) + assert.Equal(t, hs[0], h) +} + func TestTokenEndpointHandlers(t *testing.T) { h := &oauth2.AuthorizeExplicitGrantHandler{} hs := TokenEndpointHandlers{}