diff --git a/conv/conv.go b/conv/conv.go index acb0f94..c78783d 100644 --- a/conv/conv.go +++ b/conv/conv.go @@ -50,6 +50,8 @@ func Int(r interface{}, err error) (int, error) { return int(n), err case nil: return 0, ErrNil + case error: + return 0, r } return 0, ErrAssertType @@ -74,6 +76,8 @@ func Int64(r interface{}, err error) (int64, error) { return n, err case nil: return 0, ErrNil + case error: + return 0, r } return 0, ErrAssertType @@ -106,6 +110,8 @@ func UInt64(r interface{}, err error) (uint64, error) { return n, err case nil: return 0, ErrNil + case error: + return 0, r } return 0, ErrAssertType @@ -127,6 +133,8 @@ func Float64(r interface{}, err error) (float64, error) { return n, err case nil: return 0, ErrNil + case error: + return 0, r } return 0, ErrAssertType } @@ -143,6 +151,8 @@ func String(r interface{}, err error) (string, error) { return r, nil case nil: return "", ErrNil + case error: + return "", r } return "", ErrAssertType } @@ -159,6 +169,8 @@ func Bytes(r interface{}, err error) ([]byte, error) { return []byte(r), nil case nil: return nil, ErrNil + case error: + return nil, r } return nil, ErrAssertType } @@ -182,6 +194,8 @@ func Bool(r interface{}, err error) (bool, error) { return strconv.ParseBool(r) case nil: return false, ErrNil + case error: + return false, r } return false, ErrAssertType } diff --git a/stores/goredis/store.go b/stores/goredis/store.go index 2e4de30..82c159f 100644 --- a/stores/goredis/store.go +++ b/stores/goredis/store.go @@ -3,7 +3,6 @@ package goredis import ( "context" "crypto/rand" - "sync" "time" "unicode" @@ -42,10 +41,6 @@ type Store struct { // Prefix for session id. prefix string - // Temp map to store values before commit. - tempSetMap map[string]map[string]interface{} - mu sync.RWMutex - // Redis client client redis.UniversalClient clientCtx context.Context @@ -60,10 +55,9 @@ const ( // New creates a new Redis store instance. func New(ctx context.Context, client redis.UniversalClient) *Store { return &Store{ - clientCtx: ctx, - client: client, - prefix: defaultPrefix, - tempSetMap: make(map[string]map[string]interface{}), + clientCtx: ctx, + client: client, + prefix: defaultPrefix, } } @@ -93,12 +87,29 @@ func (s *Store) Get(id, key string) (interface{}, error) { return nil, ErrInvalidSession } - v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result() - if err == redis.Nil { + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HGet(s.clientCtx, s.prefix+id, key) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } + + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession + } + + v, err := get.Result() + if err != nil && err == redis.Nil { return nil, ErrFieldNotFound } - return v, err + return v, nil } // GetMulti gets a map for values for multiple keys. If key is not found then its set as nil. @@ -107,16 +118,36 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err return nil, ErrInvalidSession } - v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result() - // If field is not found then return map with fields as nil - if len(v) == 0 || err == redis.Nil { - v = make([]interface{}, len(keys)) + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HMGet(s.clientCtx, s.prefix+id, keys...) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } + + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession + } + + v, err := get.Result() + if err != nil { + return nil, err } // Form a map with returned results res := make(map[string]interface{}) for i, k := range keys { - res[k] = v[i] + if v[i] == nil { + res[k] = ErrFieldNotFound + } else { + res[k] = v[i] + } } return res, err @@ -128,10 +159,25 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) { return nil, ErrInvalidSession } - res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result() - if res == nil || err == redis.Nil { - return map[string]interface{}{}, nil - } else if err != nil { + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + get := pipe.HGetAll(s.clientCtx, s.prefix+id) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return nil, err + } + + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return nil, err + } else if ex == 0 { + return nil, ErrInvalidSession + } + + res, err := get.Result() + if err != nil { return nil, err } @@ -144,54 +190,37 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) { return out, nil } -// Set sets a value to given session but stored only on commit +// Set sets a value to given session. func (s *Store) Set(id, key string, val interface{}) error { if !validateID(id) { return ErrInvalidSession } - s.mu.Lock() - defer s.mu.Unlock() + pipe := s.client.TxPipeline() + pipe.HSet(s.clientCtx, s.prefix+id, key, val) - // Create session map if doesn't exist - if _, ok := s.tempSetMap[id]; !ok { - s.tempSetMap[id] = make(map[string]interface{}) + // Set expiry of key only if 'ttl' is set, this is to + // ensure that the key remains valid indefinitely like + // how redis handles it by default + if s.ttl > 0 { + pipe.Expire(s.clientCtx, s.prefix+id, s.ttl) } - // set value to map - s.tempSetMap[id][key] = val - - return nil + _, err := pipe.Exec(s.clientCtx) + return err } -// Commit sets all set values. -func (s *Store) Commit(id string) error { +// Set sets a value to given session. +func (s *Store) SetMulti(id string, data map[string]interface{}) error { if !validateID(id) { return ErrInvalidSession } - s.mu.RLock() - vals, ok := s.tempSetMap[id] - if !ok { - // Nothing to commit - s.mu.RUnlock() - return nil - } - // Make slice of arguments to be passed in HGETALL command - args := make([]interface{}, len(vals)*2, len(vals)*2) - c := 0 - for k, v := range s.tempSetMap[id] { - args[c] = k - args[c+1] = v - c += 2 + args := []interface{}{} + for k, v := range data { + args = append(args, k, v) } - s.mu.RUnlock() - - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() pipe := s.client.TxPipeline() pipe.HMSet(s.clientCtx, s.prefix+id, args...) @@ -212,16 +241,30 @@ func (s *Store) Delete(id string, key string) error { return ErrInvalidSession } - // Clear temp map for given session id - s.mu.Lock() - delete(s.tempSetMap, id) - s.mu.Unlock() + pipe := s.client.TxPipeline() + exists := pipe.Exists(s.clientCtx, s.prefix+id) + del := pipe.HDel(s.clientCtx, s.prefix+id, key) + _, err := pipe.Exec(s.clientCtx) + // redis.Nil is returned if a field does not exist. + // Ignore the error and check for key existence check. + if err != nil && err != redis.Nil { + return err + } - err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err() - if err == redis.Nil { + // Check if key exists and return ErrInvalidSession if not. + if ex, err := exists.Result(); err != nil { + return err + } else if ex == 0 { + return ErrInvalidSession + } + + if v, err := del.Result(); err != nil { + return err + } else if v == 0 { return ErrFieldNotFound } - return err + + return nil } // Clear clears session in redis. diff --git a/stores/goredis/store_test.go b/stores/goredis/store_test.go index 0f51d94..332b54e 100644 --- a/stores/goredis/store_test.go +++ b/stores/goredis/store_test.go @@ -30,51 +30,35 @@ func getRedisClient() redis.UniversalClient { } func TestNew(t *testing.T) { - assert := assert.New(t) client := getRedisClient() ctx := context.Background() str := New(ctx, client) - assert.Equal(str.prefix, defaultPrefix) - assert.Equal(str.client, client) - assert.Equal(str.clientCtx, ctx) - assert.NotNil(str.tempSetMap) + assert.Equal(t, str.prefix, defaultPrefix) + assert.Equal(t, str.client, client) + assert.Equal(t, str.clientCtx, ctx) } func TestSetPrefix(t *testing.T) { - assert := assert.New(t) str := New(context.TODO(), getRedisClient()) str.SetPrefix("test") - assert.Equal(str.prefix, "test") + assert.Equal(t, str.prefix, "test") } func TestSetTTL(t *testing.T) { - assert := assert.New(t) testDur := time.Second * 10 str := New(context.TODO(), getRedisClient()) str.SetTTL(testDur) - assert.Equal(str.ttl, testDur) + assert.Equal(t, str.ttl, testDur) } func TestCreate(t *testing.T) { - assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - id, err := str.Create() - assert.Nil(err) - assert.Equal(len(id), sessionIDLen) -} - -func TestGetInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - val, err := str.Get("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession) + assert.Nil(t, err) + assert.Equal(t, len(id), sessionIDLen) } func TestGet(t *testing.T) { - assert := assert.New(t) key := "4dIHy6S2uBuKaNnTUszB218L898ikGY1" field := "somekey" value := 100 @@ -82,87 +66,93 @@ func TestGet(t *testing.T) { // Set a key err := client.HSet(context.TODO(), defaultPrefix+key, field, value).Err() - assert.NoError(err) + assert.NoError(t, err) str := New(context.TODO(), client) val, err := str.Int(str.Get(key, field)) - assert.NoError(err) - assert.Equal(val, value) -} + assert.NoError(t, err) + assert.Equal(t, val, value) -func TestGetFieldNotFoundError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) - - key := "10IHy6S2uBuKaNnTUszB218L898ikGY1" - val, err := str.Get(key, "invalidkey") - assert.Nil(val) - assert.Error(err, ErrFieldNotFound.Error()) + // Check for invalid key. + _, err = str.Int(str.Get(key, "invalidfield")) + assert.ErrorIs(t, ErrFieldNotFound, err) } -func TestGetMultiInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestGetInvalidSession(t *testing.T) { str := New(context.TODO(), getRedisClient()) + val, err := str.Get("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, err, ErrInvalidSession) - val, err := str.GetMulti("invalidkey", "invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + id := "10IHy6S2uBuKaNnTUszB218L898ikGY1" + val, err = str.Get(id, "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) } -func TestGetMultiFieldEmptySession(t *testing.T) { - assert := assert.New(t) +func TestGetMultiInvalidSession(t *testing.T) { str := New(context.TODO(), getRedisClient()) + val, err := str.GetMulti("invalidkey", "invalidkey") + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" field := "somefield" - _, err := str.GetMulti(key, field) - assert.Nil(err) + _, err = str.GetMulti(key, field) + assert.ErrorIs(t, err, ErrInvalidSession) } func TestGetMulti(t *testing.T) { - assert := assert.New(t) - key := "5dIHy6S2uBuKaNnTUszB218L898ikGY1" - field1 := "somekey" - value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - field3 := "thishouldntbethere" - value3 := 100.10 - client := getRedisClient() + var ( + key = "5dIHy6S2uBuKaNnTUszB218L898ikGY1" + field1 = "somekey" + value1 = 100 + field2 = "someotherkey" + value2 = "abc123" + field3 = "thishouldntbethere" + value3 = 100.10 + invalidField = "foo" + client = getRedisClient() + ) // Set a key err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(err) + assert.NoError(t, err) str := New(context.TODO(), client) - - vals, err := str.GetMulti(key, field1, field2) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.NotContains(vals, field3) + vals, err := str.GetMulti(key, field1, field2, invalidField) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.NotContains(t, vals, field3) val1, err := str.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) + assert.NoError(t, err) + assert.Equal(t, val1, value1) val2, err := str.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) + assert.NoError(t, err) + assert.Equal(t, val2, value2) + + // Check for invalid key. + _, err = str.String(vals[invalidField], nil) + assert.ErrorIs(t, ErrFieldNotFound, err) } -func TestGetAllInvalidSessionError(t *testing.T) { - assert := assert.New(t) +func TestGetAllInvalidSession(t *testing.T) { str := New(context.TODO(), getRedisClient()) - val, err := str.GetAll("invalidkey") - assert.Nil(val) - assert.Error(err, ErrInvalidSession.Error()) + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) + + key := "11IHy6S2uBuKaNnTUszB218L898ikGY1" + val, err = str.GetAll(key) + assert.Nil(t, val) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestGetAll(t *testing.T) { - assert := assert.New(t) key := "6dIHy6S2uBuKaNnTUszB218L898ikGY1" field1 := "somekey" value1 := 100 @@ -174,122 +164,110 @@ func TestGetAll(t *testing.T) { // Set a key err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2, field3, value3).Err() - assert.NoError(err) + assert.NoError(t, err) str := New(context.TODO(), client) vals, err := str.GetAll(key) - assert.NoError(err) - assert.Contains(vals, field1) - assert.Contains(vals, field2) - assert.Contains(vals, field3) + assert.NoError(t, err) + assert.Contains(t, vals, field1) + assert.Contains(t, vals, field2) + assert.Contains(t, vals, field3) val1, err := str.Int(vals[field1], nil) - assert.NoError(err) - assert.Equal(val1, value1) + assert.NoError(t, err) + assert.Equal(t, val1, value1) val2, err := str.String(vals[field2], nil) - assert.NoError(err) - assert.Equal(val2, value2) + assert.NoError(t, err) + assert.Equal(t, val2, value2) val3, err := str.Float64(vals[field3], nil) - assert.NoError(err) - assert.Equal(val3, value3) + assert.NoError(t, err) + assert.Equal(t, val3, value3) } func TestSetInvalidSessionError(t *testing.T) { - assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - err := str.Set("invalidid", "key", "value") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestSet(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) + ttl := time.Second * 10 + str.SetTTL(ttl) // this key is unique across all tests key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" field := "somekey" value := 100 - assert.NotNil(str.tempSetMap) - assert.NotContains(str.tempSetMap, key) - err := str.Set(key, field, value) - assert.NoError(err) - assert.Contains(str.tempSetMap, key) - assert.Contains(str.tempSetMap[key], field) - assert.Equal(str.tempSetMap[key][field], value) + assert.NoError(t, err) // Check ifs not commited to redis - val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(val, int64(0)) -} - -func TestCommitInvalidSessionError(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) + v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) - err := str.Commit("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) -} - -func TestEmptyCommit(t *testing.T) { - assert := assert.New(t) - str := New(context.TODO(), getRedisClient()) + v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field).Result()) + assert.NoError(t, err) + assert.Equal(t, value, v2) - err := str.Commit("15IHy6S2uBuKaNnTUszB2180898ikGY1") - assert.NoError(err) + dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) } -func TestCommit(t *testing.T) { - // Test should commit in redis with expiry on key - assert := assert.New(t) +func TestSetMulti(t *testing.T) { + // Test should only set in internal map and not in redis client := getRedisClient() str := New(context.TODO(), client) - - str.SetTTL(10 * time.Second) + ttl := time.Second * 10 + str.SetTTL(ttl) // this key is unique across all tests - key := "8dIHy6S2uBuKaNnTUszB2180898ikGY1" - field1 := "somekey" + key := "7dIHy6S2uBuKaNnTUszB218L898ikGY9" + field1 := "somekey1" value1 := 100 - field2 := "someotherkey" - value2 := "abc123" - - err := str.Set(key, field1, value1) - assert.NoError(err) + field2 := "somekey2" + value2 := "somevalue" - err = str.Set(key, field2, value2) - assert.NoError(err) + err := str.SetMulti(key, map[string]interface{}{ + field1: value1, + field2: value2, + }) + assert.NoError(t, err) - err = str.Commit(key) - assert.NoError(err) + // Check ifs not commited to redis + v1, err := client.Exists(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, int64(1), v1) - vals, err := client.HGetAll(context.TODO(), defaultPrefix+key).Result() - assert.Equal(2, len(vals)) + v2, err := str.Int(client.HGet(context.TODO(), defaultPrefix+key, field1).Result()) + assert.NoError(t, err) + assert.Equal(t, value1, v2) - ttl, err := client.TTL(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(true, ttl.Seconds() > 0 && ttl.Seconds() <= 10) + dur, err := client.TTL(context.TODO(), defaultPrefix+key).Result() + assert.NoError(t, err) + assert.Equal(t, dur, ttl) } func TestDeleteInvalidSessionError(t *testing.T) { - assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - err := str.Delete("invalidkey", "somefield") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) + + str = New(context.TODO(), getRedisClient()) + err = str.Delete("8dIHy6S2uBuKaNnTUszB2180898ikGY1", "somefield") + assert.ErrorIs(t, ErrInvalidSession, err) } func TestDelete(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -301,29 +279,31 @@ func TestDelete(t *testing.T) { value2 := "abc123" err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(err) + assert.NoError(t, err) err = str.Delete(key, field1) - assert.NoError(err) + assert.NoError(t, err) val, err := client.HExists(context.TODO(), defaultPrefix+key, field1).Result() - assert.False(val) + assert.False(t, val) + assert.NoError(t, err) val, err = client.HExists(context.TODO(), defaultPrefix+key, field2).Result() - assert.True(val) + assert.True(t, val) + assert.NoError(t, err) + + err = str.Delete(key, "xxxxx") + assert.ErrorIs(t, err, ErrFieldNotFound) } func TestClearInvalidSessionError(t *testing.T) { - assert := assert.New(t) str := New(context.TODO(), getRedisClient()) - err := str.Clear("invalidkey") - assert.Error(err, ErrInvalidSession.Error()) + assert.ErrorIs(t, ErrInvalidSession, err) } func TestClear(t *testing.T) { // Test should only set in internal map and not in redis - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -335,23 +315,22 @@ func TestClear(t *testing.T) { value2 := "abc123" err := client.HMSet(context.TODO(), defaultPrefix+key, field1, value1, field2, value2).Err() - assert.NoError(err) + assert.NoError(t, err) // Check if its set val, err := client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.NotEqual(val, int64(0)) + assert.NoError(t, err) + assert.NotEqual(t, val, int64(0)) err = str.Clear(key) - assert.NoError(err) + assert.NoError(t, err) val, err = client.Exists(context.TODO(), defaultPrefix+key).Result() - assert.NoError(err) - assert.Equal(val, int64(0)) + assert.NoError(t, err) + assert.Equal(t, val, int64(0)) } func TestInt(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -359,19 +338,18 @@ func TestInt(t *testing.T) { value := 100 err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.Int(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Int(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Int(value, testError) + assert.ErrorIs(t, testError, err) } func TestInt64(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -379,19 +357,18 @@ func TestInt64(t *testing.T) { var value int64 = 100 err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.Int64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Int64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Int64(value, testError) + assert.ErrorIs(t, testError, err) } func TestUInt64(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -399,19 +376,18 @@ func TestUInt64(t *testing.T) { var value uint64 = 100 err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.UInt64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.UInt64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.UInt64(value, testError) + assert.ErrorIs(t, testError, err) } func TestFloat64(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -419,19 +395,18 @@ func TestFloat64(t *testing.T) { var value float64 = 100 err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.Float64(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Float64(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Float64(value, testError) + assert.ErrorIs(t, testError, err) } func TestString(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -439,19 +414,18 @@ func TestString(t *testing.T) { value := "abc123" err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.String(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.String(value, testError) - assert.Error(err, testError.Error()) + _, err = str.String(value, testError) + assert.ErrorIs(t, testError, err) } func TestBytes(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -459,19 +433,18 @@ func TestBytes(t *testing.T) { var value []byte = []byte("abc123") err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.Bytes(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Bytes(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Bytes(value, testError) + assert.ErrorIs(t, testError, err) } func TestBool(t *testing.T) { - assert := assert.New(t) client := getRedisClient() str := New(context.TODO(), client) @@ -479,13 +452,26 @@ func TestBool(t *testing.T) { value := true err := client.Set(context.TODO(), field, value, 0).Err() - assert.NoError(err) + assert.NoError(t, err) val, err := str.Bool(client.Get(context.TODO(), field).Result()) - assert.NoError(err) - assert.Equal(value, val) + assert.NoError(t, err) + assert.Equal(t, value, val) testError := errors.New("test error") - val, err = str.Bool(value, testError) - assert.Error(err, testError.Error()) + _, err = str.Bool(value, testError) + assert.ErrorIs(t, testError, err) +} + +func TestValidateID(t *testing.T) { + ok := validateID("xxxx") + assert.False(t, ok) + + ok = validateID("8dIHy6S2uBuKaNnTUszB2180898ikGY&") + assert.False(t, ok) + + id, err := generateID(sessionIDLen) + assert.NoError(t, err) + ok = validateID(id) + assert.True(t, ok) }