From a652f8db0447da1bae86a0d2514fe7d29917f3de Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Mon, 8 Jan 2018 23:37:58 +0800 Subject: [PATCH] dns: add a parameter in dns struct to identify tunnel --- dns.go | 43 ++++++++++++++++++++++++++++++------------- dnstun.go | 2 +- main.go | 10 +++++----- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/dns.go b/dns.go index 859908d..f1fee8a 100644 --- a/dns.go +++ b/dns.go @@ -4,6 +4,8 @@ package main import ( "encoding/binary" + "encoding/hex" + "fmt" "io" "net" "strings" @@ -56,6 +58,8 @@ type DNS struct { *Forwarder // as proxy client sDialer Dialer // dialer for server + tunnel bool + dnsServer string dnsServerMap map[string]string @@ -63,11 +67,13 @@ type DNS struct { } // NewDNS returns a dns forwarder. client[dns.udp] -> glider[tcp] -> forwarder[dns.tcp] -> remote dns addr -func NewDNS(addr, raddr string, sDialer Dialer) (*DNS, error) { +func NewDNS(addr, raddr string, sDialer Dialer, tunnel bool) (*DNS, error) { s := &DNS{ Forwarder: NewForwarder(addr, nil), sDialer: sDialer, + tunnel: tunnel, + dnsServer: raddr, dnsServerMap: make(map[string]string), } @@ -90,7 +96,7 @@ func (s *DNS) ListenAndServeUDP() { } defer c.Close() - logf("proxy-dns listening UDP on %s", s.addr) + logf("proxy-dns listening on udp:%s", s.addr) for { data := make([]byte, DNSUDPMaxLen) @@ -107,9 +113,9 @@ func (s *DNS) ListenAndServeUDP() { query := parseQuery(data) domain := query.DomainName - dnsServer := s.dnsServer - if dnsServer == "" { - dnsServer = s.GetServer(domain) + dnsServer := s.GetServer(domain) + if s.tunnel { + dnsServer = s.dnsServer } rc, err := s.sDialer.NextDialer(domain+":53").Dial("tcp", dnsServer) @@ -122,6 +128,7 @@ func (s *DNS) ListenAndServeUDP() { // 2 bytes length after tcp header, before dns message reqLen := make([]byte, 2) binary.BigEndian.PutUint16(reqLen, uint16(len(data))) + rc.Write(reqLen) rc.Write(data) @@ -187,7 +194,7 @@ func (s *DNS) ListenAndServeTCP() { return } - logf("proxy-dns-tcp listening TCP on %s", s.addr) + logf("proxy-dns-tcp listening on tcp:%s", s.addr) for { c, err := l.Accept() @@ -223,9 +230,9 @@ func (s *DNS) ServeTCP(c net.Conn) { query := parseQuery(reqMsg) domain := query.DomainName - dnsServer := s.dnsServer - if dnsServer == "" { - dnsServer = s.GetServer(domain) + dnsServer := s.GetServer(domain) + if s.tunnel { + dnsServer = s.dnsServer } rc, err := s.sDialer.NextDialer(domain+":53").Dial("tcp", dnsServer) @@ -235,8 +242,12 @@ func (s *DNS) ServeTCP(c net.Conn) { } defer rc.Close() - binary.Write(rc, binary.BigEndian, reqLen) - binary.Write(rc, binary.BigEndian, reqMsg) + if err := binary.Write(rc, binary.BigEndian, reqLen); err != nil { + + } + if err := binary.Write(rc, binary.BigEndian, reqMsg); err != nil { + + } var respLen uint16 if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil { @@ -251,6 +262,8 @@ func (s *DNS) ServeTCP(c net.Conn) { return } + fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:])) + var ip string if respLen > 0 { query := parseQuery(respMsg) @@ -270,8 +283,12 @@ func (s *DNS) ServeTCP(c net.Conn) { } } - binary.Write(c, binary.BigEndian, respLen) - binary.Write(c, binary.BigEndian, respMsg) + if err := binary.Write(c, binary.BigEndian, respLen); err != nil { + logf("proxy-dns-tcp error in local write respLen: %s\n", err) + } + if err := binary.Write(c, binary.BigEndian, respMsg); err != nil { + logf("proxy-dns-tcp error in local write respMsg: %s\n", err) + } } logf("proxy-dns-tcp %s <-> %s, type: %d, %s: %s", c.RemoteAddr(), dnsServer, query.QueryType, domain, ip) diff --git a/dnstun.go b/dnstun.go index be6de21..27dad71 100644 --- a/dnstun.go +++ b/dnstun.go @@ -22,7 +22,7 @@ func NewDNSTun(addr, raddr string, sDialer Dialer) (*DNSTun, error) { raddr: raddr, } - s.dns, _ = NewDNS(addr, raddr, sDialer) + s.dns, _ = NewDNS(addr, raddr, sDialer, true) return s, nil } diff --git a/main.go b/main.go index 909ce8e..8a9193c 100644 --- a/main.go +++ b/main.go @@ -49,16 +49,16 @@ func main() { } if conf.DNS != "" { - dns, err := NewDNS(conf.DNS, conf.DNSServer[0], sDialer) + dns, err := NewDNS(conf.DNS, conf.DNSServer[0], sDialer, false) if err != nil { log.Fatal(err) } // rule - for _, fwdr := range conf.rules { - for _, domain := range fwdr.Domain { - if len(fwdr.DNSServer) > 0 { - dns.SetServer(domain, fwdr.DNSServer[0]) + for _, r := range conf.rules { + for _, domain := range r.Domain { + if len(r.DNSServer) > 0 { + dns.SetServer(domain, r.DNSServer[0]) } } }