From 399ea3dd076c31be2b0364e689cc448e28973b4b Mon Sep 17 00:00:00 2001 From: Bram Vanbilsen Date: Mon, 5 Feb 2024 11:25:01 -0600 Subject: [PATCH 1/2] Implemented hashes session tokens in store --- data.go | 15 +++++++++++ data_test.go | 75 ++++++++++++++++++++++++++++++++++++++++++++++++++++ session.go | 3 +++ 3 files changed, 93 insertions(+) diff --git a/data.go b/data.go index f5585a9..3ffda75 100644 --- a/data.go +++ b/data.go @@ -3,6 +3,7 @@ package scs import ( "context" "crypto/rand" + "crypto/sha256" "encoding/base64" "fmt" "sort" @@ -623,6 +624,11 @@ func generateToken() (string, error) { return base64.RawURLEncoding.EncodeToString(b), nil } +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + type contextKey string var ( @@ -638,6 +644,9 @@ func generateContextKey() contextKey { } func (s *SessionManager) doStoreDelete(ctx context.Context, token string) (err error) { + if s.HashTokenInStore { + token = hashToken(token) + } c, ok := s.Store.(interface { DeleteCtx(context.Context, string) error }) @@ -648,6 +657,9 @@ func (s *SessionManager) doStoreDelete(ctx context.Context, token string) (err e } func (s *SessionManager) doStoreFind(ctx context.Context, token string) (b []byte, found bool, err error) { + if s.HashTokenInStore { + token = hashToken(token) + } c, ok := s.Store.(interface { FindCtx(context.Context, string) ([]byte, bool, error) }) @@ -658,6 +670,9 @@ func (s *SessionManager) doStoreFind(ctx context.Context, token string) (b []byt } func (s *SessionManager) doStoreCommit(ctx context.Context, token string, b []byte, expiry time.Time) (err error) { + if s.HashTokenInStore { + token = hashToken(token) + } c, ok := s.Store.(interface { CommitCtx(context.Context, string, []byte, time.Time) error }) diff --git a/data_test.go b/data_test.go index f91889b..ed48d41 100644 --- a/data_test.go +++ b/data_test.go @@ -320,6 +320,81 @@ func TestSessionManager_Commit(T *testing.T) { t.Error("expected error not returned") } }) + + T.Run("with token hashing", func(t *testing.T) { + s := New() + s.HashTokenInStore = true + s.IdleTimeout = time.Hour * 24 + + expectedToken := "example" + expectedExpiry := time.Now().Add(time.Hour) + + ctx := context.WithValue(context.Background(), s.contextKey, &sessionData{ + deadline: expectedExpiry, + token: expectedToken, + values: map[string]interface{}{ + "blah": "blah", + }, + mu: sync.Mutex{}, + }) + + actualToken, actualExpiry, err := s.Commit(ctx) + if expectedToken != actualToken { + t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken) + } + if expectedExpiry != actualExpiry { + t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry) + } + if err != nil { + t.Errorf("unexpected error returned: %v", err) + } + }) +} + +func TestTokenHashing(T *testing.T) { + T.Run("with token hashing", func(t *testing.T) { + s := New() + s.HashTokenInStore = true + s.IdleTimeout = time.Hour * 24 + + expectedToken := "example" + expectedExpiry := time.Now().Add(time.Hour) + + initialCtx := context.WithValue(context.Background(), s.contextKey, &sessionData{ + deadline: expectedExpiry, + token: expectedToken, + values: map[string]interface{}{ + "blah": "blah", + }, + mu: sync.Mutex{}, + }) + + actualToken, actualExpiry, err := s.Commit(initialCtx) + if expectedToken != actualToken { + t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken) + } + if expectedExpiry != actualExpiry { + t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry) + } + if err != nil { + t.Errorf("unexpected error returned: %v", err) + } + + retrievedCtx, err := s.Load(context.Background(), expectedToken) + if err != nil { + t.Errorf("unexpected error returned: %v", err) + } + retrievedSessionData, ok := retrievedCtx.Value(s.contextKey).(*sessionData) + if !ok { + t.Errorf("unexpected data in retrieved context") + } else if retrievedSessionData.token != expectedToken { + t.Errorf("expected token in context's session data data to equal %v, but received %v", expectedToken, retrievedSessionData.token) + } + + if err := s.Destroy(retrievedCtx); err != nil { + t.Errorf("unexpected error returned: %v", err) + } + }) } func TestPut(t *testing.T) { diff --git a/session.go b/session.go index 4cc73f4..2825c79 100644 --- a/session.go +++ b/session.go @@ -45,6 +45,9 @@ type SessionManager struct { // a function which logs the error and returns a customized HTML error page. ErrorFunc func(http.ResponseWriter, *http.Request, error) + // HashTokenInStore controls whether or not to store the session token or a hashed version in the store. + HashTokenInStore bool + // contextKey is the key used to set and retrieve the session data from a // context.Context. It's automatically generated to ensure uniqueness. contextKey contextKey From 7134b6f8757330a574d87940112099c4f9129294 Mon Sep 17 00:00:00 2001 From: Bram Vanbilsen Date: Mon, 5 Feb 2024 11:58:09 -0600 Subject: [PATCH 2/2] Moved hashed token test to better location --- data_test.go | 90 +++++++++++++++++++++++++--------------------------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/data_test.go b/data_test.go index ed48d41..2aed85c 100644 --- a/data_test.go +++ b/data_test.go @@ -198,6 +198,50 @@ func TestSessionManager_Load(T *testing.T) { t.Error("returned context is unexpectedly nil") } }) + + T.Run("with token hashing", func(t *testing.T) { + s := New() + s.HashTokenInStore = true + s.IdleTimeout = time.Hour * 24 + + expectedToken := "example" + expectedExpiry := time.Now().Add(time.Hour) + + initialCtx := context.WithValue(context.Background(), s.contextKey, &sessionData{ + deadline: expectedExpiry, + token: expectedToken, + values: map[string]interface{}{ + "blah": "blah", + }, + mu: sync.Mutex{}, + }) + + actualToken, actualExpiry, err := s.Commit(initialCtx) + if expectedToken != actualToken { + t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken) + } + if expectedExpiry != actualExpiry { + t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry) + } + if err != nil { + t.Errorf("unexpected error returned: %v", err) + } + + retrievedCtx, err := s.Load(context.Background(), expectedToken) + if err != nil { + t.Errorf("unexpected error returned: %v", err) + } + retrievedSessionData, ok := retrievedCtx.Value(s.contextKey).(*sessionData) + if !ok { + t.Errorf("unexpected data in retrieved context") + } else if retrievedSessionData.token != expectedToken { + t.Errorf("expected token in context's session data data to equal %v, but received %v", expectedToken, retrievedSessionData.token) + } + + if err := s.Destroy(retrievedCtx); err != nil { + t.Errorf("unexpected error returned: %v", err) + } + }) } func TestSessionManager_Commit(T *testing.T) { @@ -351,52 +395,6 @@ func TestSessionManager_Commit(T *testing.T) { }) } -func TestTokenHashing(T *testing.T) { - T.Run("with token hashing", func(t *testing.T) { - s := New() - s.HashTokenInStore = true - s.IdleTimeout = time.Hour * 24 - - expectedToken := "example" - expectedExpiry := time.Now().Add(time.Hour) - - initialCtx := context.WithValue(context.Background(), s.contextKey, &sessionData{ - deadline: expectedExpiry, - token: expectedToken, - values: map[string]interface{}{ - "blah": "blah", - }, - mu: sync.Mutex{}, - }) - - actualToken, actualExpiry, err := s.Commit(initialCtx) - if expectedToken != actualToken { - t.Errorf("expected token to equal %q, but received %q", expectedToken, actualToken) - } - if expectedExpiry != actualExpiry { - t.Errorf("expected expiry to equal %v, but received %v", expectedExpiry, actualExpiry) - } - if err != nil { - t.Errorf("unexpected error returned: %v", err) - } - - retrievedCtx, err := s.Load(context.Background(), expectedToken) - if err != nil { - t.Errorf("unexpected error returned: %v", err) - } - retrievedSessionData, ok := retrievedCtx.Value(s.contextKey).(*sessionData) - if !ok { - t.Errorf("unexpected data in retrieved context") - } else if retrievedSessionData.token != expectedToken { - t.Errorf("expected token in context's session data data to equal %v, but received %v", expectedToken, retrievedSessionData.token) - } - - if err := s.Destroy(retrievedCtx); err != nil { - t.Errorf("unexpected error returned: %v", err) - } - }) -} - func TestPut(t *testing.T) { t.Parallel()