diff --git a/dns/cache.go b/dns/cache.go index 50b371a..03be604 100644 --- a/dns/cache.go +++ b/dns/cache.go @@ -42,20 +42,18 @@ func (c *Cache) Len() int { return len(c.m) } -// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0 +// Put an item into cache, invalid after ttl seconds func (c *Cache) Put(k string, v []byte, ttl int) { - if len(v) == 0 { - return + if len(v) != 0 { + c.l.Lock() + it, ok := c.m[k] + if !ok { + it = &item{value: v} + c.m[k] = it + } + it.expire = time.Now().Add(time.Duration(ttl) * time.Second) + c.l.Unlock() } - - c.l.Lock() - it, ok := c.m[k] - if !ok { - it = &item{value: v} - c.m[k] = it - } - it.expire = time.Now().Add(time.Duration(ttl) * time.Second) - c.l.Unlock() } // Get an item from cache diff --git a/dns/client.go b/dns/client.go index f87da5a..6780549 100644 --- a/dns/client.go +++ b/dns/client.go @@ -41,7 +41,7 @@ func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) { // Exchange handles request msg and returns response msg // reqBytes = reqLen + reqMsg -func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) { +func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) { req, err := UnmarshalMessage(reqBytes[2:]) if err != nil { return nil, err @@ -58,37 +58,14 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) { } } - dnsServer := c.GetServer(req.Question.QNAME) - rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer) + dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP) if err != nil { - log.F("[dns] failed to connect to server %v: %v", dnsServer, err) - return nil, err - } - defer rc.Close() - - if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil { - log.F("[dns] failed to write req message: %v", err) - return nil, err - } - - var respLen uint16 - if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil { - log.F("[dns] failed to read response length: %v", err) - return nil, err - } - - respBytes := make([]byte, respLen+2) - binary.BigEndian.PutUint16(respBytes[:2], respLen) - - _, err = io.ReadFull(rc, respBytes[2:]) - if err != nil { - log.F("[dns] error in read respMsg %s\n", err) return nil, err } if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA { - log.F("[dns] %s <-> %s, type: %d, %s", - clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME) + log.F("[dns] %s <-> %s(%s), type: %d, %s", + clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME) return respBytes, nil } @@ -104,28 +81,101 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) { for _, h := range c.handlers { h(resp.Question.QNAME, answer.IP) } - if answer.IP != "" { ips = append(ips, answer.IP) } - if answer.TTL != 0 { ttl = int(answer.TTL) } - } - } // add to cache - c.cache.Put(getKey(resp.Question), respBytes, ttl) + if len(ips) != 0 { + c.cache.Put(getKey(resp.Question), respBytes, ttl) + } - log.F("[dns] %s <-> %s, type: %d, %s: %s", - clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) + log.F("[dns] %s <-> %s(%s), type: %d, %s: %s", + clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) return respBytes, nil } +// exchange choose a upstream dns server based on qname, communicate with it on the network +func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) { + // use tcp to connect upstream server default + network = "tcp" + + dialer := c.dialer.NextDialer(qname + ":53") + + // If client uses udp and no forwarders specified, use udp + if !preferTCP && dialer.Addr() == "DIRECT" { + network = "udp" + } + + server = c.GetServer(qname) + rc, err := dialer.Dial(network, server) + if err != nil { + log.F("[dns] failed to connect to server %v: %v", server, err) + return + } + + switch network { + case "tcp": + respBytes, err = c.exchangeTCP(rc, reqBytes) + case "udp": + respBytes, err = c.exchangeUDP(rc, reqBytes) + } + + return +} + +// exchangeTCP exchange with server over tcp +func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { + defer rc.Close() + + if _, err := rc.Write(reqBytes); err != nil { + log.F("[dns] failed to write req message: %v", err) + return nil, err + } + + var respLen uint16 + if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil { + log.F("[dns] failed to read response length: %v", err) + return nil, err + } + + respBytes := make([]byte, respLen+2) + binary.BigEndian.PutUint16(respBytes[:2], respLen) + + _, err := io.ReadFull(rc, respBytes[2:]) + if err != nil { + log.F("[dns] error in read respMsg %s\n", err) + return nil, err + } + + return respBytes, nil +} + +// exchangeUDP exchange with server over udp +func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) { + defer rc.Close() + + if _, err := rc.Write(reqBytes[2:]); err != nil { + log.F("[dns] failed to write req message: %v", err) + return nil, err + } + + reqBytes = make([]byte, 2+UDPMaxLen) + n, err := rc.Read(reqBytes[2:]) + if err != nil { + return nil, err + } + binary.BigEndian.PutUint16(reqBytes[:2], uint16(n)) + + return reqBytes[:2+n], nil +} + // SetServer . func (c *Client) SetServer(domain string, servers ...string) { c.upServerMap[domain] = append(c.upServerMap[domain], servers...) @@ -193,7 +243,7 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) { m := NewMessage(0, Response) m.SetQuestion(NewQuestion(qtype, domain)) - rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN, + rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET, RDLENGTH: rdlen, RDATA: rdata} m.AddAnswer(rr) diff --git a/dns/message.go b/dns/message.go index 8730863..c583c2c 100644 --- a/dns/message.go +++ b/dns/message.go @@ -31,8 +31,8 @@ const ( QTypeAAAA uint16 = 28 ///ipv6 ) -// CLASSIN . -const CLASSIN uint16 = 1 +// ClassINET . +const ClassINET uint16 = 1 // Message format // https://tools.ietf.org/html/rfc1035#section-4.1 @@ -51,7 +51,7 @@ const CLASSIN uint16 = 1 // +---------------------+ // | Additional | RRs holding additional information type Message struct { - *Header + Header // most dns implementation only support 1 question Question *Question Answers []*RR @@ -68,10 +68,10 @@ func NewMessage(id uint16, msgType int) *Message { id = uint16(rand.Uint32()) } - h := &Header{ID: id} - h.SetMsgType(msgType) + m := &Message{Header: Header{ID: id}} + m.SetMsgType(msgType) - return &Message{Header: h} + return m } // SetQuestion sets a question to dns message, @@ -119,38 +119,35 @@ func (m *Message) Marshal() ([]byte, error) { // UnmarshalMessage unmarshals []bytes to Message func UnmarshalMessage(b []byte) (*Message, error) { - msg := &Message{Header: &Header{}} - msg.unMarshaled = b - - err := UnmarshalHeader(b[:HeaderLen], msg.Header) + m := &Message{unMarshaled: b} + err := UnmarshalHeader(b[:HeaderLen], &m.Header) if err != nil { return nil, err } q := &Question{} - qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q) + qLen, err := m.UnmarshalQuestion(b[HeaderLen:], q) if err != nil { return nil, err } - - msg.SetQuestion(q) + m.SetQuestion(q) // resp answers - rridx := HeaderLen + qLen - for i := 0; i < int(msg.Header.ANCOUNT); i++ { + rrIdx := HeaderLen + qLen + for i := 0; i < int(m.Header.ANCOUNT); i++ { rr := &RR{} - rrLen, err := msg.UnmarshalRR(rridx, rr) + rrLen, err := m.UnmarshalRR(rrIdx, rr) if err != nil { return nil, err } - msg.AddAnswer(rr) + m.AddAnswer(rr) - rridx += rrLen + rrIdx += rrLen } - msg.Header.SetAncount(len(msg.Answers)) + m.Header.SetAncount(len(m.Answers)) - return msg, nil + return m, nil } // Header format @@ -262,7 +259,7 @@ func NewQuestion(qtype uint16, domain string) *Question { return &Question{ QNAME: domain, QTYPE: qtype, - QCLASS: CLASSIN, + QCLASS: ClassINET, } } @@ -428,7 +425,9 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) { if size == 0 { idx++ break - } else if size > 63 { + } + + if size > 63 { return "", 0, errors.New("UnmarshalDomain: label size larger than 63") } diff --git a/dns/server.go b/dns/server.go index 77717ed..1b6a998 100644 --- a/dns/server.go +++ b/dns/server.go @@ -61,11 +61,10 @@ func (s *Server) ListenAndServeUDP() { log.F("[dns] not enough message data") continue } - binary.BigEndian.PutUint16(reqBytes[:2], reqLen) go func() { - respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String()) + respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String(), false) if err != nil { log.F("[dns] error in exchange: %s", err) return @@ -123,7 +122,7 @@ func (s *Server) ServeTCP(c net.Conn) { binary.BigEndian.PutUint16(reqBytes[:2], reqLen) - respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String()) + respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true) if err != nil { log.F("[dns]-tcp error in exchange: %s", err) return