Skip to content

Commit

Permalink
fix: update goredis package to implement latest store interface
Browse files Browse the repository at this point in the history
  • Loading branch information
vividvilla committed May 27, 2024
1 parent 7f32488 commit 34a213a
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 250 deletions.
14 changes: 14 additions & 0 deletions conv/conv.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ func Int(r interface{}, err error) (int, error) {
return int(n), err
case nil:
return 0, ErrNil
case error:
return 0, r
}

return 0, ErrAssertType
Expand All @@ -74,6 +76,8 @@ func Int64(r interface{}, err error) (int64, error) {
return n, err
case nil:
return 0, ErrNil
case error:
return 0, r
}

return 0, ErrAssertType
Expand Down Expand Up @@ -106,6 +110,8 @@ func UInt64(r interface{}, err error) (uint64, error) {
return n, err
case nil:
return 0, ErrNil
case error:
return 0, r
}

return 0, ErrAssertType
Expand All @@ -127,6 +133,8 @@ func Float64(r interface{}, err error) (float64, error) {
return n, err
case nil:
return 0, ErrNil
case error:
return 0, r
}
return 0, ErrAssertType
}
Expand All @@ -143,6 +151,8 @@ func String(r interface{}, err error) (string, error) {
return r, nil
case nil:
return "", ErrNil
case error:
return "", r
}
return "", ErrAssertType
}
Expand All @@ -159,6 +169,8 @@ func Bytes(r interface{}, err error) ([]byte, error) {
return []byte(r), nil
case nil:
return nil, ErrNil
case error:
return nil, r
}
return nil, ErrAssertType
}
Expand All @@ -182,6 +194,8 @@ func Bool(r interface{}, err error) (bool, error) {
return strconv.ParseBool(r)
case nil:
return false, ErrNil
case error:
return false, r
}
return false, ErrAssertType
}
163 changes: 103 additions & 60 deletions stores/goredis/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package goredis
import (
"context"
"crypto/rand"
"sync"
"time"
"unicode"

Expand Down Expand Up @@ -42,10 +41,6 @@ type Store struct {
// Prefix for session id.
prefix string

// Temp map to store values before commit.
tempSetMap map[string]map[string]interface{}
mu sync.RWMutex

// Redis client
client redis.UniversalClient
clientCtx context.Context
Expand All @@ -60,10 +55,9 @@ const (
// New creates a new Redis store instance.
func New(ctx context.Context, client redis.UniversalClient) *Store {
return &Store{
clientCtx: ctx,
client: client,
prefix: defaultPrefix,
tempSetMap: make(map[string]map[string]interface{}),
clientCtx: ctx,
client: client,
prefix: defaultPrefix,
}
}

Expand Down Expand Up @@ -93,12 +87,29 @@ func (s *Store) Get(id, key string) (interface{}, error) {
return nil, ErrInvalidSession
}

v, err := s.client.HGet(s.clientCtx, s.prefix+id, key).Result()
if err == redis.Nil {
pipe := s.client.TxPipeline()
exists := pipe.Exists(s.clientCtx, s.prefix+id)
get := pipe.HGet(s.clientCtx, s.prefix+id, key)
_, err := pipe.Exec(s.clientCtx)
// redis.Nil is returned if a field does not exist.
// Ignore the error and check for key existence check.
if err != nil && err != redis.Nil {
return nil, err
}

// Check if key exists and return ErrInvalidSession if not.
if ex, err := exists.Result(); err != nil {
return nil, err
} else if ex == 0 {
return nil, ErrInvalidSession
}

v, err := get.Result()
if err != nil && err == redis.Nil {
return nil, ErrFieldNotFound
}

return v, err
return v, nil
}

// GetMulti gets a map for values for multiple keys. If key is not found then its set as nil.
Expand All @@ -107,16 +118,36 @@ func (s *Store) GetMulti(id string, keys ...string) (map[string]interface{}, err
return nil, ErrInvalidSession
}

v, err := s.client.HMGet(s.clientCtx, s.prefix+id, keys...).Result()
// If field is not found then return map with fields as nil
if len(v) == 0 || err == redis.Nil {
v = make([]interface{}, len(keys))
pipe := s.client.TxPipeline()
exists := pipe.Exists(s.clientCtx, s.prefix+id)
get := pipe.HMGet(s.clientCtx, s.prefix+id, keys...)
_, err := pipe.Exec(s.clientCtx)
// redis.Nil is returned if a field does not exist.
// Ignore the error and check for key existence check.
if err != nil && err != redis.Nil {
return nil, err
}

// Check if key exists and return ErrInvalidSession if not.
if ex, err := exists.Result(); err != nil {
return nil, err
} else if ex == 0 {
return nil, ErrInvalidSession
}

v, err := get.Result()
if err != nil {
return nil, err
}

// Form a map with returned results
res := make(map[string]interface{})
for i, k := range keys {
res[k] = v[i]
if v[i] == nil {
res[k] = ErrFieldNotFound
} else {
res[k] = v[i]
}
}

return res, err
Expand All @@ -128,10 +159,25 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) {
return nil, ErrInvalidSession
}

res, err := s.client.HGetAll(s.clientCtx, s.prefix+id).Result()
if res == nil || err == redis.Nil {
return map[string]interface{}{}, nil
} else if err != nil {
pipe := s.client.TxPipeline()
exists := pipe.Exists(s.clientCtx, s.prefix+id)
get := pipe.HGetAll(s.clientCtx, s.prefix+id)
_, err := pipe.Exec(s.clientCtx)
// redis.Nil is returned if a field does not exist.
// Ignore the error and check for key existence check.
if err != nil && err != redis.Nil {
return nil, err
}

// Check if key exists and return ErrInvalidSession if not.
if ex, err := exists.Result(); err != nil {
return nil, err
} else if ex == 0 {
return nil, ErrInvalidSession
}

res, err := get.Result()
if err != nil {
return nil, err
}

Expand All @@ -144,54 +190,37 @@ func (s *Store) GetAll(id string) (map[string]interface{}, error) {
return out, nil
}

// Set sets a value to given session but stored only on commit
// Set sets a value to given session.
func (s *Store) Set(id, key string, val interface{}) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.Lock()
defer s.mu.Unlock()
pipe := s.client.TxPipeline()
pipe.HSet(s.clientCtx, s.prefix+id, key, val)

// Create session map if doesn't exist
if _, ok := s.tempSetMap[id]; !ok {
s.tempSetMap[id] = make(map[string]interface{})
// Set expiry of key only if 'ttl' is set, this is to
// ensure that the key remains valid indefinitely like
// how redis handles it by default
if s.ttl > 0 {
pipe.Expire(s.clientCtx, s.prefix+id, s.ttl)
}

// set value to map
s.tempSetMap[id][key] = val

return nil
_, err := pipe.Exec(s.clientCtx)
return err
}

// Commit sets all set values.
func (s *Store) Commit(id string) error {
// Set sets a value to given session.
func (s *Store) SetMulti(id string, data map[string]interface{}) error {
if !validateID(id) {
return ErrInvalidSession
}

s.mu.RLock()
vals, ok := s.tempSetMap[id]
if !ok {
// Nothing to commit
s.mu.RUnlock()
return nil
}

// Make slice of arguments to be passed in HGETALL command
args := make([]interface{}, len(vals)*2, len(vals)*2)
c := 0
for k, v := range s.tempSetMap[id] {
args[c] = k
args[c+1] = v
c += 2
args := []interface{}{}
for k, v := range data {
args = append(args, k, v)
}
s.mu.RUnlock()

// Clear temp map for given session id
s.mu.Lock()
delete(s.tempSetMap, id)
s.mu.Unlock()

pipe := s.client.TxPipeline()
pipe.HMSet(s.clientCtx, s.prefix+id, args...)
Expand All @@ -212,16 +241,30 @@ func (s *Store) Delete(id string, key string) error {
return ErrInvalidSession
}

// Clear temp map for given session id
s.mu.Lock()
delete(s.tempSetMap, id)
s.mu.Unlock()
pipe := s.client.TxPipeline()
exists := pipe.Exists(s.clientCtx, s.prefix+id)
del := pipe.HDel(s.clientCtx, s.prefix+id, key)
_, err := pipe.Exec(s.clientCtx)
// redis.Nil is returned if a field does not exist.
// Ignore the error and check for key existence check.
if err != nil && err != redis.Nil {
return err
}

err := s.client.HDel(s.clientCtx, s.prefix+id, key).Err()
if err == redis.Nil {
// Check if key exists and return ErrInvalidSession if not.
if ex, err := exists.Result(); err != nil {
return err
} else if ex == 0 {
return ErrInvalidSession
}

if v, err := del.Result(); err != nil {
return err
} else if v == 0 {
return ErrFieldNotFound
}
return err

return nil
}

// Clear clears session in redis.
Expand Down
Loading

0 comments on commit 34a213a

Please sign in to comment.