Skip to content

Commit

Permalink
Add retries while network failure, and optimize action while server f…
Browse files Browse the repository at this point in the history
…ailure.
  • Loading branch information
jetloga committed Dec 6, 2020
1 parent 0f85620 commit fa78282
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 25 deletions.
4 changes: 2 additions & 2 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (dnsCache *Cache) UpdateItem(item *Item, msg *dnsmessage.Message) {
}
}

func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *network.SocketAddr, maxPacketSize int, updateFunc func(*dnsmessage.Message, *network.SocketAddr, int) (*dnsmessage.Message, error)) (*dnsmessage.Message, error) {
func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *network.SocketAddr, updateFunc func(*dnsmessage.Message, *network.SocketAddr) (*dnsmessage.Message, error)) (*dnsmessage.Message, error) {
if queryMsg == nil || len(queryMsg.Questions) < 1 {
return nil, errors.New("wrong dns message")
}
Expand All @@ -71,7 +71,7 @@ func (dnsCache *Cache) QueryAndUpdate(queryMsg *dnsmessage.Message, upstream *ne
if common.NeedDebug() {
logger.Debug("Cache Miss", question.Name, question.Class, question.Type)
}
msg, err := updateFunc(queryMsg, upstream, maxPacketSize)
msg, err := updateFunc(queryMsg, upstream)
if err != nil {
return nil, err
}
Expand Down
9 changes: 5 additions & 4 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const StandardMaxDNSPacketSize = 512

var Config = &ConfigStruct{
Service: &ServiceConfig{
ListenAddr: "127.0.0.1:53",
ListenAddr: "[::]:53",
ListenUDP: true,
ListenTCP: false,
},
Expand All @@ -29,16 +29,17 @@ var Config = &ConfigStruct{
MinTTL: 10,
},
Log: &LogConfig{
LogFilePath: "",
LogFilePath: "accdns.log",
LogFileMaxSizeKB: 16 * 1024,
LogLevelForFile: "info",
LogLevelForConsole: "info",
},
Advanced: &AdvancedConfig{
NSLookupTimeoutMs: 10000,
RWTimeoutMs: 6000,
NSLookupTimeoutMs: 20000,
RWTimeoutMs: 8000,
MaxReceivedPacketSize: 4096,
ConnectionTimeout: 60,
NetworkFailedRetries: 3,
},
}

Expand Down
1 change: 1 addition & 0 deletions common/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type AdvancedConfig struct {
RWTimeoutMs int
MaxReceivedPacketSize int
ConnectionTimeout int
NetworkFailedRetries int
}

type CacheConfig struct {
Expand Down
52 changes: 33 additions & 19 deletions diversion/diversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ func HandlePacket(bytes []byte, respCall func([]byte), dnsCache *cache.Cache) er
var receivedMsg *dnsmessage.Message
var err error
if dnsCache != nil {
receivedMsg, err = dnsCache.QueryAndUpdate(&newMsg, upstream, maxPacketSize, requestUpstreamDNS)
receivedMsg, err = dnsCache.QueryAndUpdate(&newMsg, upstream, requestUpstreamDNS)
} else {
receivedMsg, err = requestUpstreamDNS(&newMsg, upstream, maxPacketSize)
receivedMsg, err = requestUpstreamDNS(&newMsg, upstream)
}
if err != nil {
return
Expand Down Expand Up @@ -139,7 +139,9 @@ loop:
select {
case myMsg := <-msgChan:
appendMsgToResp(myMsg)
receivedList[<-idChan] = true
if myMsg.RCode == dnsmessage.RCodeSuccess {
receivedList[<-idChan] = true
}
allReceived := true
for _, received := range receivedList {
if !received {
Expand Down Expand Up @@ -183,18 +185,11 @@ loop:
return nil
}

func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAddr, maxPacketSize int) (*dnsmessage.Message, error) {
func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAddr) (*dnsmessage.Message, error) {

if common.NeedDebug() {
logger.Debug("Request Upstream", upstreamAddr)
}
conn, err := network.EstablishNewSocketConn(upstreamAddr)
defer func() {
_ = conn.Close()
}()
if err != nil {
logger.Warning("Dial Socket Connection", err)
return nil, err
}
bytes, err := msg.Pack()
if err != nil {
logger.Warning("Pack DNS Packet", err)
Expand All @@ -203,14 +198,33 @@ func requestUpstreamDNS(msg *dnsmessage.Message, upstreamAddr *network.SocketAdd
if common.NeedDebug() {
logger.Debug("Pack DNS Message", msg.GoString())
}
if _, err := conn.WritePacket(bytes); err != nil {
logger.Warning("Write DNS Packet", err)
return nil, err
var conn *network.SocketConn
var readBytes []byte
var networkErr error
for i := 0; i < common.Config.Advanced.NetworkFailedRetries; i++ {
func() {
conn, networkErr = network.EstablishNewSocketConn(upstreamAddr)
defer func() {
_ = conn.Close()
}()
if networkErr != nil {
logger.Warning("Dial Socket Connection", networkErr)
return
}
_, networkErr = conn.WritePacket(bytes)
if networkErr != nil {
logger.Warning("Write DNS Packet", networkErr)
return
}
readBytes, _, networkErr = conn.ReadPacket(common.Config.Advanced.MaxReceivedPacketSize)
if networkErr != nil {
logger.Warning("Read DNS Packet", networkErr)
return
}
}()
}
readBytes, _, err := conn.ReadPacket(maxPacketSize)
if err != nil {
logger.Warning("Read DNS Packet", err)
return nil, err
if networkErr != nil {
return nil, networkErr
}
receivedMsg := &dnsmessage.Message{}
if err := receivedMsg.Unpack(readBytes); err != nil {
Expand Down

0 comments on commit fa78282

Please sign in to comment.