dns: query in udp when client requests in udp and no forwarder specified

This commit is contained in:
nadoo 2018-08-02 00:11:22 +08:00
parent 8d20331096
commit 9acaff5b4a
4 changed files with 118 additions and 72 deletions

View File

@ -42,20 +42,18 @@ func (c *Cache) Len() int {
return len(c.m)
}
// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
// Put an item into cache, invalid after ttl seconds
func (c *Cache) Put(k string, v []byte, ttl int) {
if len(v) == 0 {
return
if len(v) != 0 {
c.l.Lock()
it, ok := c.m[k]
if !ok {
it = &item{value: v}
c.m[k] = it
}
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
c.l.Unlock()
}
c.l.Lock()
it, ok := c.m[k]
if !ok {
it = &item{value: v}
c.m[k] = it
}
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
c.l.Unlock()
}
// Get an item from cache

View File

@ -41,7 +41,7 @@ func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
// Exchange handles request msg and returns response msg
// reqBytes = reqLen + reqMsg
func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
req, err := UnmarshalMessage(reqBytes[2:])
if err != nil {
return nil, err
@ -58,37 +58,14 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
}
}
dnsServer := c.GetServer(req.Question.QNAME)
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
if err != nil {
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
return nil, err
}
defer rc.Close()
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
log.F("[dns] failed to write req message: %v", err)
return nil, err
}
var respLen uint16
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
log.F("[dns] failed to read response length: %v", err)
return nil, err
}
respBytes := make([]byte, respLen+2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err = io.ReadFull(rc, respBytes[2:])
if err != nil {
log.F("[dns] error in read respMsg %s\n", err)
return nil, err
}
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
log.F("[dns] %s <-> %s, type: %d, %s",
clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
log.F("[dns] %s <-> %s(%s), type: %d, %s",
clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME)
return respBytes, nil
}
@ -104,28 +81,101 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
for _, h := range c.handlers {
h(resp.Question.QNAME, answer.IP)
}
if answer.IP != "" {
ips = append(ips, answer.IP)
}
if answer.TTL != 0 {
ttl = int(answer.TTL)
}
}
}
// add to cache
c.cache.Put(getKey(resp.Question), respBytes, ttl)
if len(ips) != 0 {
c.cache.Put(getKey(resp.Question), respBytes, ttl)
}
log.F("[dns] %s <-> %s, type: %d, %s: %s",
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
log.F("[dns] %s <-> %s(%s), type: %d, %s: %s",
clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
return respBytes, nil
}
// exchange choose a upstream dns server based on qname, communicate with it on the network
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) {
// use tcp to connect upstream server default
network = "tcp"
dialer := c.dialer.NextDialer(qname + ":53")
// If client uses udp and no forwarders specified, use udp
if !preferTCP && dialer.Addr() == "DIRECT" {
network = "udp"
}
server = c.GetServer(qname)
rc, err := dialer.Dial(network, server)
if err != nil {
log.F("[dns] failed to connect to server %v: %v", server, err)
return
}
switch network {
case "tcp":
respBytes, err = c.exchangeTCP(rc, reqBytes)
case "udp":
respBytes, err = c.exchangeUDP(rc, reqBytes)
}
return
}
// exchangeTCP exchange with server over tcp
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
defer rc.Close()
if _, err := rc.Write(reqBytes); err != nil {
log.F("[dns] failed to write req message: %v", err)
return nil, err
}
var respLen uint16
if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
log.F("[dns] failed to read response length: %v", err)
return nil, err
}
respBytes := make([]byte, respLen+2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err := io.ReadFull(rc, respBytes[2:])
if err != nil {
log.F("[dns] error in read respMsg %s\n", err)
return nil, err
}
return respBytes, nil
}
// exchangeUDP exchange with server over udp
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
defer rc.Close()
if _, err := rc.Write(reqBytes[2:]); err != nil {
log.F("[dns] failed to write req message: %v", err)
return nil, err
}
reqBytes = make([]byte, 2+UDPMaxLen)
n, err := rc.Read(reqBytes[2:])
if err != nil {
return nil, err
}
binary.BigEndian.PutUint16(reqBytes[:2], uint16(n))
return reqBytes[:2+n], nil
}
// SetServer .
func (c *Client) SetServer(domain string, servers ...string) {
c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
@ -193,7 +243,7 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
m := NewMessage(0, Response)
m.SetQuestion(NewQuestion(qtype, domain))
rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
RDLENGTH: rdlen, RDATA: rdata}
m.AddAnswer(rr)

View File

@ -31,8 +31,8 @@ const (
QTypeAAAA uint16 = 28 ///ipv6
)
// CLASSIN .
const CLASSIN uint16 = 1
// ClassINET .
const ClassINET uint16 = 1
// Message format
// https://tools.ietf.org/html/rfc1035#section-4.1
@ -51,7 +51,7 @@ const CLASSIN uint16 = 1
// +---------------------+
// | Additional | RRs holding additional information
type Message struct {
*Header
Header
// most dns implementation only support 1 question
Question *Question
Answers []*RR
@ -68,10 +68,10 @@ func NewMessage(id uint16, msgType int) *Message {
id = uint16(rand.Uint32())
}
h := &Header{ID: id}
h.SetMsgType(msgType)
m := &Message{Header: Header{ID: id}}
m.SetMsgType(msgType)
return &Message{Header: h}
return m
}
// SetQuestion sets a question to dns message,
@ -119,38 +119,35 @@ func (m *Message) Marshal() ([]byte, error) {
// UnmarshalMessage unmarshals []bytes to Message
func UnmarshalMessage(b []byte) (*Message, error) {
msg := &Message{Header: &Header{}}
msg.unMarshaled = b
err := UnmarshalHeader(b[:HeaderLen], msg.Header)
m := &Message{unMarshaled: b}
err := UnmarshalHeader(b[:HeaderLen], &m.Header)
if err != nil {
return nil, err
}
q := &Question{}
qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
qLen, err := m.UnmarshalQuestion(b[HeaderLen:], q)
if err != nil {
return nil, err
}
msg.SetQuestion(q)
m.SetQuestion(q)
// resp answers
rridx := HeaderLen + qLen
for i := 0; i < int(msg.Header.ANCOUNT); i++ {
rrIdx := HeaderLen + qLen
for i := 0; i < int(m.Header.ANCOUNT); i++ {
rr := &RR{}
rrLen, err := msg.UnmarshalRR(rridx, rr)
rrLen, err := m.UnmarshalRR(rrIdx, rr)
if err != nil {
return nil, err
}
msg.AddAnswer(rr)
m.AddAnswer(rr)
rridx += rrLen
rrIdx += rrLen
}
msg.Header.SetAncount(len(msg.Answers))
m.Header.SetAncount(len(m.Answers))
return msg, nil
return m, nil
}
// Header format
@ -262,7 +259,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
return &Question{
QNAME: domain,
QTYPE: qtype,
QCLASS: CLASSIN,
QCLASS: ClassINET,
}
}
@ -428,7 +425,9 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
if size == 0 {
idx++
break
} else if size > 63 {
}
if size > 63 {
return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
}

View File

@ -61,11 +61,10 @@ func (s *Server) ListenAndServeUDP() {
log.F("[dns] not enough message data")
continue
}
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
go func() {
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String())
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String(), false)
if err != nil {
log.F("[dns] error in exchange: %s", err)
return
@ -123,7 +122,7 @@ func (s *Server) ServeTCP(c net.Conn) {
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
if err != nil {
log.F("[dns]-tcp error in exchange: %s", err)
return