From b30ae67c98645fedfc4c2c8b0b0705360be81672 Mon Sep 17 00:00:00 2001 From: Nikos Date: Fri, 15 Nov 2024 12:11:47 +0200 Subject: [PATCH] refactor: enhance deviceRequest struct --- device_request.go | 29 +++++- device_write_test.go | 2 +- handler/rfc8628/auth_handler.go | 2 +- handler/rfc8628/storage.go | 4 +- handler/rfc8628/strategy.go | 15 ++- handler/rfc8628/strategy_hmacsha.go | 83 +--------------- handler/rfc8628/strategy_hmacsha_test.go | 30 +++--- handler/rfc8628/token_handler.go | 22 ++--- handler/rfc8628/token_handler_test.go | 94 ++++++++----------- .../authorize_device_grant_request_test.go | 13 +-- integration/helper_setup_test.go | 2 +- oauth2.go | 6 ++ storage/memory.go | 10 +- 13 files changed, 129 insertions(+), 183 deletions(-) diff --git a/device_request.go b/device_request.go index 0b243b01..ef26c5d8 100644 --- a/device_request.go +++ b/device_request.go @@ -3,14 +3,41 @@ package fosite +type UserCodeState int16 + +const ( + // User code is active + UserCodeUnused = UserCodeState(0) + // User code has been accepted + UserCodeAccepted = UserCodeState(1) + // User code has been rejected + UserCodeRejected = UserCodeState(2) +) + // DeviceRequest is an implementation of DeviceRequester type DeviceRequest struct { + UserCodeState UserCodeState Request } +func (d *DeviceRequest) GetUserCodeState() UserCodeState { + return d.UserCodeState +} + +func (d *DeviceRequest) SetUserCodeState(state UserCodeState) { + d.UserCodeState = state +} + +func (d *DeviceRequest) Sanitize(allowedParameters []string) Requester { + r, _ := d.Request.Sanitize(allowedParameters).(*Request) + d.Request = *r + return d +} + // NewDeviceRequest returns a new device request func NewDeviceRequest() *DeviceRequest { return &DeviceRequest{ - Request: *NewRequest(), + UserCodeState: UserCodeUnused, + Request: *NewRequest(), } } diff --git a/device_write_test.go b/device_write_test.go index fc8ade09..5eae5e06 100644 --- a/device_write_test.go +++ b/device_write_test.go @@ -26,7 +26,7 @@ func TestWriteDeviceUserResponse(t *testing.T) { ctx := context.Background() rw := httptest.NewRecorder() - ar := &Request{} + ar := &DeviceRequest{} resp := &DeviceResponse{} resp.SetUserCode("AAAA") resp.SetDeviceCode("BBBB") diff --git a/handler/rfc8628/auth_handler.go b/handler/rfc8628/auth_handler.go index b5ebec12..5ae47602 100644 --- a/handler/rfc8628/auth_handler.go +++ b/handler/rfc8628/auth_handler.go @@ -66,7 +66,7 @@ func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fos return "", "", err } - if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil { + if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil).(fosite.DeviceRequester)); err == nil { return deviceCode, userCode, nil } } diff --git a/handler/rfc8628/storage.go b/handler/rfc8628/storage.go index e15dc97d..dea0b6a1 100644 --- a/handler/rfc8628/storage.go +++ b/handler/rfc8628/storage.go @@ -20,14 +20,14 @@ type RFC8628CoreStorage interface { // DeviceAuthStorage handles the device auth session storage type DeviceAuthStorage interface { // CreateDeviceAuthSession stores the device auth request session. - CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error) + CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.DeviceRequester) (err error) // GetDeviceCodeSession hydrates the session based on the given device code and returns the device request. // If the device code has been invalidated with `InvalidateDeviceCodeSession`, this // method should return the ErrInvalidatedDeviceCode error. // // Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error! - GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error) + GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.DeviceRequester, err error) // InvalidateDeviceCodeSession is called when a device code is being used. The state of the device // code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the diff --git a/handler/rfc8628/strategy.go b/handler/rfc8628/strategy.go index 67ca2808..cd32bb53 100644 --- a/handler/rfc8628/strategy.go +++ b/handler/rfc8628/strategy.go @@ -18,19 +18,30 @@ type RFC8628CodeStrategy interface { // DeviceRateLimitStrategy handles the rate limiting strategy type DeviceRateLimitStrategy interface { + // ShouldRateLimit checks whether the token request should be rate-limited ShouldRateLimit(ctx context.Context, code string) (bool, error) } // DeviceCodeStrategy handles the device_code strategy type DeviceCodeStrategy interface { + // DeviceCodeSignature calculates the signature of a device_code DeviceCodeSignature(ctx context.Context, code string) (signature string, err error) + + // GenerateDeviceCode generates a new device code and signature GenerateDeviceCode(ctx context.Context) (code string, signature string, err error) - ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error) + + // ValidateDeviceCode validates the device_code + ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error) } // UserCodeStrategy handles the user_code strategy type UserCodeStrategy interface { + // UserCodeSignature calculates the signature of a user_code UserCodeSignature(ctx context.Context, code string) (signature string, err error) + + // GenerateUserCode generates a new user code and signature GenerateUserCode(ctx context.Context) (code string, signature string, err error) - ValidateUserCode(ctx context.Context, r fosite.Requester, code string) (err error) + + // ValidateUserCode validates the user_code + ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error) } diff --git a/handler/rfc8628/strategy_hmacsha.go b/handler/rfc8628/strategy_hmacsha.go index 6315bd92..8671f504 100644 --- a/handler/rfc8628/strategy_hmacsha.go +++ b/handler/rfc8628/strategy_hmacsha.go @@ -8,8 +8,6 @@ import ( "strings" "time" - "github.com/mohae/deepcopy" - "github.com/ory/x/errorsx" "github.com/ory/x/randx" @@ -20,83 +18,6 @@ import ( const POLLING_RATE_LIMITING_LEEWAY = 200 * time.Millisecond -// DeviceFlowSession is a fosite.Session container specific for the device flow. -type DeviceFlowSession interface { - // GetBrowserFlowCompleted returns the flag indicating whether user has completed the browser flow or not. - GetBrowserFlowCompleted() bool - - // SetBrowserFlowCompleted allows client to mark user has completed the browser flow. - SetBrowserFlowCompleted(flag bool) - - fosite.Session -} - -// DefaultDeviceFlowSession is a DeviceFlowSession implementation for the device flow. -type DefaultDeviceFlowSession struct { - ExpiresAt map[fosite.TokenType]time.Time `json:"expires_at"` - Username string `json:"username"` - Subject string `json:"subject"` - Extra map[string]interface{} `json:"extra"` - BrowserFlowCompleted bool `json:"browser_flow_completed"` -} - -func (s *DefaultDeviceFlowSession) SetExpiresAt(key fosite.TokenType, exp time.Time) { - if s.ExpiresAt == nil { - s.ExpiresAt = make(map[fosite.TokenType]time.Time) - } - s.ExpiresAt[key] = exp -} - -func (s *DefaultDeviceFlowSession) GetExpiresAt(key fosite.TokenType) time.Time { - if s.ExpiresAt == nil { - s.ExpiresAt = make(map[fosite.TokenType]time.Time) - } - - if _, ok := s.ExpiresAt[key]; !ok { - return time.Time{} - } - return s.ExpiresAt[key] -} - -func (s *DefaultDeviceFlowSession) GetUsername() string { - if s == nil { - return "" - } - return s.Username -} - -func (s *DefaultDeviceFlowSession) SetSubject(subject string) { - s.Subject = subject -} - -func (s *DefaultDeviceFlowSession) GetSubject() string { - if s == nil { - return "" - } - - return s.Subject -} - -func (s *DefaultDeviceFlowSession) Clone() fosite.Session { - if s == nil { - return nil - } - - return deepcopy.Copy(s).(fosite.Session) -} - -func (s *DefaultDeviceFlowSession) GetBrowserFlowCompleted() bool { - if s == nil { - return false - } - - return s.BrowserFlowCompleted -} - -func (s *DefaultDeviceFlowSession) SetBrowserFlowCompleted(flag bool) { - s.BrowserFlowCompleted = flag -} - // DefaultDeviceStrategy implements the default device strategy type DefaultDeviceStrategy struct { Enigma *enigma.HMACStrategy @@ -129,7 +50,7 @@ func (h *DefaultDeviceStrategy) UserCodeSignature(ctx context.Context, token str } // ValidateUserCode validates a user_code -func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) error { +func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) error { exp := r.GetSession().GetExpiresAt(fosite.UserCode) if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) { return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("User code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)))) @@ -156,7 +77,7 @@ func (h *DefaultDeviceStrategy) DeviceCodeSignature(ctx context.Context, token s } // ValidateDeviceCode validates a device_code -func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) error { +func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) error { exp := r.GetSession().GetExpiresAt(fosite.DeviceCode) if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) { return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("Device code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)))) diff --git a/handler/rfc8628/strategy_hmacsha_test.go b/handler/rfc8628/strategy_hmacsha_test.go index 666a4725..bb625251 100644 --- a/handler/rfc8628/strategy_hmacsha_test.go +++ b/handler/rfc8628/strategy_hmacsha_test.go @@ -28,25 +28,27 @@ var hmacshaStrategy = DefaultDeviceStrategy{ }, } -var hmacValidCase = fosite.Request{ - Client: &fosite.DefaultClient{ - Secret: []byte("foobarfoobarfoobarfoobar"), - }, - Session: &fosite.DefaultSession{ - ExpiresAt: map[fosite.TokenType]time.Time{ - fosite.UserCode: time.Now().UTC().Add(time.Hour), - fosite.DeviceCode: time.Now().UTC().Add(time.Hour), +var hmacValidCase = fosite.DeviceRequest{ + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + Secret: []byte("foobarfoobarfoobarfoobar"), + }, + Session: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{ + fosite.UserCode: time.Now().UTC().Add(time.Hour), + fosite.DeviceCode: time.Now().UTC().Add(time.Hour), + }, }, }, } func TestHMACUserCode(t *testing.T) { for k, c := range []struct { - r fosite.Request + r fosite.DeviceRequester pass bool }{ { - r: hmacValidCase, + r: &hmacValidCase, pass: true, }, } { @@ -56,7 +58,7 @@ func TestHMACUserCode(t *testing.T) { regex := regexp.MustCompile("[ABCDEFGHIJKLMNOPQRSTUVWXYZ]{8}") assert.Equal(t, len(regex.FindString(userCode)), len(userCode)) - err = hmacshaStrategy.ValidateUserCode(context.TODO(), &c.r, userCode) + err = hmacshaStrategy.ValidateUserCode(context.TODO(), c.r, userCode) if c.pass { assert.NoError(t, err) validate, _ := hmacshaStrategy.Enigma.GenerateHMACForString(context.TODO(), userCode) @@ -73,11 +75,11 @@ func TestHMACUserCode(t *testing.T) { func TestHMACDeviceCode(t *testing.T) { for k, c := range []struct { - r fosite.Request + r fosite.DeviceRequester pass bool }{ { - r: hmacValidCase, + r: &hmacValidCase, pass: true, }, } { @@ -92,7 +94,7 @@ func TestHMACDeviceCode(t *testing.T) { strings.TrimPrefix(token, "ory_dc_"), } { t.Run(fmt.Sprintf("prefix=%v", k == 0), func(t *testing.T) { - err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), &c.r, token) + err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), c.r, token) if c.pass { assert.NoError(t, err) validate := hmacshaStrategy.Enigma.Signature(token) diff --git a/handler/rfc8628/token_handler.go b/handler/rfc8628/token_handler.go index 22cc555f..8e832c27 100644 --- a/handler/rfc8628/token_handler.go +++ b/handler/rfc8628/token_handler.go @@ -62,12 +62,12 @@ func (c *DeviceCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx conte return err } - var ar fosite.Requester + var ar fosite.DeviceRequester if ar, err = c.session(ctx, requester, signature); err != nil { return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code); err != nil { + if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, ar, code); err != nil { return errorsx.WithStack(err) } @@ -154,7 +154,7 @@ func (c *DeviceCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context. return errorsx.WithStack(err) } - var ar fosite.Requester + var ar fosite.DeviceRequester if ar, err = c.session(ctx, requester, signature); err != nil { if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) { return c.revokeTokens(ctx, requester.GetID()) @@ -252,7 +252,7 @@ func (c DeviceCodeTokenEndpointHandler) validateCode(ctx context.Context, reques return nil } -func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) { +func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.DeviceRequester, error) { req, err := s.CoreStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession()) if err != nil && errors.Is(err, fosite.ErrInvalidatedDeviceCode) { @@ -265,10 +265,6 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f WithDebug("\"GetDeviceCodeSession\" must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedDeviceCode\".") } - if err != nil && errors.Is(err, fosite.ErrAuthorizationPending) { - return nil, err - } - if err != nil && errors.Is(err, fosite.ErrNotFound) { return nil, errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error())) } @@ -277,14 +273,14 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f return nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error())) } - session, ok := req.GetSession().(DeviceFlowSession) - if !ok { - return nil, fosite.ErrServerError.WithHint("Wrong authorization request session.") - } + state := req.GetUserCodeState() - if !session.GetBrowserFlowCompleted() { + if state == fosite.UserCodeUnused { return nil, fosite.ErrAuthorizationPending } + if state == fosite.UserCodeRejected { + return nil, fosite.ErrAccessDenied + } return req, err } diff --git a/handler/rfc8628/token_handler_test.go b/handler/rfc8628/token_handler_test.go index c60bb40c..8c0e3f16 100644 --- a/handler/rfc8628/token_handler_test.go +++ b/handler/rfc8628/token_handler_test.go @@ -79,7 +79,7 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, @@ -94,7 +94,7 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: []string{""}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, @@ -109,7 +109,7 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, @@ -130,11 +130,12 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeUnused, Request: fosite.Request{ Client: &fosite.DefaultClient{ ID: "foo", @@ -142,11 +143,10 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { }, RequestedScope: fosite.Arguments{"foo"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ + Session: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.DeviceCode: time.Now().Add(-time.Hour).UTC(), }, - BrowserFlowCompleted: false, }, RequestedAt: time.Now().Add(-2 * time.Hour).UTC(), }, @@ -173,20 +173,20 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, GrantedScope: fosite.Arguments{"foo", "offline"}, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ + Session: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.DeviceCode: time.Now().Add(-time.Hour).UTC(), }, - BrowserFlowCompleted: true, }, RequestedAt: time.Now().Add(-2 * time.Hour).UTC(), }, @@ -211,20 +211,20 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "bar"}, RequestedScope: fosite.Arguments{"foo"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ + Session: &fosite.DefaultSession{ ExpiresAt: map[fosite.TokenType]time.Time{ fosite.DeviceCode: time.Now().Add(time.Hour).UTC(), }, - BrowserFlowCompleted: true, }, }, }, @@ -248,11 +248,12 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ ID: "foo", @@ -260,10 +261,8 @@ func TestDeviceUserCode_HandleTokenEndpointRequest(t *testing.T) { }, RequestedScope: fosite.Arguments{"foo"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, }, setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest) { @@ -332,19 +331,18 @@ func TestDeviceUserCode_HandleTokenEndpointRequest_RateLimiting(t *testing.T) { ID: "foo", GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{}, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, } authreq := &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, } @@ -392,9 +390,7 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, @@ -409,9 +405,7 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, @@ -431,21 +425,18 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo", "bar", "offline"}, GrantedScope: fosite.Arguments{"foo", "offline"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, }, setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, _ *fosite.Config) { @@ -474,21 +465,18 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo", "bar"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, }, setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, config *fosite.Config) { @@ -518,21 +506,18 @@ func TestDeviceUserCode_PopulateTokenEndpointResponse(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, }, authreq: &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo", "bar"}, GrantedScope: fosite.Arguments{"foo"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, }, setup: func(t *testing.T, areq *fosite.AccessRequest, authreq *fosite.DeviceRequest, config *fosite.Config) { @@ -603,14 +588,13 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { deviceStrategy := RFC8628HMACSHAStrategy authreq := &fosite.DeviceRequest{ + UserCodeState: fosite.UserCodeAccepted, Request: fosite.Request{ Client: &fosite.DefaultClient{ID: "foo", GrantTypes: []string{"urn:ietf:params:oauth:grant-type:device_code"}}, RequestedScope: fosite.Arguments{"foo", "offline"}, GrantedScope: fosite.Arguments{"foo", "offline"}, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, - RequestedAt: time.Now().UTC(), + Session: &fosite.DefaultSession{}, + RequestedAt: time.Now().UTC(), }, } @@ -620,9 +604,7 @@ func TestDeviceUserCodeTransactional_HandleTokenEndpointRequest(t *testing.T) { Client: &fosite.DefaultClient{ GrantTypes: fosite.Arguments{"urn:ietf:params:oauth:grant-type:device_code", "refresh_token"}, }, - Session: &DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - }, + Session: &fosite.DefaultSession{}, RequestedAt: time.Now().UTC(), }, } diff --git a/integration/authorize_device_grant_request_test.go b/integration/authorize_device_grant_request_test.go index 9eb2460a..20930971 100644 --- a/integration/authorize_device_grant_request_test.go +++ b/integration/authorize_device_grant_request_test.go @@ -8,8 +8,6 @@ import ( "fmt" "testing" - "github.com/ory/fosite/handler/rfc8628" - "github.com/ory/fosite" "github.com/ory/fosite/compose" "github.com/ory/fosite/internal/gen" @@ -24,7 +22,7 @@ func TestDeviceFlow(t *testing.T) { } func runDeviceFlowTest(t *testing.T) { - session := &rfc8628.DefaultDeviceFlowSession{} + session := &fosite.DefaultSession{} fc := &fosite.Config{ DeviceVerificationURL: "https://example.com/", @@ -124,9 +122,7 @@ func runDeviceFlowTest(t *testing.T) { } func runDeviceFlowAccessTokenTest(t *testing.T) { - session := &rfc8628.DefaultDeviceFlowSession{ - BrowserFlowCompleted: true, - } + session := &fosite.DefaultSession{} fc := &fosite.Config{ DeviceVerificationURL: "https://example.com/", @@ -147,6 +143,11 @@ func runDeviceFlowAccessTokenTest(t *testing.T) { }, } resp, _ := oauthClient.DeviceAuth(context.Background()) + deviceCodeSignature, err := compose.NewDeviceStrategy(fc).DeviceCodeSignature(context.Background(), resp.DeviceCode) + require.NoError(t, err) + d := fositeStore.DeviceAuths[deviceCodeSignature] + d.SetUserCodeState(fosite.UserCodeAccepted) + fositeStore.DeviceAuths[deviceCodeSignature] = d for k, c := range []struct { description string diff --git a/integration/helper_setup_test.go b/integration/helper_setup_test.go index f314e885..4c121104 100644 --- a/integration/helper_setup_test.go +++ b/integration/helper_setup_test.go @@ -123,7 +123,7 @@ var fositeStore = &storage.MemoryStore{ AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, PARSessions: map[string]fosite.AuthorizeRequester{}, - DeviceAuths: map[string]fosite.Requester{}, + DeviceAuths: map[string]fosite.DeviceRequester{}, DeviceCodesRequestIDs: map[string]storage.DeviceAuthPair{}, UserCodesRequestIDs: map[string]string{}, } diff --git a/oauth2.go b/oauth2.go index 9fc689cd..cbae4cfe 100644 --- a/oauth2.go +++ b/oauth2.go @@ -286,6 +286,12 @@ type AccessRequester interface { // DeviceRequester is an device endpoint's request context. type DeviceRequester interface { + // GetUserCodeState returns the state of the user code + GetUserCodeState() UserCodeState + + // SetUserCodeState sets the state of the user code + SetUserCodeState(state UserCodeState) + Requester } diff --git a/storage/memory.go b/storage/memory.go index 31ba4bc6..c1d6ff1a 100644 --- a/storage/memory.go +++ b/storage/memory.go @@ -47,7 +47,7 @@ type MemoryStore struct { IDSessions map[string]fosite.Requester AccessTokens map[string]fosite.Requester RefreshTokens map[string]StoreRefreshToken - DeviceAuths map[string]fosite.Requester + DeviceAuths map[string]fosite.DeviceRequester PKCES map[string]fosite.Requester Users map[string]MemoryUserRelation BlacklistedJTIs map[string]time.Time @@ -83,7 +83,7 @@ func NewMemoryStore() *MemoryStore { IDSessions: make(map[string]fosite.Requester), AccessTokens: make(map[string]fosite.Requester), RefreshTokens: make(map[string]StoreRefreshToken), - DeviceAuths: make(map[string]fosite.Requester), + DeviceAuths: make(map[string]fosite.DeviceRequester), PKCES: make(map[string]fosite.Requester), Users: make(map[string]MemoryUserRelation), AccessTokenRequestIDs: make(map[string]string), @@ -154,7 +154,7 @@ func NewExampleStore() *MemoryStore { AccessTokens: map[string]fosite.Requester{}, RefreshTokens: map[string]StoreRefreshToken{}, PKCES: map[string]fosite.Requester{}, - DeviceAuths: make(map[string]fosite.Requester), + DeviceAuths: make(map[string]fosite.DeviceRequester), AccessTokenRequestIDs: map[string]string{}, RefreshTokenRequestIDs: map[string]string{}, DeviceCodesRequestIDs: make(map[string]DeviceAuthPair), @@ -521,7 +521,7 @@ func (s *MemoryStore) RotateRefreshToken(ctx context.Context, requestID string, } // CreateDeviceAuthSession stores the device auth session -func (s *MemoryStore) CreateDeviceAuthSession(_ context.Context, deviceCodeSignature, userCodeSignature string, req fosite.Requester) error { +func (s *MemoryStore) CreateDeviceAuthSession(_ context.Context, deviceCodeSignature, userCodeSignature string, req fosite.DeviceRequester) error { s.deviceAuthsRequestIDsMutex.Lock() defer s.deviceAuthsRequestIDsMutex.Unlock() s.deviceAuthsMutex.Lock() @@ -534,7 +534,7 @@ func (s *MemoryStore) CreateDeviceAuthSession(_ context.Context, deviceCodeSigna } // GetDeviceCodeSession gets the device code session -func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.Requester, error) { +func (s *MemoryStore) GetDeviceCodeSession(_ context.Context, signature string, _ fosite.Session) (fosite.DeviceRequester, error) { s.deviceAuthsMutex.RLock() defer s.deviceAuthsMutex.RUnlock()