diff --git a/cache.go b/cache.go index 319c653..2499e01 100644 --- a/cache.go +++ b/cache.go @@ -22,6 +22,8 @@ type Cache struct { segments [segmentCount]segment } +type Updater func(value []byte, found bool) (newValue []byte, replace bool, expireSeconds int) + func hashFunc(data []byte) uint64 { return xxhash.Sum64(data) } @@ -136,6 +138,32 @@ func (cache *Cache) SetAndGet(key, value []byte, expireSeconds int) (retValue [] return } +// Update gets value for a key, passes it to updater function that decides if set should be called as well +// This allows for an atomic Get plus Set call using the existing value to decide on whether to call Set. +// If the key is larger than 65535 or value is larger than 1/1024 of the cache size, +// the entry will not be written to the cache. expireSeconds <= 0 means no expire, +// but it can be evicted when cache is full. Returns bool value to indicate if existing record was found along with bool +// value indicating the value was replaced and error if any +func (cache *Cache) Update(key []byte, updater Updater) (found bool, replaced bool, err error) { + hashVal := hashFunc(key) + segID := hashVal & segmentAndOpVal + cache.locks[segID].Lock() + defer cache.locks[segID].Unlock() + + retValue, _, err := cache.segments[segID].get(key, nil, hashVal, false) + if err == nil { + found = true + } else { + err = nil // Clear ErrNotFound error since we're returning found flag + } + value, replaced, expireSeconds := updater(retValue, found) + if !replaced { + return + } + err = cache.segments[segID].set(key, value, hashVal, expireSeconds) + return +} + // Peek returns the value or not found error, without updating access time or counters. func (cache *Cache) Peek(key []byte) (value []byte, err error) { hashVal := hashFunc(key) diff --git a/cache_test.go b/cache_test.go index cea110b..6082299 100644 --- a/cache_test.go +++ b/cache_test.go @@ -668,7 +668,7 @@ func TestRace(t *testing.T) { getFunc := func() { var i int64 for i = 0; i < iters; i++ { - _, _ = cache.GetInt(int64(mrand.Intn(inUse))) //it will likely error w/ delFunc running too + _, _ = cache.GetInt(int64(mrand.Intn(inUse))) // it will likely error w/ delFunc running too } wg.Done() } @@ -1017,3 +1017,72 @@ func TestSetAndGet(t *testing.T) { t.Fatalf("SetAndGet expected SetAndGet %s: got %s", string(val1), string(rval)) } } + +func TestUpdate(t *testing.T) { + testName := "Update" + cache := NewCache(1024) + key := []byte("abcd") + val1 := []byte("efgh") + val2 := []byte("ijkl") + + var found, replaced bool + var err error + var prevVal, updaterVal []byte + updaterReplace := false + expireSeconds := 123 + + updater := func(value []byte, found bool) ([]byte, bool, int) { + prevVal = value + return updaterVal, updaterReplace, expireSeconds + } + + setUpdaterResponse := func(value []byte, replace bool) { + updaterVal = value + updaterReplace = replace + } + + assertExpectations := func(testCase int, expectedFound, expectedReplaced bool, expectedPrevVal []byte, expectedVal []byte) { + failPrefix := fmt.Sprintf("%s(%d)", testName, testCase) + + if expectedFound != found { + t.Fatalf("%s found should be %v", failPrefix, expectedFound) + } + if expectedReplaced != replaced { + t.Fatalf("%s found should be %v", failPrefix, expectedReplaced) + } + if err != nil { + t.Fatalf("%s unexpected err %v", failPrefix, err) + } + if string(prevVal) != string(expectedPrevVal) { + t.Fatalf("%s previous value expected %s instead of %s", failPrefix, string(expectedPrevVal), string(prevVal)) + } + + // Check value + value, err := cache.Get(key) + if err == ErrNotFound && expectedVal != nil { + t.Fatalf("%s previous value expected %s instead of nil", failPrefix, string(expectedVal)) + } + if string(value) != string(expectedVal) { + t.Fatalf("%s previous value expected %s instead of %s", failPrefix, string(expectedVal), string(value)) + } + } + + // Doesn't exist yet, decide not to update, set should not be called + found, replaced, err = cache.Update(key, updater) + assertExpectations(1, false, false, nil, nil) + + // Doesn't exist yet, decide to update, set should be called with new value + setUpdaterResponse(val1, true) + found, replaced, err = cache.Update(key, updater) + assertExpectations(2, false, true, nil, val1) + + // Key exists, decide to update, updater is given old value and set should be called with new value + setUpdaterResponse(val2, true) + found, replaced, err = cache.Update(key, updater) + assertExpectations(3, true, true, val1, val2) + + // Key exists, decide not to update, updater is given old value and set should not be called + setUpdaterResponse(val1, false) + found, replaced, err = cache.Update(key, updater) + assertExpectations(4, true, false, val2, val2) +}