diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4afcf19 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +build \ No newline at end of file diff --git a/common/common.go b/common/common.go new file mode 100644 index 0000000..9df0518 --- /dev/null +++ b/common/common.go @@ -0,0 +1,224 @@ +package common + +import ( + "errors" + "fmt" + "github.com/phachon/go-logger" + "golang.org/x/net/dns/dnsmessage" + "gopkg.in/ini.v1" + "net" + "os" + "strconv" + "strings" +) + +var Logger = go_logger.NewLogger() +var Config = &ConfigStruct{ + Service: &ServiceConfig{ + ListenAddr: "127.0.0.1:53", + ListenUDP: true, + ListenTCP: false, + }, + Upstream: &UpstreamConfig{ + UseUDP: true, + UseTCP: false, + DefaultUpstreams: make([]string, 0), + ARecordUpstreams: make([]string, 0), + AAAARecordUpstreams: make([]string, 0), + CNAMERecordUpstreams: make([]string, 0), + TXTRecordUpstreams: make([]string, 0), + PTRRecordUpstreams: make([]string, 0), + CustomRecordUpstream: make([]string, 0), + }, + Log: &LogConfig{ + LogFilePath: "/dev/null", + LogFileMaxSize: 4096, + LogLevelForFile: "info", + LogLevelForConsole: "info", + }, + Advanced: &AdvancedConfig{ + NSLookupTimeoutMs: 10000, + MaxReceivedPacketSize: 512, + }, +} + +var UpstreamsList [256][]*SocketAddr + +func Init(configFilePath string) error { + + if configFilePath != "" { + configFile, err := os.Open(configFilePath) + if err != nil { + return err + } + defer func() { + if err := configFile.Close(); err != nil { + Warning(err) + } + }() + cfg, err := ini.Load(configFilePath) + if err != nil { + return err + } + if err := cfg.MapTo(Config); err != nil { + return err + } + } + switch Config.Log.LogLevelForConsole { + case "debug": + _ = Logger.Detach("console") + _ = Logger.Attach("console", go_logger.LOGGER_LEVEL_DEBUG, &go_logger.ConsoleConfig{}) + case "info": + _ = Logger.Detach("console") + _ = Logger.Attach("console", go_logger.LOGGER_LEVEL_INFO, &go_logger.ConsoleConfig{}) + case "warning": + _ = Logger.Detach("console") + _ = Logger.Attach("console", go_logger.LOGGER_LEVEL_WARNING, &go_logger.ConsoleConfig{}) + case "error": + _ = Logger.Detach("console") + _ = Logger.Attach("console", go_logger.LOGGER_LEVEL_ERROR, &go_logger.ConsoleConfig{}) + case "none": + _ = Logger.Detach("console") + default: + Error("[COMMON]", "{Set Log Level}", "unknown log level", Config.Log.LogLevelForConsole) + } + for typeCode := range UpstreamsList { + switch dnsmessage.Type(typeCode) { + case dnsmessage.Type(0): + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.DefaultUpstreams)) + for i, upstreamStr := range Config.Upstream.DefaultUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Info("[COMMON]", "{Load Default Upstream}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + case dnsmessage.TypeA: + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.ARecordUpstreams)) + for i, upstreamStr := range Config.Upstream.ARecordUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Info("[COMMON]", "{Load Upstream For A Record}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + case dnsmessage.TypeAAAA: + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.AAAARecordUpstreams)) + for i, upstreamStr := range Config.Upstream.AAAARecordUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Info("[COMMON]", "{Load Upstream For AAAA Record}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + case dnsmessage.TypeCNAME: + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.CNAMERecordUpstreams)) + for i, upstreamStr := range Config.Upstream.CNAMERecordUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Info("[COMMON]", "{Load Upstream For CNAME Record}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + case dnsmessage.TypeTXT: + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.TXTRecordUpstreams)) + for i, upstreamStr := range Config.Upstream.TXTRecordUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Debug("[COMMON]", "{Load Upstream For TXT Record}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + case dnsmessage.TypePTR: + UpstreamsList[typeCode] = make([]*SocketAddr, len(Config.Upstream.PTRRecordUpstreams)) + for i, upstreamStr := range Config.Upstream.PTRRecordUpstreams { + socketAddr, err := NewSocketAddr(upstreamStr) + if err != nil { + return err + } + Debug("[COMMON]", "{Load Upstream For PTR Record}", socketAddr.UDPAddr.String()) + UpstreamsList[typeCode][i] = socketAddr + } + default: + UpstreamsList[typeCode] = make([]*SocketAddr, 0) + } + + } + for _, kvPair := range Config.Upstream.CustomRecordUpstream { + typeCodeStr, addr, err := ParseKVPair(kvPair) + if err != nil { + return err + } + typeCode, err := strconv.Atoi(typeCodeStr) + if err != nil { + return err + } + if typeCode < 0 || typeCode > 255 { + return errors.New("type code is not correct") + } + socketAddr, err := NewSocketAddr(addr) + if err != nil { + return err + } + UpstreamsList[typeCode] = append(UpstreamsList[typeCode], socketAddr) + } + Logger.Alert("DnsDiversion Started") + return nil +} + +func ParseKVPair(kvPair string) (key, value string, err error) { + index := strings.Index(kvPair, ":") + if index < 0 { + return "", "", errors.New("key-value pair \"" + kvPair + "\" is not correct") + } + return kvPair[:index], kvPair[index+1:], nil +} + +func NewSocketAddr(addr string) (*SocketAddr, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + tcpAddr, err := net.ResolveTCPAddr("tcp", addr) + if err != nil { + return nil, err + } + return &SocketAddr{ + UDPAddr: udpAddr, + TCPAddr: tcpAddr, + }, nil +} + +func Error(objs ...interface{}) { + msg := "" + for _, obj := range objs { + msg += fmt.Sprint(obj) + " " + } + Logger.Error(strings.TrimSpace(msg)) +} +func Warning(objs ...interface{}) { + msg := "" + for _, obj := range objs { + msg += fmt.Sprint(obj) + " " + } + Logger.Warning(strings.TrimSpace(msg)) +} +func Info(objs ...interface{}) { + msg := "" + for _, obj := range objs { + msg += fmt.Sprint(obj) + " " + } + Logger.Info(strings.TrimSpace(msg)) +} +func Debug(objs ...interface{}) { + msg := "" + for _, obj := range objs { + msg += fmt.Sprint(obj) + " " + } + Logger.Debug(strings.TrimSpace(msg)) +} diff --git a/common/type.go b/common/type.go new file mode 100644 index 0000000..c2bc66c --- /dev/null +++ b/common/type.go @@ -0,0 +1,53 @@ +package common + +import ( + "net" + "sync" +) + +type ConfigStruct struct { + Service *ServiceConfig + Upstream *UpstreamConfig + Log *LogConfig + Advanced *AdvancedConfig +} + +type ServiceConfig struct { + ListenAddr string `comment:"Listen Address (Example: [::]:53)"` + ListenUDP bool + ListenTCP bool +} + +type UpstreamConfig struct { + UseUDP bool + UseTCP bool + DefaultUpstreams []string `comment:"Upstream List for Non-specific Record (Example: 223.5.5.5:53,223.6.6.6,[2001:da8::666]:53)"` + ARecordUpstreams []string `comment:"Upstream List for A Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"` + AAAARecordUpstreams []string `comment:"Upstream List for AAAA Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"` + CNAMERecordUpstreams []string `comment:"Upstream List for CNAME Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"` + TXTRecordUpstreams []string `comment:"Upstream List for TXT Record (Example: 223.5.5.5,223.6.6.6,2001:da8::666)"` + PTRRecordUpstreams []string + CustomRecordUpstream []string `comment:"Upstream List for Custom Record (Example: 1:223.5.5.5:53,1:223.6.6.6,28:[2001:da8::666]:53)"` +} + +type LogConfig struct { + LogFilePath string + LogFileMaxSize int + LogLevelForFile string + LogLevelForConsole string +} + +type AdvancedConfig struct { + NSLookupTimeoutMs int + MaxReceivedPacketSize int +} + +type SocketAddr struct { + *net.UDPAddr + *net.TCPAddr +} + +type SafeQueue struct { + mutex sync.Mutex + contents []interface{} +} diff --git a/diversion/diversion.go b/diversion/diversion.go new file mode 100644 index 0000000..4b30150 --- /dev/null +++ b/diversion/diversion.go @@ -0,0 +1,116 @@ +package diversion + +import ( + "DnsDiversion/common" + "golang.org/x/net/dns/dnsmessage" + "net" + "sync/atomic" + "time" +) + +var totalQueryCount uint64 + +func HandlePacket(bytes []byte, respCall func([]byte)) error { + msg := dnsmessage.Message{} + if err := msg.Unpack(bytes); err != nil { + return err + } + common.Debug("[DIVERSION]", "{Unpack DNS Message}", msg.GoString()) + answers := make([]dnsmessage.Resource, 0) + numOfQueries := 0 + for _, question := range msg.Questions { + queryType := dnsmessage.Type(0) + if len(common.UpstreamsList[question.Type]) != 0 { + queryType = question.Type + } + numOfQueries += len(common.UpstreamsList[queryType]) + } + answerChan := make(chan []dnsmessage.Resource, numOfQueries) + idChan := make(chan int, len(msg.Questions)) + receivedList := make([]bool, len(msg.Questions)) + for id, question := range msg.Questions { + common.Debug("[DIVERSION]", "{Question}", question.Name, question.Type, question.Class) + queryType := dnsmessage.Type(0) + if len(common.UpstreamsList[question.Type]) != 0 { + queryType = question.Type + } else { + common.Debug("[DIVERSION]", "{Request Default Upstraeams}", question.Name, question.Type, question.Class) + } + for _, upstream := range common.UpstreamsList[queryType] { + newMsg := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: uint16(atomic.AddUint64(&totalQueryCount, 1) % 65536), + RCode: dnsmessage.RCodeSuccess, + }, + Questions: make([]dnsmessage.Question, 1), + } + newMsg.Questions[0] = question + go requestUpstream(&newMsg, upstream, answerChan, id, idChan) + } + } + timer := time.NewTimer(time.Duration(common.Config.Advanced.NSLookupTimeoutMs) * time.Millisecond) +loop: + for { + select { + case id := <-idChan: + receivedList[id] = true + allReceived := true + for _, received := range receivedList { + if !received { + allReceived = false + break + } + } + if allReceived { + break loop + } + case myAnswers := <-answerChan: + answers = append(answers, myAnswers...) + case <-timer.C: + break loop + } + } + msg.Header.Response = true + msg.Answers = answers + bytes, err := msg.Pack() + if err != nil { + return err + } + respCall(bytes) + return nil +} + +func requestUpstream(msg *dnsmessage.Message, upstream *common.SocketAddr, answerChan chan []dnsmessage.Resource, questionId int, idChan chan int) { + if common.Config.Upstream.UseUDP { + conn, err := net.DialUDP("udp", nil, upstream.UDPAddr) + if err != nil { + common.Warning("[DIVERSION]", "{Dial UDP}", upstream.UDPAddr, err) + } + defer func() { _ = conn.Close() }() + if err := conn.SetDeadline(time.Now().Add(time.Duration(common.Config.Advanced.NSLookupTimeoutMs) * time.Millisecond)); err != nil { + common.Warning("[DIVERSION]", "{Set UDP Timeout}", upstream.UDPAddr, err) + } + bytes, err := msg.Pack() + if err != nil { + common.Warning("[DIVERSION]", "{DNS Message Pack}", upstream.UDPAddr, err) + } + n, err := conn.Write(bytes) + if err != nil { + common.Warning("[DIVERSION]", "{Write UDP Packet}", upstream.UDPAddr, err) + } + common.Debug("[DIVERSION]", "{Write UDP Packet}", "Write", n, "bytes to", upstream.UDPAddr) + buffer := make([]byte, common.Config.Advanced.MaxReceivedPacketSize) + n, err = conn.Read(buffer) + if err != nil { + common.Warning("[DIVERSION]", "{Read UDP Packet}", upstream.UDPAddr, err) + } + common.Debug("[DIVERSION]", "{Read UDP Packet}", "Read", n, "bytes from", upstream.UDPAddr) + receivedMsg := dnsmessage.Message{} + if err := receivedMsg.Unpack(buffer); err != nil { + common.Warning("[DIVERSION]", "{DNS Unpack}", upstream.UDPAddr, err) + } + common.Debug("[DIVERSION]", "{Unpack DNS Message}", upstream.UDPAddr.String(), receivedMsg.GoString()) + answerChan <- receivedMsg.Answers + idChan <- questionId + } +} diff --git a/main/main.go b/main/main.go new file mode 100644 index 0000000..84199c0 --- /dev/null +++ b/main/main.go @@ -0,0 +1,91 @@ +package main + +import ( + "DnsDiversion/common" + "DnsDiversion/diversion" + "flag" + "net" + "sync" +) + +var configFilePath = flag.String("c", "", "Config File Path") + +func main() { + flag.Parse() + if err := common.Init(*configFilePath); err != nil { + common.Error(err) + return + } + + waitGroup := sync.WaitGroup{} + if common.Config.Service.ListenUDP { + func() { + waitGroup.Add(1) + defer waitGroup.Done() + udpAddr, err := net.ResolveUDPAddr("udp", common.Config.Service.ListenAddr) + listener, err := net.ListenUDP("udp", udpAddr) + if err != nil { + common.Error("[MAIN]", "{Listen UDP}", err) + return + } + for true { + buffer := make([]byte, common.Config.Advanced.MaxReceivedPacketSize) + n, addr, err := listener.ReadFromUDP(buffer) + if err != nil { + common.Warning("[MAIN]", "{Read UDP Packet}", addr, err) + continue + } + common.Debug("[MAIN]", "{Read UDP Packet}", "Read", n, "bytes from", addr) + go func() { + if err = diversion.HandlePacket(buffer, func(bytes []byte) { + n, err := listener.WriteToUDP(bytes, addr) + if err != nil { + common.Warning("[MAIN]", "{Write UDP Packet}", addr, err) + } + common.Debug("[MAIN]", "{Write UDP Packet}", "Write", n, "bytes to", addr) + }); err != nil { + common.Warning("[MAIN]", "{Handle DNS Packet}", addr, err) + } + }() + } + }() + } + + if common.Config.Service.ListenTCP { + func() { + waitGroup.Add(1) + defer waitGroup.Done() + tcpAddr, err := net.ResolveTCPAddr("tcp", common.Config.Service.ListenAddr) + listener, err := net.ListenTCP("tcp", tcpAddr) + if err != nil { + common.Error(err) + return + } + for true { + conn, err := listener.AcceptTCP() + if err != nil { + common.Error(err) + continue + } + go func() { + buffer := make([]byte, common.Config.Advanced.MaxReceivedPacketSize) + n, err := conn.Read(buffer) + if err != nil { + common.Error("[MAIN]", "{Read DNS Packet From TCP Connection}", conn.RemoteAddr(), err) + if err := conn.Close(); err != nil { + common.Error("[MAIN]", "{Close TCP Connection}", conn.RemoteAddr(), err) + } + return + } + common.Debug("[MAIN]", "{Read DNS Packet From TCP Connection}", "Read", n, "bytes from", conn.RemoteAddr()) + if err = diversion.HandlePacket(buffer, func(bytes []byte) { + + }); err != nil { + common.Error(err) + } + }() + } + }() + } + waitGroup.Wait() +}