diff --git a/dns.go b/dns.go index f1fee8a..f1c565c 100644 --- a/dns.go +++ b/dns.go @@ -4,8 +4,7 @@ package main import ( "encoding/binary" - "encoding/hex" - "fmt" + "errors" "io" "net" "strings" @@ -110,77 +109,15 @@ func (s *DNS) ListenAndServeUDP() { data = data[:n] go func() { - query := parseQuery(data) - domain := query.DomainName - dnsServer := s.GetServer(domain) - if s.tunnel { - dnsServer = s.dnsServer - } + _, respMsg := s.handleReqMsg(uint16(len(data)), data) - rc, err := s.sDialer.NextDialer(domain+":53").Dial("tcp", dnsServer) + _, err = c.WriteTo(respMsg, clientAddr) if err != nil { - logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) - return - } - defer rc.Close() - - // 2 bytes length after tcp header, before dns message - reqLen := make([]byte, 2) - binary.BigEndian.PutUint16(reqLen, uint16(len(data))) - - rc.Write(reqLen) - rc.Write(data) - - // fmt.Printf("dns req len %d:\n%s\n\n", reqLen, hex.Dump(data[:])) - - var respLen uint16 - err = binary.Read(rc, binary.BigEndian, &respLen) - if err != nil { - logf("proxy-dns error in read respLen %s\n", err) - return + logf("proxy-dns error in local write: %s\n", err) } - respMsg := make([]byte, respLen) - _, err = io.ReadFull(rc, respMsg) - if err != nil { - logf("proxy-dns error in read respMsg %s\n", err) - return - } - - // fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:])) - - var ip string - // length is not needed in udp dns response. (2 bytes) - // SEE RFC1035, section 4.2.2 TCP: The message is prefixed with a two byte length field which gives the message length, excluding the two byte length field. - if respLen > 0 { - - // run handle functions before send to client so RULE and IPSET can take effect - // TODO: add PRE_HANDLERS - query := parseQuery(respMsg) - if (query.QueryType == DNSQueryTypeA || query.QueryType == DNSQueryTypeAAAA) && - len(respMsg) > query.Offset { - - answers := parseAnswers(respMsg[query.Offset:]) - - for _, answer := range answers { - if answer.IP != "" { - ip += answer.IP + "," - } - - for _, h := range s.answerHandlers { - h(query.DomainName, answer.IP) - } - } - } - - _, err = c.WriteTo(respMsg, clientAddr) - if err != nil { - logf("proxy-dns error in local write: %s\n", err) - } - } - - logf("proxy-dns %s <-> %s, type: %d, %s: %s", clientAddr.String(), dnsServer, query.QueryType, domain, ip) + // logf("proxy-dns %s <-> %s, type: %d, %s: %s", clientAddr.String(), dnsServer, query.QueryType, domain, ip) }() } @@ -223,19 +160,37 @@ func (s *DNS) ServeTCP(c net.Conn) { reqMsg := make([]byte, reqLen) _, err := io.ReadFull(c, reqMsg) if err != nil { - logf("proxy-dns-tcp error in read reqMsg %s\n", err) + logf("proxy-dns-tcp error in read reqMsg %s", err) return } - query := parseQuery(reqMsg) - domain := query.DomainName + respLen, respMsg := s.handleReqMsg(reqLen, reqMsg) - dnsServer := s.GetServer(domain) + if err := binary.Write(c, binary.BigEndian, respLen); err != nil { + logf("proxy-dns-tcp error in local write respLen: %s\n", err) + } + if err := binary.Write(c, binary.BigEndian, respMsg); err != nil { + logf("proxy-dns-tcp error in local write respMsg: %s\n", err) + } + + // logf("proxy-dns-tcp %s <-> %s, type: %d, %s: %s", c.RemoteAddr(), dnsServer, query.QueryType, domain, ip) +} + +// handle request msg and return response msg +func (s *DNS) handleReqMsg(reqLen uint16, reqMsg []byte) (respLen uint16, respMsg []byte) { + + query, err := parseQuery(reqMsg) + if err != nil { + logf("proxy-dns-tcp error in parseQuery reqMsg %s", err) + return + } + + dnsServer := s.GetServer(query.DomainName) if s.tunnel { dnsServer = s.dnsServer } - rc, err := s.sDialer.NextDialer(domain+":53").Dial("tcp", dnsServer) + rc, err := s.sDialer.NextDialer(query.DomainName+":53").Dial("tcp", dnsServer) if err != nil { logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) return @@ -243,30 +198,34 @@ func (s *DNS) ServeTCP(c net.Conn) { defer rc.Close() if err := binary.Write(rc, binary.BigEndian, reqLen); err != nil { - + logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) } if err := binary.Write(rc, binary.BigEndian, reqMsg); err != nil { - + logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) } - var respLen uint16 if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil { logf("proxy-dns-tcp failed to read response length: %v", err) return } - respMsg := make([]byte, respLen) + respMsg = make([]byte, respLen) _, err = io.ReadFull(rc, respMsg) if err != nil { logf("proxy-dns-tcp error in read respMsg %s\n", err) return } - fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:])) + // fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:])) var ip string if respLen > 0 { - query := parseQuery(respMsg) + query, err := parseQuery(respMsg) + if err != nil { + logf("proxy-dns error in parseQuery respMsg %s", err) + return + } + if (query.QueryType == DNSQueryTypeA || query.QueryType == DNSQueryTypeAAAA) && len(respMsg) > query.Offset { @@ -283,16 +242,9 @@ func (s *DNS) ServeTCP(c net.Conn) { } } - if err := binary.Write(c, binary.BigEndian, respLen); err != nil { - logf("proxy-dns-tcp error in local write respLen: %s\n", err) - } - if err := binary.Write(c, binary.BigEndian, respMsg); err != nil { - logf("proxy-dns-tcp error in local write respMsg: %s\n", err) - } } - logf("proxy-dns-tcp %s <-> %s, type: %d, %s: %s", c.RemoteAddr(), dnsServer, query.QueryType, domain, ip) - + return } // SetServer . @@ -321,7 +273,7 @@ func (s *DNS) AddAnswerHandler(h DNSAnswerHandler) { s.answerHandlers = append(s.answerHandlers, h) } -func parseQuery(p []byte) *dnsQuery { +func parseQuery(p []byte) (*dnsQuery, error) { q := &dnsQuery{} var i int @@ -341,11 +293,16 @@ func parseQuery(p []byte) *dnsQuery { } q.DomainName = string(domain[:len(domain)-1]) + + if len(p) < i+4 { + return nil, errors.New("parseQuery error, not enough data") + } + q.QueryType = binary.BigEndian.Uint16(p[i:]) q.QueryClass = binary.BigEndian.Uint16(p[i+2:]) q.Offset = i + 4 - return q + return q, nil } func parseAnswers(p []byte) []*dnsAnswer {