From fa782822b045fa55af336073a19af509654aa08d Mon Sep 17 00:00:00 2001 From: Liu Liming Date: Sun, 6 Dec 2020 11:52:49 +0800 Subject: [PATCH] Add retries while network failure, and optimize action while server failure. --- cache/cache.go | 4 ++-- common/common.go | 9 ++++---- common/type.go | 1 + diversion/diversion.go | 52 +++++++++++++++++++++++++++--------------- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/cache/cache.go b/cache/cache.go index 43abe5b..e272310 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -45,7 +45,7 @@ func (dnsCache *Cache) UpdateItem(item *Item, msg *dnsmessage.Message) { } } -func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *network.SocketAddr, maxPacketSize int, updateFunc func(*dnsmessage.Message, *network.SocketAddr, int) (*dnsmessage.Message, error)) (*dnsmessage.Message, error) { +func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *network.SocketAddr, updateFunc func(*dnsmessage.Message, *network.SocketAddr) (*dnsmessage.Message, error)) (*dnsmessage.Message, error) { if queryMsg == nil || len(queryMsg.Questions) < 1 { return nil, errors.New("wrong dns message") } @@ -71,7 +71,7 @@ func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *ne if common.NeedDebug() { logger.Debug("Cache Miss", question.Name, question.Class, question.Type) } - msg, err := updateFunc(queryMsg, upstream, maxPacketSize) + msg, err := updateFunc(queryMsg, upstream) if err != nil { return nil, err } diff --git a/common/common.go b/common/common.go index 5d1e351..4802e72 100644 --- a/common/common.go +++ b/common/common.go @@ -10,7 +10,7 @@ const StandardMaxDNSPacketSize = 512 var Config = &ConfigStruct{ Service: &ServiceConfig{ - ListenAddr: "127.0.0.1:53", + ListenAddr: "[::]:53", ListenUDP: true, ListenTCP: false, }, @@ -29,16 +29,17 @@ var Config = &ConfigStruct{ MinTTL: 10, }, Log: &LogConfig{ - LogFilePath: "", + LogFilePath: "accdns.log", LogFileMaxSizeKB: 16 * 1024, LogLevelForFile: "info", LogLevelForConsole: "info", }, Advanced: &AdvancedConfig{ - NSLookupTimeoutMs: 10000, - RWTimeoutMs: 6000, + NSLookupTimeoutMs: 20000, + RWTimeoutMs: 8000, MaxReceivedPacketSize: 4096, ConnectionTimeout: 60, + NetworkFailedRetries: 3, }, } diff --git a/common/type.go b/common/type.go index c7f154e..4020bd4 100644 --- a/common/type.go +++ b/common/type.go @@ -36,6 +36,7 @@ type AdvancedConfig struct { RWTimeoutMs int MaxReceivedPacketSize int ConnectionTimeout int + NetworkFailedRetries int } type CacheConfig struct { diff --git a/diversion/diversion.go b/diversion/diversion.go index 867b1f2..ee34743 100644 --- a/diversion/diversion.go +++ b/diversion/diversion.go @@ -83,9 +83,9 @@ func HandlePacket(bytes []byte, respCall func([]byte), dnsCache *cache.Cache) er var receivedMsg *dnsmessage.Message var err error if dnsCache != nil { - receivedMsg, err = dnsCache.QueryAndUpdate(&newMsg, upstream, maxPacketSize, requestUpstreamDNS) + receivedMsg, err = dnsCache.QueryAndUpdate(&newMsg, upstream, requestUpstreamDNS) } else { - receivedMsg, err = requestUpstreamDNS(&newMsg, upstream, maxPacketSize) + receivedMsg, err = requestUpstreamDNS(&newMsg, upstream) } if err != nil { return @@ -139,7 +139,9 @@ loop: select { case myMsg := <-msgChan: appendMsgToResp(myMsg) - receivedList[<-idChan] = true + if myMsg.RCode == dnsmessage.RCodeSuccess { + receivedList[<-idChan] = true + } allReceived := true for _, received := range receivedList { if !received { @@ -183,18 +185,11 @@ loop: return nil } -func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAddr, maxPacketSize int) (*dnsmessage.Message, error) { +func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAddr) (*dnsmessage.Message, error) { + if common.NeedDebug() { logger.Debug("Request Upstream", upstreamAddr) } - conn, err := network.EstablishNewSocketConn(upstreamAddr) - defer func() { - _ = conn.Close() - }() - if err != nil { - logger.Warning("Dial Socket Connection", err) - return nil, err - } bytes, err := msg.Pack() if err != nil { logger.Warning("Pack DNS Packet", err) @@ -203,14 +198,33 @@ func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAdd if common.NeedDebug() { logger.Debug("Pack DNS Message", msg.GoString()) } - if _, err := conn.WritePacket(bytes); err != nil { - logger.Warning("Write DNS Packet", err) - return nil, err + var conn *network.SocketConn + var readBytes []byte + var networkErr error + for i := 0; i < common.Config.Advanced.NetworkFailedRetries; i++ { + func() { + conn, networkErr = network.EstablishNewSocketConn(upstreamAddr) + defer func() { + _ = conn.Close() + }() + if networkErr != nil { + logger.Warning("Dial Socket Connection", networkErr) + return + } + _, networkErr = conn.WritePacket(bytes) + if networkErr != nil { + logger.Warning("Write DNS Packet", networkErr) + return + } + readBytes, _, networkErr = conn.ReadPacket(common.Config.Advanced.MaxReceivedPacketSize) + if networkErr != nil { + logger.Warning("Read DNS Packet", networkErr) + return + } + }() } - readBytes, _, err := conn.ReadPacket(maxPacketSize) - if err != nil { - logger.Warning("Read DNS Packet", err) - return nil, err + if networkErr != nil { + return nil, networkErr } receivedMsg := &dnsmessage.Message{} if err := receivedMsg.Unpack(readBytes); err != nil {