2017-08-13 23:51:08 +08:00
// https://tools.ietf.org/html/rfc1035
package main
import (
"encoding/binary"
2017-12-21 10:10:56 +08:00
"io"
2017-08-13 23:51:08 +08:00
"net"
2017-08-16 13:20:12 +08:00
"strings"
2017-08-13 23:51:08 +08:00
)
2017-08-21 23:57:49 +08:00
// DNSUDPHeaderLen is the length of UDP dns msg header
const DNSUDPHeaderLen = 12
2017-08-13 23:51:08 +08:00
2017-08-21 23:57:49 +08:00
// DNSTCPHeaderLen is the length of TCP dns msg header
const DNSTCPHeaderLen = 2 + DNSUDPHeaderLen
2017-08-13 23:51:08 +08:00
2017-08-21 23:57:49 +08:00
// DNSUDPMaxLen is the max size of udp dns request.
2017-08-13 23:51:08 +08:00
// https://tools.ietf.org/html/rfc1035#section-4.2.1
// Messages carried by UDP are restricted to 512 bytes (not counting the IP
// or UDP headers). Longer messages are truncated and the TC bit is set in
// the header.
// TODO: If the request length > 512 then the client will send TCP packets instead,
// so we should also serve tcp requests.
2017-08-21 23:57:49 +08:00
const DNSUDPMaxLen = 512
// DNSQueryTypeA ipv4
const DNSQueryTypeA = 1
// DNSQueryTypeAAAA ipv6
const DNSQueryTypeAAAA = 28
type dnsQuery struct {
DomainName string
QueryType uint16
QueryClass uint16
Offset int
}
type dnsAnswer struct {
2017-08-23 16:35:39 +08:00
// DomainName string
2017-08-21 23:57:49 +08:00
QueryType uint16
QueryClass uint16
TTL uint32
DataLength uint16
Data [ ] byte
IP string
}
2017-08-13 23:51:08 +08:00
2017-08-23 16:35:39 +08:00
// DNSAnswerHandler .
type DNSAnswerHandler func ( domain , ip string ) error
2017-08-20 21:44:18 +08:00
// DNS .
2017-08-16 13:20:12 +08:00
type DNS struct {
2017-08-23 16:35:39 +08:00
* Forwarder // as proxy client
sDialer Dialer // dialer for server
2017-08-16 13:20:12 +08:00
dnsServer string
2017-08-23 16:35:39 +08:00
dnsServerMap map [ string ] string
answerHandlers [ ] DNSAnswerHandler
2017-08-13 23:51:08 +08:00
}
2017-09-10 20:33:35 +08:00
// NewDNS returns a dns forwarder. client[dns.udp] -> glider[tcp] -> forwarder[dns.tcp] -> remote dns addr
2017-08-23 16:35:39 +08:00
func NewDNS ( addr , raddr string , sDialer Dialer ) ( * DNS , error ) {
2017-08-16 13:20:12 +08:00
s := & DNS {
2017-08-23 16:35:39 +08:00
Forwarder : NewForwarder ( addr , nil ) ,
sDialer : sDialer ,
2017-08-16 13:20:12 +08:00
dnsServer : raddr ,
dnsServerMap : make ( map [ string ] string ) ,
2017-08-13 23:51:08 +08:00
}
return s , nil
}
// ListenAndServe .
2017-08-16 13:20:12 +08:00
func ( s * DNS ) ListenAndServe ( ) {
2017-09-06 16:21:27 +08:00
c , err := net . ListenPacket ( "udp" , s . addr )
2017-08-13 23:51:08 +08:00
if err != nil {
logf ( "failed to listen on %s: %v" , s . addr , err )
return
}
2017-09-06 16:21:27 +08:00
defer c . Close ( )
2017-08-13 23:51:08 +08:00
logf ( "listening UDP on %s" , s . addr )
for {
2017-08-21 23:57:49 +08:00
data := make ( [ ] byte , DNSUDPMaxLen )
2017-08-13 23:51:08 +08:00
2017-09-06 16:21:27 +08:00
n , clientAddr , err := c . ReadFrom ( data )
2017-08-13 23:51:08 +08:00
if err != nil {
logf ( "DNS local read error: %v" , err )
continue
}
data = data [ : n ]
go func ( ) {
2017-08-21 23:57:49 +08:00
query := parseQuery ( data )
domain := query . DomainName
2017-08-13 23:51:08 +08:00
2017-08-16 13:20:12 +08:00
dnsServer := s . GetServer ( domain )
2017-08-23 16:35:39 +08:00
2017-08-23 17:45:57 +08:00
rc , err := s . sDialer . NextDialer ( domain + ":53" ) . Dial ( "tcp" , dnsServer )
2017-08-13 23:51:08 +08:00
if err != nil {
2017-08-16 13:20:12 +08:00
logf ( "failed to connect to server %v: %v" , dnsServer , err )
2017-08-13 23:51:08 +08:00
return
}
defer rc . Close ( )
// 2 bytes length after tcp header, before dns message
2017-12-21 10:10:56 +08:00
reqLen := make ( [ ] byte , 2 )
binary . BigEndian . PutUint16 ( reqLen , uint16 ( len ( data ) ) )
rc . Write ( reqLen )
2017-08-13 23:51:08 +08:00
rc . Write ( data )
2017-12-21 10:10:56 +08:00
// fmt.Printf("dns req len %d:\n%s\n\n", reqLen, hex.Dump(data[:]))
var respLen uint16
err = binary . Read ( rc , binary . BigEndian , & respLen )
if err != nil {
2017-12-24 23:13:53 +08:00
logf ( "proxy-dns error in read respLen %s\n" , err )
2017-12-21 10:10:56 +08:00
return
}
respMsg := make ( [ ] byte , respLen )
_ , err = io . ReadFull ( rc , respMsg )
2017-08-13 23:51:08 +08:00
if err != nil {
2017-12-24 23:13:53 +08:00
logf ( "proxy-dns error in read respMsg %s\n" , err )
2017-08-13 23:51:08 +08:00
return
}
2017-12-21 10:10:56 +08:00
// fmt.Printf("dns resp len %d:\n%s\n\n", respLen, hex.Dump(respMsg[:]))
2017-08-21 23:57:49 +08:00
var ip string
2017-08-13 23:51:08 +08:00
// length is not needed in udp dns response. (2 bytes)
// SEE RFC1035, section 4.2.2 TCP: The message is prefixed with a two byte length field which gives the message length, excluding the two byte length field.
2017-12-21 10:10:56 +08:00
if respLen > 0 {
query := parseQuery ( respMsg )
2018-01-05 12:29:06 +08:00
if ( query . QueryType == DNSQueryTypeA || query . QueryType == DNSQueryTypeAAAA ) &&
len ( respMsg ) > query . Offset {
2017-12-21 10:10:56 +08:00
answers := parseAnswers ( respMsg [ query . Offset : ] )
2017-08-21 23:57:49 +08:00
for _ , answer := range answers {
if answer . IP != "" {
ip += answer . IP + ","
}
2017-08-23 16:35:39 +08:00
for _ , h := range s . answerHandlers {
h ( query . DomainName , answer . IP )
}
2017-08-21 23:57:49 +08:00
}
}
2017-12-21 10:10:56 +08:00
_ , err = c . WriteTo ( respMsg , clientAddr )
2017-08-13 23:51:08 +08:00
if err != nil {
2017-12-24 23:13:53 +08:00
logf ( "proxy-dns error in local write: %s\n" , err )
2017-08-13 23:51:08 +08:00
}
}
2018-01-05 12:29:06 +08:00
logf ( "proxy-dns %s <-> %s, type: %d, %s: %s" , clientAddr . String ( ) , dnsServer , query . QueryType , domain , ip )
2017-08-21 23:57:49 +08:00
2017-08-13 23:51:08 +08:00
} ( )
}
}
2017-08-16 13:20:12 +08:00
// SetServer .
func ( s * DNS ) SetServer ( domain , server string ) {
s . dnsServerMap [ domain ] = server
}
// GetServer .
func ( s * DNS ) GetServer ( domain string ) string {
domainParts := strings . Split ( domain , "." )
length := len ( domainParts )
for i := length - 2 ; i >= 0 ; i -- {
domain := strings . Join ( domainParts [ i : length ] , "." )
if server , ok := s . dnsServerMap [ domain ] ; ok {
return server
}
}
return s . dnsServer
}
2017-08-23 16:35:39 +08:00
// AddAnswerHandler .
func ( s * DNS ) AddAnswerHandler ( h DNSAnswerHandler ) {
s . answerHandlers = append ( s . answerHandlers , h )
}
2017-08-21 23:57:49 +08:00
func parseQuery ( p [ ] byte ) * dnsQuery {
q := & dnsQuery { }
2017-08-13 23:51:08 +08:00
2017-08-21 23:57:49 +08:00
var i int
var domain [ ] byte
for i = DNSUDPHeaderLen ; i < len ( p ) ; {
2017-08-13 23:51:08 +08:00
l := int ( p [ i ] )
if l == 0 {
2017-08-21 23:57:49 +08:00
i ++
2017-08-13 23:51:08 +08:00
break
}
2017-08-21 23:57:49 +08:00
domain = append ( domain , p [ i + 1 : i + l + 1 ] ... )
domain = append ( domain , '.' )
2017-08-13 23:51:08 +08:00
i = i + l + 1
}
2017-08-21 23:57:49 +08:00
q . DomainName = string ( domain [ : len ( domain ) - 1 ] )
q . QueryType = binary . BigEndian . Uint16 ( p [ i : ] )
q . QueryClass = binary . BigEndian . Uint16 ( p [ i + 2 : ] )
q . Offset = i + 4
return q
}
func parseAnswers ( p [ ] byte ) [ ] * dnsAnswer {
var answers [ ] * dnsAnswer
for i := 0 ; i < len ( p ) ; {
l := int ( p [ i ] )
if l == 0 {
i ++
break
}
answer := & dnsAnswer { }
2018-01-05 12:29:06 +08:00
// https://tools.ietf.org/html/rfc1035#section-4.1.4
// i+2 assumes the ANSWER always using "Message compression", start with 2 bytes offset of the query domain.
// TODO: check here
2017-08-21 23:57:49 +08:00
answer . QueryType = binary . BigEndian . Uint16 ( p [ i + 2 : ] )
answer . QueryClass = binary . BigEndian . Uint16 ( p [ i + 4 : ] )
answer . TTL = binary . BigEndian . Uint32 ( p [ i + 6 : ] )
answer . DataLength = binary . BigEndian . Uint16 ( p [ i + 10 : ] )
answer . Data = p [ i + 12 : i + 12 + int ( answer . DataLength ) ]
if answer . QueryType == DNSQueryTypeA {
answer . IP = net . IP ( answer . Data [ : net . IPv4len ] ) . String ( )
} else if answer . QueryType == DNSQueryTypeAAAA {
answer . IP = net . IP ( answer . Data [ : net . IPv6len ] ) . String ( )
}
answers = append ( answers , answer )
i = i + 12 + int ( answer . DataLength )
}
return answers
2017-08-13 23:51:08 +08:00
}