Skip to content

Commit

Permalink
Initial Version
Browse files Browse the repository at this point in the history
  • Loading branch information
jetloga committed Sep 20, 2020
0 parents commit 444d75d
Show file tree
Hide file tree
Showing 5 changed files with 486 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.idea
build
224 changes: 224 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package common

import (
"errors"
"fmt"
"github.com/phachon/go-logger"
"golang.org/x/net/dns/dnsmessage"
"gopkg.in/ini.v1"
"net"
"os"
"strconv"
"strings"
)

var Logger = go_logger.NewLogger()
var Config = &ConfigStruct{
Service: &ServiceConfig{
ListenAddr: "127.0.0.1:53",
ListenUDP: true,
ListenTCP: false,
},
Upstream: &UpstreamConfig{
UseUDP: true,
UseTCP: false,
DefaultUpstreams: make([]string, 0),
ARecordUpstreams: make([]string, 0),
AAAARecordUpstreams: make([]string, 0),
CNAMERecordUpstreams: make([]string, 0),
TXTRecordUpstreams: make([]string, 0),
PTRRecordUpstreams: make([]string, 0),
CustomRecordUpstream: make([]string, 0),
},
Log: &LogConfig{
LogFilePath: "/dev/null",
LogFileMaxSize: 4096,
LogLevelForFile: "info",
LogLevelForConsole: "info",
},
Advanced: &AdvancedConfig{
NSLookupTimeoutMs: 10000,
MaxReceivedPacketSize: 512,
},
}

var UpstreamsList [256][]*SocketAddr

func Init(configFilePath string) error {

if configFilePath != "" {
configFile, err := os.Open(configFilePath)
if err != nil {
return err
}
defer func() {
if err := configFile.Close(); err != nil {
Warning(err)
}
}()
cfg, err := ini.Load(configFilePath)
if err != nil {
return err
}
if err := cfg.MapTo(Config); err != nil {
return err
}
}
switch Config.Log.LogLevelForConsole {
case "debug":
_ = Logger.Detach("console")
_ = Logger.Attach("console", go_logger.LOGGER_LEVEL_DEBUG, &go_logger.ConsoleConfig{})
case "info":
_ = Logger.Detach("console")
_ = Logger.Attach("console", go_logger.LOGGER_LEVEL_INFO, &go_logger.ConsoleConfig{})
case "warning":
_ = Logger.Detach("console")
_ = Logger.Attach("console", go_logger.LOGGER_LEVEL_WARNING, &go_logger.ConsoleConfig{})
case "error":
_ = Logger.Detach("console")
_ = Logger.Attach("console", go_logger.LOGGER_LEVEL_ERROR, &go_logger.ConsoleConfig{})
case "none":
_ = Logger.Detach("console")
default:
Error("[COMMON]", "{Set Log Level}", "unknown log level", Config.Log.LogLevelForConsole)
}
for typeCode := range UpstreamsList {
switch dnsmessage.Type(typeCode) {
case dnsmessage.Type(0):
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.DefaultUpstreams))
for i, upstreamStr := range Config.Upstream.DefaultUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Info("[COMMON]", "{Load Default Upstream}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
case dnsmessage.TypeA:
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.ARecordUpstreams))
for i, upstreamStr := range Config.Upstream.ARecordUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Info("[COMMON]", "{Load Upstream For A Record}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
case dnsmessage.TypeAAAA:
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.AAAARecordUpstreams))
for i, upstreamStr := range Config.Upstream.AAAARecordUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Info("[COMMON]", "{Load Upstream For AAAA Record}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
case dnsmessage.TypeCNAME:
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.CNAMERecordUpstreams))
for i, upstreamStr := range Config.Upstream.CNAMERecordUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Info("[COMMON]", "{Load Upstream For CNAME Record}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
case dnsmessage.TypeTXT:
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.TXTRecordUpstreams))
for i, upstreamStr := range Config.Upstream.TXTRecordUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Debug("[COMMON]", "{Load Upstream For TXT Record}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
case dnsmessage.TypePTR:
UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.PTRRecordUpstreams))
for i, upstreamStr := range Config.Upstream.PTRRecordUpstreams {
socketAddr, err := NewSocketAddr(upstreamStr)
if err != nil {
return err
}
Debug("[COMMON]", "{Load Upstream For PTR Record}", socketAddr.UDPAddr.String())
UpstreamsList[typeCode][i] = socketAddr
}
default:
UpstreamsList[typeCode] = make([]*SocketAddr, 0)
}

}
for _, kvPair := range Config.Upstream.CustomRecordUpstream {
typeCodeStr, addr, err := ParseKVPair(kvPair)
if err != nil {
return err
}
typeCode, err := strconv.Atoi(typeCodeStr)
if err != nil {
return err
}
if typeCode < 0 || typeCode > 255 {
return errors.New("type code is not correct")
}
socketAddr, err := NewSocketAddr(addr)
if err != nil {
return err
}
UpstreamsList[typeCode] = append(UpstreamsList[typeCode], socketAddr)
}
Logger.Alert("DnsDiversion Started")
return nil
}

func ParseKVPair(kvPair string) (key, value string, err error) {
index := strings.Index(kvPair, ":")
if index < 0 {
return "", "", errors.New("key-value pair \"" + kvPair + "\" is not correct")
}
return kvPair[:index], kvPair[index+1:], nil
}

func NewSocketAddr(addr string) (*SocketAddr, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
return &SocketAddr{
UDPAddr: udpAddr,
TCPAddr: tcpAddr,
}, nil
}

func Error(objs ...interface{}) {
msg := ""
for _, obj := range objs {
msg += fmt.Sprint(obj) + " "
}
Logger.Error(strings.TrimSpace(msg))
}
func Warning(objs ...interface{}) {
msg := ""
for _, obj := range objs {
msg += fmt.Sprint(obj) + " "
}
Logger.Warning(strings.TrimSpace(msg))
}
func Info(objs ...interface{}) {
msg := ""
for _, obj := range objs {
msg += fmt.Sprint(obj) + " "
}
Logger.Info(strings.TrimSpace(msg))
}
func Debug(objs ...interface{}) {
msg := ""
for _, obj := range objs {
msg += fmt.Sprint(obj) + " "
}
Logger.Debug(strings.TrimSpace(msg))
}
53 changes: 53 additions & 0 deletions common/type.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package common

import (
"net"
"sync"
)

type ConfigStruct struct {
Service *ServiceConfig
Upstream *UpstreamConfig
Log *LogConfig
Advanced *AdvancedConfig
}

type ServiceConfig struct {
ListenAddr string `comment:"Listen Address (Example: [::]:53)"`
ListenUDP bool
ListenTCP bool
}

type UpstreamConfig struct {
UseUDP bool
UseTCP bool
DefaultUpstreams []string `comment:"Upstream List for Non-specific Record (Example: 223.5.5.5:53,223.6.6.6,[2001:da8::666]:53)"`
ARecordUpstreams []string `comment:"Upstream List for A Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"`
AAAARecordUpstreams []string `comment:"Upstream List for AAAA Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"`
CNAMERecordUpstreams []string `comment:"Upstream List for CNAME Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"`
TXTRecordUpstreams []string `comment:"Upstream List for TXT Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"`
PTRRecordUpstreams []string
CustomRecordUpstream []string `comment:"Upstream List for Custom Record (Example: 1:223.5.5.5:53,1:223.6.6.6,28:[2001:da8::666]:53)"`
}

type LogConfig struct {
LogFilePath string
LogFileMaxSize int
LogLevelForFile string
LogLevelForConsole string
}

type AdvancedConfig struct {
NSLookupTimeoutMs int
MaxReceivedPacketSize int
}

type SocketAddr struct {
*net.UDPAddr
*net.TCPAddr
}

type SafeQueue struct {
mutex sync.Mutex
contents []interface{}
}
116 changes: 116 additions & 0 deletions diversion/diversion.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package diversion

import (
"DnsDiversion/common"
"golang.org/x/net/dns/dnsmessage"
"net"
"sync/atomic"
"time"
)

var totalQueryCount uint64

func HandlePacket(bytes []byte, respCall func([]byte)) error {
msg := dnsmessage.Message{}
if err := msg.Unpack(bytes); err != nil {
return err
}
common.Debug("[DIVERSION]", "{Unpack DNS Message}", msg.GoString())
answers := make([]dnsmessage.Resource, 0)
numOfQueries := 0
for _, question := range msg.Questions {
queryType := dnsmessage.Type(0)
if len(common.UpstreamsList[question.Type]) != 0 {
queryType = question.Type
}
numOfQueries += len(common.UpstreamsList[queryType])
}
answerChan := make(chan []dnsmessage.Resource, numOfQueries)
idChan := make(chan int, len(msg.Questions))
receivedList := make([]bool, len(msg.Questions))
for id, question := range msg.Questions {
common.Debug("[DIVERSION]", "{Question}", question.Name, question.Type, question.Class)
queryType := dnsmessage.Type(0)
if len(common.UpstreamsList[question.Type]) != 0 {
queryType = question.Type
} else {
common.Debug("[DIVERSION]", "{Request Default Upstraeams}", question.Name, question.Type, question.Class)
}
for _, upstream := range common.UpstreamsList[queryType] {
newMsg := dnsmessage.Message{
Header: dnsmessage.Header{
ID: uint16(atomic.AddUint64(&totalQueryCount, 1) % 65536),
RCode: dnsmessage.RCodeSuccess,
},
Questions: make([]dnsmessage.Question, 1),
}
newMsg.Questions[0] = question
go requestUpstream(&newMsg, upstream, answerChan, id, idChan)
}
}
timer := time.NewTimer(time.Duration(common.Config.Advanced.NSLookupTimeoutMs) * time.Millisecond)
loop:
for {
select {
case id := <-idChan:
receivedList[id] = true
allReceived := true
for _, received := range receivedList {
if !received {
allReceived = false
break
}
}
if allReceived {
break loop
}
case myAnswers := <-answerChan:
answers = append(answers, myAnswers...)
case <-timer.C:
break loop
}
}
msg.Header.Response = true
msg.Answers = answers
bytes, err := msg.Pack()
if err != nil {
return err
}
respCall(bytes)
return nil
}

func requestUpstream(msg *dnsmessage.Message, upstream *common.SocketAddr, answerChan chan []dnsmessage.Resource, questionId int, idChan chan int) {
if common.Config.Upstream.UseUDP {
conn, err := net.DialUDP("udp", nil, upstream.UDPAddr)
if err != nil {
common.Warning("[DIVERSION]", "{Dial UDP}", upstream.UDPAddr, err)
}
defer func() { _ = conn.Close() }()
if err := conn.SetDeadline(time.Now().Add(time.Duration(common.Config.Advanced.NSLookupTimeoutMs) * time.Millisecond)); err != nil {
common.Warning("[DIVERSION]", "{Set UDP Timeout}", upstream.UDPAddr, err)
}
bytes, err := msg.Pack()
if err != nil {
common.Warning("[DIVERSION]", "{DNS Message Pack}", upstream.UDPAddr, err)
}
n, err := conn.Write(bytes)
if err != nil {
common.Warning("[DIVERSION]", "{Write UDP Packet}", upstream.UDPAddr, err)
}
common.Debug("[DIVERSION]", "{Write UDP Packet}", "Write", n, "bytes to", upstream.UDPAddr)
buffer := make([]byte, common.Config.Advanced.MaxReceivedPacketSize)
n, err = conn.Read(buffer)
if err != nil {
common.Warning("[DIVERSION]", "{Read UDP Packet}", upstream.UDPAddr, err)
}
common.Debug("[DIVERSION]", "{Read UDP Packet}", "Read", n, "bytes from", upstream.UDPAddr)
receivedMsg := dnsmessage.Message{}
if err := receivedMsg.Unpack(buffer); err != nil {
common.Warning("[DIVERSION]", "{DNS Unpack}", upstream.UDPAddr, err)
}
common.Debug("[DIVERSION]", "{Unpack DNS Message}", upstream.UDPAddr.String(), receivedMsg.GoString())
answerChan <- receivedMsg.Answers
idChan <- questionId
}
}
Loading

0 comments on commit 444d75d

Please sign in to comment.