diff --git a/dns.go b/dns.go index 1d689f9..b47c861 100644 --- a/dns.go +++ b/dns.go @@ -9,20 +9,44 @@ import ( "strings" ) -// UDPDNSHeaderLen is the length of UDP dns msg header -const UDPDNSHeaderLen = 12 +// DNSUDPHeaderLen is the length of UDP dns msg header +const DNSUDPHeaderLen = 12 -// TCPDNSHEADERLen is the length of TCP dns msg header -const TCPDNSHEADERLen = 2 + UDPDNSHeaderLen +// DNSTCPHeaderLen is the length of TCP dns msg header +const DNSTCPHeaderLen = 2 + DNSUDPHeaderLen -// MaxUDPDNSLen is the max size of udp dns request. +// DNSUDPMaxLen is the max size of udp dns request. // https://tools.ietf.org/html/rfc1035#section-4.2.1 // Messages carried by UDP are restricted to 512 bytes (not counting the IP // or UDP headers). Longer messages are truncated and the TC bit is set in // the header. // TODO: If the request length > 512 then the client will send TCP packets instead, // so we should also serve tcp requests. -const MaxUDPDNSLen = 512 +const DNSUDPMaxLen = 512 + +// DNSQueryTypeA ipv4 +const DNSQueryTypeA = 1 + +// DNSQueryTypeAAAA ipv6 +const DNSQueryTypeAAAA = 28 + +type dnsQuery struct { + DomainName string + QueryType uint16 + QueryClass uint16 + Offset int +} + +type dnsAnswer struct { + DomainName string + QueryType uint16 + QueryClass uint16 + TTL uint32 + DataLength uint16 + Data []byte + + IP string +} // DNS . type DNS struct { @@ -55,7 +79,7 @@ func (s *DNS) ListenAndServe() { logf("listening UDP on %s", s.addr) for { - data := make([]byte, MaxUDPDNSLen) + data := make([]byte, DNSUDPMaxLen) n, clientAddr, err := l.ReadFrom(data) if err != nil { @@ -67,7 +91,8 @@ func (s *DNS) ListenAndServe() { go func() { // TODO: check domain rules and get a proper upstream name server. - domain := string(getDomain(data)) + query := parseQuery(data) + domain := query.DomainName dnsServer := s.GetServer(domain) // TODO: check here @@ -78,8 +103,6 @@ func (s *DNS) ListenAndServe() { } defer rc.Close() - logf("proxy-dns %s, %s <-> %s", domain, clientAddr.String(), dnsServer) - // 2 bytes length after tcp header, before dns message length := make([]byte, 2) binary.BigEndian.PutUint16(length, uint16(len(data))) @@ -92,16 +115,32 @@ func (s *DNS) ListenAndServe() { return } + var ip string // length is not needed in udp dns response. (2 bytes) // SEE RFC1035, section 4.2.2 TCP: The message is prefixed with a two byte length field which gives the message length, excluding the two byte length field. if len(resp) > 2 { msg := resp[2:] + // TODO: Get IP from response, check and add to ipset + query := parseQuery(msg) + if len(msg) > query.Offset { + answers := parseAnswers(msg[query.Offset:]) + for _, answer := range answers { + if answer.IP != "" { + ip += answer.IP + "," + } + + } + + } + _, err = l.WriteTo(msg, clientAddr) if err != nil { logf("error in local write: %s\n", err) } } + logf("proxy-dns %s, %s <-> %s, ip: %s", domain, clientAddr.String(), dnsServer, ip) + }() } } @@ -127,26 +166,61 @@ func (s *DNS) GetServer(domain string) string { return s.dnsServer } -// getDomain from dns request playload, return []byte like: -// []byte{'w', 'w', 'w', '.', 'm', 's', 'n', '.', 'c', 'o', 'm', '.'} -// []byte("www.msn.com.") -func getDomain(p []byte) []byte { - var ret []byte +func parseQuery(p []byte) *dnsQuery { + q := &dnsQuery{} - for i := UDPDNSHeaderLen; i < len(p); { + var i int + var domain []byte + for i = DNSUDPHeaderLen; i < len(p); { l := int(p[i]) if l == 0 { + i++ break } - ret = append(ret, p[i+1:i+l+1]...) - ret = append(ret, '.') + domain = append(domain, p[i+1:i+l+1]...) + domain = append(domain, '.') i = i + l + 1 } - // TODO: check here - // domain name could not be null, so the length of ret always >= 1? - return ret[:len(ret)-1] + q.DomainName = string(domain[:len(domain)-1]) + q.QueryType = binary.BigEndian.Uint16(p[i:]) + q.QueryClass = binary.BigEndian.Uint16(p[i+2:]) + q.Offset = i + 4 + + return q +} + +func parseAnswers(p []byte) []*dnsAnswer { + var answers []*dnsAnswer + + for i := 0; i < len(p); { + l := int(p[i]) + + if l == 0 { + i++ + break + } + + answer := &dnsAnswer{} + answer.QueryType = binary.BigEndian.Uint16(p[i+2:]) + answer.QueryClass = binary.BigEndian.Uint16(p[i+4:]) + answer.TTL = binary.BigEndian.Uint32(p[i+6:]) + answer.DataLength = binary.BigEndian.Uint16(p[i+10:]) + answer.Data = p[i+12 : i+12+int(answer.DataLength)] + + if answer.QueryType == DNSQueryTypeA { + answer.IP = net.IP(answer.Data[:net.IPv4len]).String() + } else if answer.QueryType == DNSQueryTypeAAAA { + answer.IP = net.IP(answer.Data[:net.IPv6len]).String() + } + + answers = append(answers, answer) + + i = i + 12 + int(answer.DataLength) + } + + return answers }