Skip to content

Commit

Permalink
refactor: enhance deviceRequest struct
Browse files Browse the repository at this point in the history
  • Loading branch information
nsklikas committed Dec 19, 2024
1 parent 45462ac commit b30ae67
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 183 deletions.
29 changes: 28 additions & 1 deletion device_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,41 @@

package fosite

type UserCodeState int16

const (
// User code is active
UserCodeUnused = UserCodeState(0)
// User code has been accepted
UserCodeAccepted = UserCodeState(1)
// User code has been rejected
UserCodeRejected = UserCodeState(2)
)

// DeviceRequest is an implementation of DeviceRequester
type DeviceRequest struct {
UserCodeState UserCodeState
Request
}

func (d *DeviceRequest) GetUserCodeState() UserCodeState {
return d.UserCodeState
}

func (d *DeviceRequest) SetUserCodeState(state UserCodeState) {
d.UserCodeState = state
}

func (d *DeviceRequest) Sanitize(allowedParameters []string) Requester {
r, _ := d.Request.Sanitize(allowedParameters).(*Request)
d.Request = *r
return d
}

// NewDeviceRequest returns a new device request
func NewDeviceRequest() *DeviceRequest {
return &DeviceRequest{
Request: *NewRequest(),
UserCodeState: UserCodeUnused,
Request: *NewRequest(),
}
}
2 changes: 1 addition & 1 deletion device_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestWriteDeviceUserResponse(t *testing.T) {
ctx := context.Background()

rw := httptest.NewRecorder()
ar := &Request{}
ar := &DeviceRequest{}
resp := &DeviceResponse{}
resp.SetUserCode("AAAA")
resp.SetDeviceCode("BBBB")
Expand Down
2 changes: 1 addition & 1 deletion handler/rfc8628/auth_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (d *DeviceAuthHandler) handleDeviceAuthSession(ctx context.Context, dar fos
return "", "", err
}

if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil)); err == nil {
if err = d.Storage.CreateDeviceAuthSession(ctx, deviceCodeSignature, userCodeSignature, dar.Sanitize(nil).(fosite.DeviceRequester)); err == nil {
return deviceCode, userCode, nil
}
}
Expand Down
4 changes: 2 additions & 2 deletions handler/rfc8628/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ type RFC8628CoreStorage interface {
// DeviceAuthStorage handles the device auth session storage
type DeviceAuthStorage interface {
// CreateDeviceAuthSession stores the device auth request session.
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.Requester) (err error)
CreateDeviceAuthSession(ctx context.Context, deviceCodeSignature, userCodeSignature string, request fosite.DeviceRequester) (err error)

// GetDeviceCodeSession hydrates the session based on the given device code and returns the device request.
// If the device code has been invalidated with `InvalidateDeviceCodeSession`, this
// method should return the ErrInvalidatedDeviceCode error.
//
// Make sure to also return the fosite.Requester value when returning the fosite.ErrInvalidatedDeviceCode error!
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.Requester, err error)
GetDeviceCodeSession(ctx context.Context, signature string, session fosite.Session) (request fosite.DeviceRequester, err error)

// InvalidateDeviceCodeSession is called when a device code is being used. The state of the device
// code should be set to invalid and consecutive requests to GetDeviceCodeSession should return the
Expand Down
15 changes: 13 additions & 2 deletions handler/rfc8628/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,30 @@ type RFC8628CodeStrategy interface {

// DeviceRateLimitStrategy handles the rate limiting strategy
type DeviceRateLimitStrategy interface {
// ShouldRateLimit checks whether the token request should be rate-limited
ShouldRateLimit(ctx context.Context, code string) (bool, error)
}

// DeviceCodeStrategy handles the device_code strategy
type DeviceCodeStrategy interface {
// DeviceCodeSignature calculates the signature of a device_code
DeviceCodeSignature(ctx context.Context, code string) (signature string, err error)

// GenerateDeviceCode generates a new device code and signature
GenerateDeviceCode(ctx context.Context) (code string, signature string, err error)
ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) (err error)

// ValidateDeviceCode validates the device_code
ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error)
}

// UserCodeStrategy handles the user_code strategy
type UserCodeStrategy interface {
// UserCodeSignature calculates the signature of a user_code
UserCodeSignature(ctx context.Context, code string) (signature string, err error)

// GenerateUserCode generates a new user code and signature
GenerateUserCode(ctx context.Context) (code string, signature string, err error)
ValidateUserCode(ctx context.Context, r fosite.Requester, code string) (err error)

// ValidateUserCode validates the user_code
ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) (err error)
}
83 changes: 2 additions & 81 deletions handler/rfc8628/strategy_hmacsha.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"strings"
"time"

"github.com/mohae/deepcopy"

"github.com/ory/x/errorsx"

"github.com/ory/x/randx"
Expand All @@ -20,83 +18,6 @@ import (

const POLLING_RATE_LIMITING_LEEWAY = 200 * time.Millisecond

// DeviceFlowSession is a fosite.Session container specific for the device flow.
type DeviceFlowSession interface {
// GetBrowserFlowCompleted returns the flag indicating whether user has completed the browser flow or not.
GetBrowserFlowCompleted() bool

// SetBrowserFlowCompleted allows client to mark user has completed the browser flow.
SetBrowserFlowCompleted(flag bool)

fosite.Session
}

// DefaultDeviceFlowSession is a DeviceFlowSession implementation for the device flow.
type DefaultDeviceFlowSession struct {
ExpiresAt map[fosite.TokenType]time.Time `json:"expires_at"`
Username string `json:"username"`
Subject string `json:"subject"`
Extra map[string]interface{} `json:"extra"`
BrowserFlowCompleted bool `json:"browser_flow_completed"`
}

func (s *DefaultDeviceFlowSession) SetExpiresAt(key fosite.TokenType, exp time.Time) {
if s.ExpiresAt == nil {
s.ExpiresAt = make(map[fosite.TokenType]time.Time)
}
s.ExpiresAt[key] = exp
}

func (s *DefaultDeviceFlowSession) GetExpiresAt(key fosite.TokenType) time.Time {
if s.ExpiresAt == nil {
s.ExpiresAt = make(map[fosite.TokenType]time.Time)
}

if _, ok := s.ExpiresAt[key]; !ok {
return time.Time{}
}
return s.ExpiresAt[key]
}

func (s *DefaultDeviceFlowSession) GetUsername() string {
if s == nil {
return ""
}
return s.Username
}

func (s *DefaultDeviceFlowSession) SetSubject(subject string) {
s.Subject = subject
}

func (s *DefaultDeviceFlowSession) GetSubject() string {
if s == nil {
return ""
}

return s.Subject
}

func (s *DefaultDeviceFlowSession) Clone() fosite.Session {
if s == nil {
return nil
}

return deepcopy.Copy(s).(fosite.Session)
}

func (s *DefaultDeviceFlowSession) GetBrowserFlowCompleted() bool {
if s == nil {
return false
}

return s.BrowserFlowCompleted
}

func (s *DefaultDeviceFlowSession) SetBrowserFlowCompleted(flag bool) {
s.BrowserFlowCompleted = flag
}

// DefaultDeviceStrategy implements the default device strategy
type DefaultDeviceStrategy struct {
Enigma *enigma.HMACStrategy
Expand Down Expand Up @@ -129,7 +50,7 @@ func (h *DefaultDeviceStrategy) UserCodeSignature(ctx context.Context, token str
}

// ValidateUserCode validates a user_code
func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.Requester, code string) error {
func (h *DefaultDeviceStrategy) ValidateUserCode(ctx context.Context, r fosite.DeviceRequester, code string) error {
exp := r.GetSession().GetExpiresAt(fosite.UserCode)
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("User code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx))))
Expand All @@ -156,7 +77,7 @@ func (h *DefaultDeviceStrategy) DeviceCodeSignature(ctx context.Context, token s
}

// ValidateDeviceCode validates a device_code
func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.Requester, code string) error {
func (h *DefaultDeviceStrategy) ValidateDeviceCode(ctx context.Context, r fosite.DeviceRequester, code string) error {
exp := r.GetSession().GetExpiresAt(fosite.DeviceCode)
if exp.IsZero() && r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx)).Before(time.Now().UTC()) {
return errorsx.WithStack(fosite.ErrDeviceExpiredToken.WithHintf("Device code expired at '%s'.", r.GetRequestedAt().Add(h.Config.GetDeviceAndUserCodeLifespan(ctx))))
Expand Down
30 changes: 16 additions & 14 deletions handler/rfc8628/strategy_hmacsha_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,27 @@ var hmacshaStrategy = DefaultDeviceStrategy{
},
}

var hmacValidCase = fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Session: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.UserCode: time.Now().UTC().Add(time.Hour),
fosite.DeviceCode: time.Now().UTC().Add(time.Hour),
var hmacValidCase = fosite.DeviceRequest{
Request: fosite.Request{
Client: &fosite.DefaultClient{
Secret: []byte("foobarfoobarfoobarfoobar"),
},
Session: &fosite.DefaultSession{
ExpiresAt: map[fosite.TokenType]time.Time{
fosite.UserCode: time.Now().UTC().Add(time.Hour),
fosite.DeviceCode: time.Now().UTC().Add(time.Hour),
},
},
},
}

func TestHMACUserCode(t *testing.T) {
for k, c := range []struct {
r fosite.Request
r fosite.DeviceRequester
pass bool
}{
{
r: hmacValidCase,
r: &hmacValidCase,
pass: true,
},
} {
Expand All @@ -56,7 +58,7 @@ func TestHMACUserCode(t *testing.T) {
regex := regexp.MustCompile("[ABCDEFGHIJKLMNOPQRSTUVWXYZ]{8}")
assert.Equal(t, len(regex.FindString(userCode)), len(userCode))

err = hmacshaStrategy.ValidateUserCode(context.TODO(), &c.r, userCode)
err = hmacshaStrategy.ValidateUserCode(context.TODO(), c.r, userCode)
if c.pass {
assert.NoError(t, err)
validate, _ := hmacshaStrategy.Enigma.GenerateHMACForString(context.TODO(), userCode)
Expand All @@ -73,11 +75,11 @@ func TestHMACUserCode(t *testing.T) {

func TestHMACDeviceCode(t *testing.T) {
for k, c := range []struct {
r fosite.Request
r fosite.DeviceRequester
pass bool
}{
{
r: hmacValidCase,
r: &hmacValidCase,
pass: true,
},
} {
Expand All @@ -92,7 +94,7 @@ func TestHMACDeviceCode(t *testing.T) {
strings.TrimPrefix(token, "ory_dc_"),
} {
t.Run(fmt.Sprintf("prefix=%v", k == 0), func(t *testing.T) {
err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), &c.r, token)
err = hmacshaStrategy.ValidateDeviceCode(context.TODO(), c.r, token)
if c.pass {
assert.NoError(t, err)
validate := hmacshaStrategy.Enigma.Signature(token)
Expand Down
22 changes: 9 additions & 13 deletions handler/rfc8628/token_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ func (c *DeviceCodeTokenEndpointHandler) PopulateTokenEndpointResponse(ctx conte
return err
}

var ar fosite.Requester
var ar fosite.DeviceRequester
if ar, err = c.session(ctx, requester, signature); err != nil {
return errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, requester, code); err != nil {
if err = c.DeviceCodeStrategy.ValidateDeviceCode(ctx, ar, code); err != nil {
return errorsx.WithStack(err)
}

Expand Down Expand Up @@ -154,7 +154,7 @@ func (c *DeviceCodeTokenEndpointHandler) HandleTokenEndpointRequest(ctx context.
return errorsx.WithStack(err)
}

var ar fosite.Requester
var ar fosite.DeviceRequester
if ar, err = c.session(ctx, requester, signature); err != nil {
if ar != nil && (errors.Is(err, fosite.ErrInvalidatedAuthorizeCode) || errors.Is(err, fosite.ErrInvalidatedDeviceCode)) {
return c.revokeTokens(ctx, requester.GetID())
Expand Down Expand Up @@ -252,7 +252,7 @@ func (c DeviceCodeTokenEndpointHandler) validateCode(ctx context.Context, reques
return nil
}

func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.Requester, error) {
func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester fosite.AccessRequester, codeSignature string) (fosite.DeviceRequester, error) {
req, err := s.CoreStorage.GetDeviceCodeSession(ctx, codeSignature, requester.GetSession())

if err != nil && errors.Is(err, fosite.ErrInvalidatedDeviceCode) {
Expand All @@ -265,10 +265,6 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f
WithDebug("\"GetDeviceCodeSession\" must return a value for \"fosite.Requester\" when returning \"ErrInvalidatedDeviceCode\".")
}

if err != nil && errors.Is(err, fosite.ErrAuthorizationPending) {
return nil, err
}

if err != nil && errors.Is(err, fosite.ErrNotFound) {
return nil, errorsx.WithStack(fosite.ErrInvalidGrant.WithWrap(err).WithDebug(err.Error()))
}
Expand All @@ -277,14 +273,14 @@ func (s DeviceCodeTokenEndpointHandler) session(ctx context.Context, requester f
return nil, errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebug(err.Error()))
}

session, ok := req.GetSession().(DeviceFlowSession)
if !ok {
return nil, fosite.ErrServerError.WithHint("Wrong authorization request session.")
}
state := req.GetUserCodeState()

if !session.GetBrowserFlowCompleted() {
if state == fosite.UserCodeUnused {
return nil, fosite.ErrAuthorizationPending
}
if state == fosite.UserCodeRejected {
return nil, fosite.ErrAccessDenied
}

return req, err
}
Expand Down
Loading

0 comments on commit b30ae67

Please sign in to comment.