diff --git a/dns.go b/dns.go index f1c565c..640a1a6 100644 --- a/dns.go +++ b/dns.go @@ -10,11 +10,8 @@ import ( "strings" ) -// DNSUDPHeaderLen is the length of UDP dns msg header -const DNSUDPHeaderLen = 12 - -// DNSTCPHeaderLen is the length of TCP dns msg header -const DNSTCPHeaderLen = 2 + DNSUDPHeaderLen +// DNSHeaderLen is the length of dns msg header +const DNSHeaderLen = 12 // DNSUDPMaxLen is the max size of udp dns request. // https://tools.ietf.org/html/rfc1035#section-4.2.1 @@ -25,26 +22,114 @@ const DNSTCPHeaderLen = 2 + DNSUDPHeaderLen // so we should also serve tcp requests. const DNSUDPMaxLen = 512 -// DNSQueryTypeA ipv4 -const DNSQueryTypeA = 1 +// DNSQTypeA ipv4 +const DNSQTypeA = 1 -// DNSQueryTypeAAAA ipv6 -const DNSQueryTypeAAAA = 28 +// DNSQTypeAAAA ipv6 +const DNSQTypeAAAA = 28 -type dnsQuery struct { - DomainName string - QueryType uint16 - QueryClass uint16 - Offset int +// DNSMsg format +// https://tools.ietf.org/html/rfc1035#section-4.1 +// All communications inside of the domain protocol are carried in a single +// format called a message. The top level format of message is divided +// into 5 sections (some of which are empty in certain cases) shown below: +// +// +---------------------+ +// | Header | +// +---------------------+ +// | Question | the question for the name server +// +---------------------+ +// | Answer | RRs answering the question +// +---------------------+ +// | Authority | RRs pointing toward an authority +// +---------------------+ +// | Additional | RRs holding additional information +type DNSMsg struct { + DNSHeader + Questions []*DNSQuestion + Answers []*DNSRR } -type dnsAnswer struct { - // DomainName string - QueryType uint16 - QueryClass uint16 - TTL uint32 - DataLength uint16 - Data []byte +// DNSHeader format +// https://tools.ietf.org/html/rfc1035#section-4.1.1 +// The header contains the following fields: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ID | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// |QR| Opcode |AA|TC|RD|RA| Z | RCODE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QDCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ANCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | NSCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ARCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// +type DNSHeader struct { + ID uint16 +} + +// DNSQuestion format +// https://tools.ietf.org/html/rfc1035#section-4.1.2 +// The question section is used to carry the "question" in most queries, +// i.e., the parameters that define what is being asked. The section +// contains QDCOUNT (usually 1) entries, each of the following format: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / QNAME / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QTYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QCLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +type DNSQuestion struct { + QNAME string + QTYPE uint16 + QCLASS uint16 + + Offset int +} + +// DNSRR format +// https://tools.ietf.org/html/rfc1035#section-3.2.1 +// All RRs have the same top level format shown below: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / / +// / NAME / +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | CLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TTL | +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | RDLENGTH | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +// / RDATA / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +type DNSRR struct { + // NAME string + TYPE uint16 + CLASS uint16 + TTL uint32 + RDLENGTH uint16 + RDATA []byte IP string } @@ -57,12 +142,12 @@ type DNS struct { *Forwarder // as proxy client sDialer Dialer // dialer for server - tunnel bool + Tunnel bool - dnsServer string + DNSServer string - dnsServerMap map[string]string - answerHandlers []DNSAnswerHandler + DNSServerMap map[string]string + AnswerHandlers []DNSAnswerHandler } // NewDNS returns a dns forwarder. client[dns.udp] -> glider[tcp] -> forwarder[dns.tcp] -> remote dns addr @@ -71,10 +156,10 @@ func NewDNS(addr, raddr string, sDialer Dialer, tunnel bool) (*DNS, error) { Forwarder: NewForwarder(addr, nil), sDialer: sDialer, - tunnel: tunnel, + Tunnel: tunnel, - dnsServer: raddr, - dnsServerMap: make(map[string]string), + DNSServer: raddr, + DNSServerMap: make(map[string]string), } return s, nil @@ -90,31 +175,34 @@ func (s *DNS) ListenAndServe() { func (s *DNS) ListenAndServeUDP() { c, err := net.ListenPacket("udp", s.addr) if err != nil { - logf("proxy-dns failed to listen on %s: %v", s.addr, err) + logf("proxy-dns failed to listen on %s, error: %v", s.addr, err) return } defer c.Close() - logf("proxy-dns listening on udp:%s", s.addr) + logf("proxy-dns listening UDP on %s", s.addr) for { - data := make([]byte, DNSUDPMaxLen) - - n, clientAddr, err := c.ReadFrom(data) + b := make([]byte, DNSUDPMaxLen) + n, clientAddr, err := c.ReadFrom(b) if err != nil { - logf("proxy-dns DNS local read error: %v", err) + logf("proxy-dns local read error: %v", err) continue } - data = data[:n] + reqLen := uint16(n) + if reqLen <= DNSHeaderLen { + logf("proxy-dns not enough data") + continue + } + b = b[:n] go func() { - - _, respMsg := s.handleReqMsg(uint16(len(data)), data) - + _, respMsg := s.handleReqMsg(reqLen, b) _, err = c.WriteTo(respMsg, clientAddr) if err != nil { - logf("proxy-dns error in local write: %s\n", err) + logf("proxy-dns error in local write: %s", err) + return } // logf("proxy-dns %s <-> %s, type: %d, %s: %s", clientAddr.String(), dnsServer, query.QueryType, domain, ip) @@ -131,7 +219,7 @@ func (s *DNS) ListenAndServeTCP() { return } - logf("proxy-dns-tcp listening on tcp:%s", s.addr) + logf("proxy-dns-tcp listening TCP on %s", s.addr) for { c, err := l.Accept() @@ -157,6 +245,11 @@ func (s *DNS) ServeTCP(c net.Conn) { return } + if reqLen <= DNSHeaderLen { + logf("proxy-dns not enough data") + return + } + reqMsg := make([]byte, reqLen) _, err := io.ReadFull(c, reqMsg) if err != nil { @@ -165,7 +258,6 @@ func (s *DNS) ServeTCP(c net.Conn) { } respLen, respMsg := s.handleReqMsg(reqLen, reqMsg) - if err := binary.Write(c, binary.BigEndian, respLen); err != nil { logf("proxy-dns-tcp error in local write respLen: %s\n", err) } @@ -178,19 +270,19 @@ func (s *DNS) ServeTCP(c net.Conn) { // handle request msg and return response msg func (s *DNS) handleReqMsg(reqLen uint16, reqMsg []byte) (respLen uint16, respMsg []byte) { - - query, err := parseQuery(reqMsg) + // fmt.Printf("dns req len %d:\n%s\n\n", reqLen, hex.Dump(reqMsg[:])) + query, err := parseQuestion(reqMsg) if err != nil { - logf("proxy-dns-tcp error in parseQuery reqMsg %s", err) + logf("proxy-dns error in parseQuestion reqMsg %s", err) return } - dnsServer := s.GetServer(query.DomainName) - if s.tunnel { - dnsServer = s.dnsServer + dnsServer := s.GetServer(query.QNAME) + if s.Tunnel { + dnsServer = s.DNSServer } - rc, err := s.sDialer.NextDialer(query.DomainName+":53").Dial("tcp", dnsServer) + rc, err := s.sDialer.NextDialer(query.QNAME+":53").Dial("tcp", dnsServer) if err != nil { logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) return @@ -198,21 +290,23 @@ func (s *DNS) handleReqMsg(reqLen uint16, reqMsg []byte) (respLen uint16, respMs defer rc.Close() if err := binary.Write(rc, binary.BigEndian, reqLen); err != nil { - logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) + logf("proxy-dns failed to write req length: %v", err) + return } if err := binary.Write(rc, binary.BigEndian, reqMsg); err != nil { - logf("proxy-dns failed to connect to server %v: %v", dnsServer, err) + logf("proxy-dns failed to write req message: %v", err) + return } if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil { - logf("proxy-dns-tcp failed to read response length: %v", err) + logf("proxy-dns failed to read response length: %v", err) return } respMsg = make([]byte, respLen) _, err = io.ReadFull(rc, respMsg) if err != nil { - logf("proxy-dns-tcp error in read respMsg %s\n", err) + logf("proxy-dns error in read respMsg %s\n", err) return } @@ -220,24 +314,23 @@ func (s *DNS) handleReqMsg(reqLen uint16, reqMsg []byte) (respLen uint16, respMs var ip string if respLen > 0 { - query, err := parseQuery(respMsg) + query, err := parseQuestion(respMsg) if err != nil { - logf("proxy-dns error in parseQuery respMsg %s", err) + logf("proxy-dns error in parseQuestion respMsg %s", err) return } - if (query.QueryType == DNSQueryTypeA || query.QueryType == DNSQueryTypeAAAA) && + if (query.QTYPE == DNSQTypeA || query.QTYPE == DNSQTypeAAAA) && len(respMsg) > query.Offset { answers := parseAnswers(respMsg[query.Offset:]) - for _, answer := range answers { if answer.IP != "" { ip += answer.IP + "," } - for _, h := range s.answerHandlers { - h(query.DomainName, answer.IP) + for _, h := range s.AnswerHandlers { + h(query.QNAME, answer.IP) } } } @@ -249,36 +342,35 @@ func (s *DNS) handleReqMsg(reqLen uint16, reqMsg []byte) (respLen uint16, respMs // SetServer . func (s *DNS) SetServer(domain, server string) { - s.dnsServerMap[domain] = server + s.DNSServerMap[domain] = server } // GetServer . func (s *DNS) GetServer(domain string) string { - domainParts := strings.Split(domain, ".") length := len(domainParts) for i := length - 2; i >= 0; i-- { domain := strings.Join(domainParts[i:length], ".") - if server, ok := s.dnsServerMap[domain]; ok { + if server, ok := s.DNSServerMap[domain]; ok { return server } } - return s.dnsServer + return s.DNSServer } // AddAnswerHandler . func (s *DNS) AddAnswerHandler(h DNSAnswerHandler) { - s.answerHandlers = append(s.answerHandlers, h) + s.AnswerHandlers = append(s.AnswerHandlers, h) } -func parseQuery(p []byte) (*dnsQuery, error) { - q := &dnsQuery{} +func parseQuestion(p []byte) (*DNSQuestion, error) { + q := &DNSQuestion{} var i int var domain []byte - for i = DNSUDPHeaderLen; i < len(p); { + for i = DNSHeaderLen; i < len(p); { l := int(p[i]) if l == 0 { @@ -292,21 +384,21 @@ func parseQuery(p []byte) (*dnsQuery, error) { i = i + l + 1 } - q.DomainName = string(domain[:len(domain)-1]) + q.QNAME = string(domain[:len(domain)-1]) if len(p) < i+4 { - return nil, errors.New("parseQuery error, not enough data") + return nil, errors.New("parseQuestion error, not enough data") } - q.QueryType = binary.BigEndian.Uint16(p[i:]) - q.QueryClass = binary.BigEndian.Uint16(p[i+2:]) + q.QTYPE = binary.BigEndian.Uint16(p[i:]) + q.QCLASS = binary.BigEndian.Uint16(p[i+2:]) q.Offset = i + 4 return q, nil } -func parseAnswers(p []byte) []*dnsAnswer { - var answers []*dnsAnswer +func parseAnswers(p []byte) []*DNSRR { + var answers []*DNSRR for i := 0; i < len(p); { @@ -323,23 +415,23 @@ func parseAnswers(p []byte) []*dnsAnswer { break } - answer := &dnsAnswer{} + answer := &DNSRR{} - answer.QueryType = binary.BigEndian.Uint16(p[i:]) - answer.QueryClass = binary.BigEndian.Uint16(p[i+2:]) + answer.TYPE = binary.BigEndian.Uint16(p[i:]) + answer.CLASS = binary.BigEndian.Uint16(p[i+2:]) answer.TTL = binary.BigEndian.Uint32(p[i+4:]) - answer.DataLength = binary.BigEndian.Uint16(p[i+8:]) - answer.Data = p[i+10 : i+10+int(answer.DataLength)] + answer.RDLENGTH = binary.BigEndian.Uint16(p[i+8:]) + answer.RDATA = p[i+10 : i+10+int(answer.RDLENGTH)] - if answer.QueryType == DNSQueryTypeA { - answer.IP = net.IP(answer.Data[:net.IPv4len]).String() - } else if answer.QueryType == DNSQueryTypeAAAA { - answer.IP = net.IP(answer.Data[:net.IPv6len]).String() + if answer.TYPE == DNSQTypeA { + answer.IP = net.IP(answer.RDATA[:net.IPv4len]).String() + } else if answer.TYPE == DNSQTypeAAAA { + answer.IP = net.IP(answer.RDATA[:net.IPv6len]).String() } answers = append(answers, answer) - i = i + 10 + int(answer.DataLength) + i = i + 10 + int(answer.RDLENGTH) } return answers