diff --git a/dns.go b/dns.go index 25149e1..5f5ebac 100644 --- a/dns.go +++ b/dns.go @@ -4,7 +4,7 @@ package main import ( "encoding/binary" - "io/ioutil" + "io" "net" "strings" ) @@ -98,13 +98,11 @@ func (s *DNS) ListenAndServe() { data = data[:n] go func() { - // TODO: check domain rules and get a proper upstream name server. query := parseQuery(data) domain := query.DomainName dnsServer := s.GetServer(domain) - // TODO: check here; ADD dnsServer to rule ip lists rc, err := s.sDialer.NextDialer(domain+":53").Dial("tcp", dnsServer) if err != nil { logf("failed to connect to server %v: %v", dnsServer, err) @@ -113,26 +111,36 @@ func (s *DNS) ListenAndServe() { defer rc.Close() // 2 bytes length after tcp header, before dns message - length := make([]byte, 2) - binary.BigEndian.PutUint16(length, uint16(len(data))) - rc.Write(length) + reqLen := make([]byte, 2) + binary.BigEndian.PutUint16(reqLen, uint16(len(data))) + rc.Write(reqLen) rc.Write(data) - resp, err := ioutil.ReadAll(rc) + // fmt.Printf("dns req len %d:\n%s\n\n", reqLen, hex.Dump(data[:])) + + var respLen uint16 + err = binary.Read(rc, binary.BigEndian, &respLen) if err != nil { - logf("error in ioutil.ReadAll: %s\n", err) + logf("proxy-dns: error in read respLen %s\n", err) return } + respMsg := make([]byte, respLen) + _, err = io.ReadFull(rc, respMsg) + if err != nil { + logf("proxy-dns: error in read respMsg %s\n", err) + return + } + + // fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:])) + var ip string // length is not needed in udp dns response. (2 bytes) // SEE RFC1035, section 4.2.2 TCP: The message is prefixed with a two byte length field which gives the message length, excluding the two byte length field. - if len(resp) > 2 { - msg := resp[2:] - // TODO: Get IP from response, check and add to ipset - query := parseQuery(msg) - if len(msg) > query.Offset { - answers := parseAnswers(msg[query.Offset:]) + if respLen > 0 { + query := parseQuery(respMsg) + if len(respMsg) > query.Offset { + answers := parseAnswers(respMsg[query.Offset:]) for _, answer := range answers { if answer.IP != "" { ip += answer.IP + "," @@ -146,7 +154,7 @@ func (s *DNS) ListenAndServe() { } - _, err = c.WriteTo(msg, clientAddr) + _, err = c.WriteTo(respMsg, clientAddr) if err != nil { logf("error in local write: %s\n", err) } diff --git a/main.go b/main.go index d985c33..b627fc6 100644 --- a/main.go +++ b/main.go @@ -9,7 +9,7 @@ import ( ) // VERSION . -const VERSION = "0.4.2" +const VERSION = "0.4.3" func dialerFromConf() Dialer { // global forwarders in xx.conf diff --git a/rule.go b/rule.go index 475d63c..4769439 100644 --- a/rule.go +++ b/rule.go @@ -61,7 +61,6 @@ func (rd *RuleDialer) Addr() string { return "RULES" } // NextDialer return next dialer according to rule func (rd *RuleDialer) NextDialer(dstAddr string) Dialer { - // TODO: change to index finders host, _, err := net.SplitHostPort(dstAddr) if err != nil { // TODO: check here