Skip to content

Commit

Permalink
use new cache in handle
Browse files Browse the repository at this point in the history
  • Loading branch information
cottand committed Dec 18, 2024
1 parent 331f23a commit c88d607
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 99 deletions.
177 changes: 81 additions & 96 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"github.com/cottand/leng/internal/metric"
"github.com/cottand/leng/lcache"
"github.com/miekg/dns"
"net"
"slices"
Expand Down Expand Up @@ -30,9 +31,9 @@ func (q *Question) String() string {
type EventLoop struct {
requestChannel chan DNSOperationData
resolver *Resolver
cache Cache
cache lcache.Cache[lcache.DefaultEntry]
// negCache caches failures
negCache Cache
negCache lcache.Cache[lcache.DefaultEntry]
active bool
muActive sync.RWMutex
config *Config
Expand All @@ -52,17 +53,12 @@ func NewEventLoop(config *Config, blockCache *MemoryBlockCache) *EventLoop {
var (
clientConfig *dns.ClientConfig
resolver *Resolver
cache Cache
negCache Cache
)

resolver = &Resolver{clientConfig}

//cache = lcache.NewDefault(int64(config.Upstream.Maxcount))
negCache = &MemoryCache{
Backend: make(map[string]*Mesg),
Maxcount: config.Upstream.Maxcount,
}
cache := lcache.NewDefault(config.Upstream.Maxcount)
negCache := lcache.NewDefault(config.Upstream.Maxcount)

handler := &EventLoop{
requestChannel: make(chan DNSOperationData),
Expand Down Expand Up @@ -91,7 +87,7 @@ func (h *EventLoop) do() {
}

// responseFor has side-effects, like writing to h's caches, so avoid calling it concurrently
func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (_ *dns.Msg, success bool) {
func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remote net.Addr) (resp *dns.Msg, success bool, blocked bool, cached bool) {

var remote net.IP
if Net == "tcp" || Net == "http" {
Expand All @@ -102,157 +98,112 @@ func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remo

// first of all, check custom DNS. No need to cache it because it is already in-mem and precedes the blocking
if custom := h.customDns.Resolve(req, _local, _remote); custom != nil {
return custom, true
return custom, true, false, true
}

// does not include custom DNS
defer metric.ReportDNSRespond(remote, resp, blocked, cached)

q := req.Question[0]
Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]}
logger.Infof("%s lookup %s\n", remote, Q.String())

IPQuery := h.isIPQuery(q)

blocked = IPQuery > 0 && lengActive && h.blockCache.Exists(Q.Qname)
if blocked {
resp = h.blockedResponseFor(req, IPQuery)

logger.Noticef("%s found in blocklist\n", Q.Qname)
return resp, true, blocked, false
}

// Only query cache when qtype == 'A'|'AAAA' , qclass == 'IN'
key := KeyGen(Q)
if IPQuery > 0 {
mesg, blocked, err := h.cache.Get(key)
mesg, err := h.cache.Get(key)
if err != nil {
if mesg, blocked, err = h.negCache.Get(key); err != nil {
if mesg, err = h.negCache.Get(key); err != nil {
logger.Debugf("%s didn't hit cache\n", Q.String())
} else {
logger.Debugf("%s hit negative cache\n", Q.String())
return nil, false
return nil, false, true, false
}
} else {
if blocked && !lengActive {
logger.Debugf("%s hit cache and was blocked: forwarding request\n", Q.String())
} else {
logger.Debugf("%s hit cache\n", Q.String())
cached = true
logger.Debugf("%s hit cache\n", Q.String())

// we need this copy against concurrent modification of ID
msg := *mesg
msg.Id = req.Id
// we need this copy against concurrent modification of ID
msg := *mesg
msg.Id = req.Id

defer metric.ReportDNSRespond(remote, &msg, blocked, true)
return &msg, true
}
return &msg.Msg, true, blocked, cached
}
}
// Check blocklist
var blacklisted = false

if IPQuery > 0 {
blacklisted = h.blockCache.Exists(Q.Qname)
cached = false

if lengActive && blacklisted {
m := new(dns.Msg)
m.SetReply(req)

if h.config.Blocking.NXDomain {
m.SetRcode(req, dns.RcodeNameError)
} else {
nullroute := net.ParseIP(h.config.Blocking.Nullroute)
nullroutev6 := net.ParseIP(h.config.Blocking.Nullroutev6)

switch IPQuery {
case _IP4Query:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: h.config.TTL,
}
a := &dns.A{Hdr: rrHeader, A: nullroute}
m.Answer = append(m.Answer, a)
case _IP6Query:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: h.config.TTL,
}
a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6}
m.Answer = append(m.Answer, a)
}
}

defer metric.ReportDNSRespond(remote, m, true, false)

logger.Noticef("%s found in blocklist\n", Q.Qname)

// cache the block; we don't know the true TTL for blocked entries: we just enforce our config
err := h.cache.Set(key, m, true)
if err != nil {
logger.Errorf("Set %s block cache failed: %s\n", Q.String(), err.Error())
}

return m, true
}
logger.Debugf("%s not found in blocklist\n", Q.Qname)
}

mesg, err := h.resolver.Lookup(Net, req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH)
resp, err := h.resolver.Lookup(Net, req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH)

if err != nil {
logger.Errorf("resolve query error %s\n", err)

// cache the failure, too!
if err = h.negCache.Set(key, nil, false); err != nil {
// TODO set TTL for failed errors
if err = h.negCache.Set(key, &lcache.DefaultEntry{}); err != nil {
logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err)
}
return nil, false
return nil, false, blocked, cached
}

// if we were doing DNS over UDP, and we got a truncated response,
// we retry in TCP in hopes that we do not get a truncated one again.
if mesg.Truncated && Net == "udp" {
mesg, err = h.resolver.Lookup("tcp", req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH)
if resp.Truncated && Net == "udp" {
resp, err = h.resolver.Lookup("tcp", req, h.config.Timeout, h.config.Interval, h.config.Upstream.Nameservers, h.config.Upstream.DoH)
if err != nil {
logger.Errorf("resolve tcp query error %s\n", err)

// cache the failure, too!
if err = h.negCache.Set(key, nil, false); err != nil {
// TODO set TTL for failed errors
if err = h.negCache.Set(key, &lcache.DefaultEntry{}); err != nil {
logger.Errorf("set %s negative cache failed: %v\n", Q.String(), err)
}
return nil, false
return nil, false, blocked, cached
}
}

//find the smallest ttl
ttl := h.config.Upstream.Expire
var candidateTTL uint32

for index, answer := range mesg.Answer {
for index, answer := range resp.Answer {
logger.Debugf("Answer %d - %s\n", index, answer.String())

candidateTTL = answer.Header().Ttl

// TODO is a zero TTL a forever TTL??
if candidateTTL > 0 && candidateTTL < ttl {
ttl = candidateTTL
}
}

defer metric.ReportDNSRespond(remote, mesg, false, false)

if IPQuery > 0 && len(mesg.Answer) > 0 {
if !lengActive && blacklisted {
logger.Debugf("%s is blacklisted and leng not active: not caching\n", Q.String())
} else {
err = h.cache.Set(key, mesg, false)
if IPQuery > 0 && len(resp.Answer) > 0 {
go func() {
err := h.cache.Set(key, &lcache.DefaultEntry{Msg: *resp})
if err != nil {
logger.Errorf("set %s cache failed: %s\n", Q.String(), err.Error())
logger.Warningf("set %s cache failed: %v\n", Q.String(), err)
}
logger.Debugf("insert %s into cache with ttl %d\n", Q.String(), ttl)
}
}()
}
return mesg, true
return resp, true, blocked, cached
}

func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) {
defer func(w dns.ResponseWriter) {
_ = w.Close()
}(w)

resp, ok := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr())
resp, ok, _, _ := h.responseFor(Net, req, w.LocalAddr(), w.RemoteAddr())

if !ok {
m := new(dns.Msg)
Expand All @@ -272,7 +223,7 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) {
for _, cname := range cnames {
r := dns.Msg{}
r.SetQuestion(cname.Target, req.Question[0].Qtype)
followed, ok := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr())
followed, ok, _, _ := h.responseFor(Net, &r, w.LocalAddr(), w.RemoteAddr())
for _, fAnswer := range followed.Answer {
containsNewAnswer := func(rr dns.RR) bool {
return rr.String() == fAnswer.String()
Expand Down Expand Up @@ -372,6 +323,40 @@ func (h *EventLoop) isIPQuery(q dns.Question) int {
return notIPQuery
}
}
func (h *EventLoop) blockedResponseFor(req *dns.Msg, IPQuery int) *dns.Msg {
m := new(dns.Msg)
m.SetReply(req)
q := req.Question[0]

if h.config.Blocking.NXDomain {
m.SetRcode(req, dns.RcodeNameError)
} else {
nullroute := net.ParseIP(h.config.Blocking.Nullroute)
nullroutev6 := net.ParseIP(h.config.Blocking.Nullroutev6)

switch IPQuery {
case _IP4Query:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: h.config.TTL,
}
a := &dns.A{Hdr: rrHeader, A: nullroute}
m.Answer = append(m.Answer, a)
case _IP6Query:
rrHeader := dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: h.config.TTL,
}
a := &dns.AAAA{Hdr: rrHeader, AAAA: nullroutev6}
m.Answer = append(m.Answer, a)
}
}
return m
}

// UnFqdn function
func UnFqdn(s string) string {
Expand Down
4 changes: 2 additions & 2 deletions lcache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ type lengCache[E Entry] struct {

// NewGeneric creates a new Cache
// maxSize <= 0 means the cache is unbounded
func NewGeneric[E Entry](maxSize int64) Cache[E] {
func NewGeneric[E Entry](maxSize int) Cache[E] {
return &lengCache[E]{
backend: sync.Map{},
size: atomic.Int64{},
maxSize: maxSize,
maxSize: int64(maxSize),
}
}

Expand Down
2 changes: 1 addition & 1 deletion lcache/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ func (dnsEntry DefaultEntry) RRs() []dns.RR {
return dnsEntry.Answer
}

func NewDefault(maxSize int64) Cache[DefaultEntry] {
func NewDefault(maxSize int) Cache[DefaultEntry] {
return NewGeneric[DefaultEntry](maxSize)
}

0 comments on commit c88d607

Please sign in to comment.