From d7c0bf2a1ff0837f5378667349fb023bb28446b7 Mon Sep 17 00:00:00 2001 From: Cottand Date: Wed, 18 Dec 2024 16:14:08 +0000 Subject: [PATCH] refacto handler to use new cache --- grimd_test.go | 12 +++++------- handler.go | 16 +++++++++++++--- server.go | 2 +- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/grimd_test.go b/grimd_test.go index c96c461..a10f025 100644 --- a/grimd_test.go +++ b/grimd_test.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/cottand/leng/internal/metric" "github.com/pelletier/go-toml/v2" + "github.com/stretchr/testify/assert" "io" "net/http" "slices" @@ -389,14 +390,11 @@ func TestConfigReloadForCustomRecords(t *testing.T) { m1 = new(dns.Msg) m1.SetQuestion(dns.Fqdn("old.com_custom"), dns.TypeA) + // no caching available, so this request requires internet! reply, _, err = c.Exchange(m1, testDnsHost) - if err != nil { - fmt.Printf("Err was %v - expected this", err) - t.FailNow() - } - if len(reply.Answer) != 0 { - t.Fatalf("expected old.com_custom DNS to fail, but got %v", reply) - } + // no response but no error + assert.NoError(t, err) + assert.Len(t, reply.Answer, 0) m1 = new(dns.Msg) m1.SetQuestion(dns.Fqdn("boo.org"), dns.TypeA) diff --git a/handler.go b/handler.go index 377ca76..a8d0dec 100644 --- a/handler.go +++ b/handler.go @@ -101,13 +101,23 @@ func (h *EventLoop) responseFor(Net string, req *dns.Msg, _local net.Addr, _remo return custom, true, false, true } - // does not include custom DNS - defer metric.ReportDNSRespond(remote, resp, blocked, cached) + if len(req.Question) < 1 { + return nil, false, false, false + } q := req.Question[0] Q := Question{UnFqdn(q.Name), dns.TypeToString[q.Qtype], dns.ClassToString[q.Qclass]} logger.Infof("%s lookup %s\n", remote, Q.String()) + defer func() { + if resp != nil { + resp.SetReply(req) + } + if resp != nil && remote != nil { + metric.ReportDNSRespond(remote, resp, blocked, cached) + } + }() + IPQuery := h.isIPQuery(q) blocked = IPQuery > 0 && lengActive && h.blockCache.Exists(Q.Qname) @@ -207,6 +217,7 @@ func (h *EventLoop) doRequest(Net string, w dns.ResponseWriter, req *dns.Msg) { if !ok { m := new(dns.Msg) + m.SetReply(req) m.SetRcode(req, dns.RcodeServerFailure) WriteReplyMsg(w, m) metric.ReportDNSResponse(w, m, false) @@ -325,7 +336,6 @@ func (h *EventLoop) isIPQuery(q dns.Question) int { } func (h *EventLoop) blockedResponseFor(req *dns.Msg, IPQuery int) *dns.Msg { m := new(dns.Msg) - m.SetReply(req) q := req.Question[0] if h.config.Blocking.NXDomain { diff --git a/server.go b/server.go index 025101b..afc320c 100644 --- a/server.go +++ b/server.go @@ -119,5 +119,5 @@ func (s *Server) Stop() { func (s *Server) ReloadConfig(config *Config) { newRecords := NewCustomDNSRecordsFromText(config.CustomDNSRecords) s.eventLoop.customDns = NewCustomRecordsResolver(newRecords) - defer metric.CustomDNSConfigReload.Inc() + metric.CustomDNSConfigReload.Inc() }