Skip to content

Commit

Permalink
feat: new cache implementation (#82)
Browse files Browse the repository at this point in the history
This is a non-functional change as this cache is unused

But the idea is that it may replace the existing MemoryCache and
BlockCache to serve concurrent requests in the future
  • Loading branch information
cottand authored Dec 17, 2024
1 parent b5b2ddb commit 0bb5cc2
Show file tree
Hide file tree
Showing 5 changed files with 602 additions and 4 deletions.
99 changes: 99 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/miekg/dns"
Expand Down Expand Up @@ -57,6 +58,104 @@ type Cache interface {
Exists(key string) bool
Remove(key string)
Length() int
Full() bool
}

type lengCache struct {
backend sync.Map // of string -> *Mesg
size atomic.Int64
full bool
maxSize int64
}

func NewCache(maxSize int64) Cache {
return &lengCache{
backend: sync.Map{},
size: atomic.Int64{},
maxSize: maxSize,
}
}

func (c *lengCache) Get(key string) (Msg *dns.Msg, blocked bool, err error) {
key = strings.ToLower(key)

//Truncate time to the second, so that subsecond queries won't keep moving
//forward the last update time without touching the TTL
now := WallClock.Now().Truncate(time.Second)
expired := false
existing, ok := c.backend.Load(key)
mesg := existing.(*Mesg)
if ok && mesg.Msg == nil {
ok = false
logger.Warningf("Cache: key %s returned nil entry", key)
c.Remove(key)
}

if ok {
elapsed := uint32(now.Sub(mesg.LastUpdateTime).Seconds())
for _, answer := range mesg.Msg.Answer {
if elapsed > answer.Header().Ttl {
logger.Debugf("Cache: Key expired %s", key)
c.Remove(key)
expired = true
}
answer.Header().Ttl -= elapsed
}
}

if !ok {
logger.Debugf("Cache: Cannot find key %s\n", key)
return nil, false, KeyNotFound{key}
}

if expired {
return nil, false, KeyExpired{key}
}

mesg.LastUpdateTime = now

return mesg.Msg, mesg.Blocked, nil
}

func (c *lengCache) Set(key string, msg *dns.Msg, blocked bool) error {
key = strings.ToLower(key)

if c.Full() && !c.Exists(key) {
return CacheIsFull{}
}
if msg == nil {
logger.Debugf("Setting an empty value for key %s", key)
}
c.backend.Store(key, &Mesg{msg, blocked, WallClock.Now().Truncate(time.Second)})
return nil
}

func (c *lengCache) Exists(key string) bool {
_, ok := c.backend.Load(key)
return ok
}

func (c *lengCache) Remove(key string) {
_, loaded := c.backend.LoadAndDelete(key)
if loaded {
newSize := c.size.Add(-1)
if newSize < c.maxSize {
c.full = false
}
}
}

func (c *lengCache) Length() int {
size := c.size.Load()
c.full = size > c.maxSize
return int(size)
}

func (c *lengCache) Full() bool {
if c.maxSize > 0 {
return c.full
}
return false
}

// MemoryCache type
Expand Down
13 changes: 9 additions & 4 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package main

import (
"errors"
"fmt"
"net"
"regexp"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -250,8 +252,8 @@ func TestCacheTtlFrequentPolling(t *testing.T) {

}

/*
func TestExpirationRace(t *testing.T) {
t.Skip()
cache := makeCache()
fakeClock := clockwork.NewFakeClock()
WallClock = fakeClock
Expand Down Expand Up @@ -279,22 +281,25 @@ func TestExpirationRace(t *testing.T) {
}

for i := 0; i < 1000; i++ {
wg := &sync.WaitGroup{}
wg.Add(2)
fakeClock.Advance(time.Duration(100) * time.Millisecond)
go func() {
_, _, err := cache.Get(testDomain)
if err != nil {
if err != nil && !errors.Is(err, &KeyNotFound{}) {
t.Error(err)
}
wg.Done()
}()
go func() {
err := cache.Set(testDomain, m, true)
if err != nil {
t.Error(err)
}
wg.Done()
}()
}
}
*/

func BenchmarkSetCache(b *testing.B) {
cache := makeCache()
Expand All @@ -311,7 +316,7 @@ func BenchmarkSetCache(b *testing.B) {
}
}

func BenchmarkGetCache(b *testing.B) {
func BenchmarkGetCacheSingleDomain(b *testing.B) {
cache := makeCache()

m := new(dns.Msg)
Expand Down
147 changes: 147 additions & 0 deletions lcache/cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package lcache

import (
"github.com/jonboulle/clockwork"
"github.com/miekg/dns"
"github.com/op/go-logging"
"math"
"strings"
"sync"
"sync/atomic"
"time"
)

var logger = logging.MustGetLogger("test")

// wallClock is the wall clock
var wallClock = clockwork.NewRealClock()

// entry represents a cache entry
type entry struct {
Msg *dns.Msg
Blocked bool
expiresAt time.Time
mu sync.Mutex
}

// Cache interface
type Cache interface {
Get(key string) (Msg *dns.Msg, blocked bool, err error)
Set(key string, Msg *dns.Msg, blocked bool) error
Exists(key string) bool
Remove(key string)
Length() int
Full() bool
}

type lengCache struct {
backend sync.Map // of string -> *entry
size atomic.Int64
full bool
maxSize int64
}

func New(maxSize int64) Cache {
return &lengCache{
backend: sync.Map{},
size: atomic.Int64{},
maxSize: maxSize,
}
}

func (c *lengCache) Get(key string) (Msg *dns.Msg, blocked bool, err error) {
key = strings.ToLower(key)

existing, ok := c.backend.Load(key)
if !ok {
logger.Debugf("Cache: Cannot find key %s\n", key)
return nil, false, KeyNotFound{key}
}
mesg := existing.(*entry)
if mesg.Msg == nil {
return nil, mesg.Blocked, nil
}
mesg.mu.Lock()
defer mesg.mu.Unlock()
now := wallClock.Now()

// entry expired!
if now.After(mesg.expiresAt) {
c.Remove(key)
return nil, false, KeyExpired{key}
}
newTtl := uint32(mesg.expiresAt.Sub(now).Truncate(time.Second).Seconds())

for _, answer := range mesg.Msg.Answer {
// this can happen concurrently (and it is a concurrent write of shared memory),
// but it's ok because two concurrent modifications usually have the same result
// when rounded to the second
answer.Header().Ttl = newTtl
}

return mesg.Msg, mesg.Blocked, nil
}

func minTtlFor(msg *dns.Msg) time.Duration {
if msg == nil {
return 0
}
// find smallest ttl
minTtl := uint32(math.MaxUint32)
for _, answer := range msg.Answer {
msgTtl := answer.Header().Ttl
if minTtl > msgTtl {
minTtl = msgTtl
}
}
return time.Duration(minTtl) * time.Second
}

func (c *lengCache) Set(key string, msg *dns.Msg, blocked bool) error {
key = strings.ToLower(key)

if c.Full() && !c.Exists(key) {
return CacheIsFull{}
}
if msg == nil {
logger.Debugf("Setting an empty value for key %s", key)
}

now := wallClock.Now()
e := entry{
Msg: msg,
Blocked: blocked,
expiresAt: now.Add(minTtlFor(msg)),
}
c.backend.Store(key, &e)
return nil
}

func (c *lengCache) Exists(key string) bool {
key = strings.ToLower(key)
_, ok := c.backend.Load(key)
return ok
}

func (c *lengCache) Remove(key string) {
_, loaded := c.backend.LoadAndDelete(key)
if loaded {
newSize := c.size.Add(-1)
if newSize < c.maxSize {
c.full = false
}
}
}

func (c *lengCache) Length() int {
size := c.size.Load()
c.full = size > c.maxSize
return int(size)
}

func (c *lengCache) Full() bool {
if c.maxSize > 0 {
return c.full
}
return false
}
Loading

0 comments on commit 0bb5cc2

Please sign in to comment.