From 4781e7b4725c230a6d844a3d5b29e629dcfceb0b Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Tue, 31 Jul 2018 00:03:36 +0800 Subject: [PATCH] dns: add tcp server support --- dns/message.go | 75 ++++++++++++++++++++++++++++++++++++++------------ dns/server.go | 64 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 18 deletions(-) diff --git a/dns/message.go b/dns/message.go index 67ec68b..e9b0b6d 100644 --- a/dns/message.go +++ b/dns/message.go @@ -48,7 +48,6 @@ const ( // +---------------------+ // | Additional | RRs holding additional information type Message struct { - // all dns messages should start with a 12 byte dns header *Header // most dns implementation only support 1 question Question *Question @@ -99,13 +98,13 @@ func (m *Message) Marshal() ([]byte, error) { } buf.Write(b) - // for _, answer := range m.Answers { - // b, err := answer.Marshal() - // if err != nil { - // return nil, err - // } - // buf.Write(b) - // } + for _, answer := range m.Answers { + b, err := answer.Marshal() + if err != nil { + return nil, err + } + buf.Write(b) + } return buf.Bytes(), nil } @@ -191,6 +190,11 @@ func (h *Header) SetQR(qr int) { h.Bits |= uint16(qr) << 15 } +// SetTC . +func (h *Header) SetTC(tc int) { + h.Bits |= uint16(tc) << 9 +} + // SetQdcount sets query count, most dns servers only support 1 query per request func (h *Header) SetQdcount(qdcount int) { h.QDCOUNT = uint16(qdcount) @@ -282,7 +286,11 @@ func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) { return 0, errors.New("unmarshal question must not be nil") } - domain, idx := m.UnmarshalDomain(b) + domain, idx, err := m.UnmarshalDomain(b) + if err != nil { + return 0, err + } + q.QNAME = domain q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2]) q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4]) @@ -335,19 +343,36 @@ func NewRR() *RR { return rr } +// Marshal marshals RR struct to []byte +func (rr *RR) Marshal() ([]byte, error) { + var buf bytes.Buffer + + buf.Write(MarshalDomain(rr.NAME)) + binary.Write(&buf, binary.BigEndian, rr.TYPE) + binary.Write(&buf, binary.BigEndian, rr.CLASS) + binary.Write(&buf, binary.BigEndian, rr.TTL) + binary.Write(&buf, binary.BigEndian, rr.RDLENGTH) + buf.Write(rr.RDATA) + + return buf.Bytes(), nil +} + // UnmarshalRR unmarshals []bytes to RR func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { if rr == nil { - return 0, errors.New("unmarshal question must not be nil") + return 0, errors.New("unmarshal rr must not be nil") } p := m.unMarshaled[start:] - domain, n := m.UnmarshalDomain(p) + domain, n, err := m.UnmarshalDomain(p) + if err != nil { + return 0, err + } rr.NAME = domain if len(p) <= n+10 { - return 0, errors.New("not enough data") + return 0, errors.New("UnmarshalRR: not enough data") } rr.TYPE = binary.BigEndian.Uint16(p[n:]) @@ -381,7 +406,7 @@ func MarshalDomain(domain string) []byte { } // UnmarshalDomain gets domain from bytes -func (m *Message) UnmarshalDomain(b []byte) (string, int) { +func (m *Message) UnmarshalDomain(b []byte) (string, int, error) { var idx, size int var labels = []string{} @@ -393,7 +418,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) { // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ if b[idx]&0xC0 == 0xC0 { offset := binary.BigEndian.Uint16(b[idx : idx+2]) - lable := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) + lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) + if err != nil { + return "", 0, err + } + labels = append(labels, lable) idx += 2 break @@ -402,18 +431,28 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) { if size == 0 { idx++ break + } else if size > 63 { + return "", 0, errors.New("UnmarshalDomain: label length more than 63") } + + if idx+size+1 > len(b) { + return "", 0, errors.New("UnmarshalDomain: label length more than 63") + } + labels = append(labels, string(b[idx+1:idx+size+1])) idx += (size + 1) } } domain := strings.Join(labels, ".") - return domain, idx + return domain, idx, nil } // UnmarshalDomainPoint gets domain from offset point -func (m *Message) UnmarshalDomainPoint(offset int) string { - domain, _ := m.UnmarshalDomain(m.unMarshaled[offset:]) - return domain +func (m *Message) UnmarshalDomainPoint(offset int) (string, error) { + if offset > len(m.unMarshaled) { + return "", errors.New("UnmarshalDomainPoint: offset larger than msg length") + } + domain, _, err := m.UnmarshalDomain(m.unMarshaled[offset:]) + return domain, err } diff --git a/dns/server.go b/dns/server.go index 4f824bf..0806c0b 100644 --- a/dns/server.go +++ b/dns/server.go @@ -2,12 +2,17 @@ package dns import ( "encoding/binary" + "io" "net" + "time" "github.com/nadoo/glider/common/log" "github.com/nadoo/glider/proxy" ) +// conn timeout, seconds +const timeout = 30 + // Server is a dns server struct type Server struct { addr string @@ -28,6 +33,12 @@ func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server, // ListenAndServe . func (s *Server) ListenAndServe() { + go s.ListenAndServeTCP() + s.ListenAndServeUDP() +} + +// ListenAndServeUDP . +func (s *Server) ListenAndServeUDP() { c, err := net.ListenPacket("udp", s.addr) if err != nil { log.F("[dns] failed to listen on %s, error: %v", s.addr, err) @@ -70,3 +81,56 @@ func (s *Server) ListenAndServe() { } } + +// ListenAndServeTCP . +func (s *Server) ListenAndServeTCP() { + l, err := net.Listen("tcp", s.addr) + if err != nil { + log.F("[dns]-tcp error: %v", err) + return + } + + log.F("[dns]-tcp listening TCP on %s", s.addr) + + for { + c, err := l.Accept() + if err != nil { + log.F("[dns]-tcp error: failed to accept: %v", err) + continue + } + go s.ServeTCP(c) + } +} + +// ServeTCP . +func (s *Server) ServeTCP(c net.Conn) { + defer c.Close() + + c.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second)) + + var reqLen uint16 + if err := binary.Read(c, binary.BigEndian, &reqLen); err != nil { + log.F("[dns]-tcp failed to get request length: %v", err) + return + } + + reqBytes := make([]byte, reqLen+2) + _, err := io.ReadFull(c, reqBytes[2:]) + if err != nil { + log.F("[dns]-tcp error in read reqBytes %s", err) + return + } + + binary.BigEndian.PutUint16(reqBytes[:2], reqLen) + + respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String()) + if err != nil { + log.F("[dns]-tcp error in exchange: %s", err) + return + } + + if err := binary.Write(c, binary.BigEndian, respBytes); err != nil { + log.F("[dns]-tcp error in local write respBytes: %s", err) + return + } +}