dns: add tcp server support

This commit is contained in:
nadoo 2018-07-31 00:03:36 +08:00
parent 41ddbb1168
commit 4781e7b472
2 changed files with 121 additions and 18 deletions

View File

@ -48,7 +48,6 @@ const (
// +---------------------+ // +---------------------+
// | Additional | RRs holding additional information // | Additional | RRs holding additional information
type Message struct { type Message struct {
// all dns messages should start with a 12 byte dns header
*Header *Header
// most dns implementation only support 1 question // most dns implementation only support 1 question
Question *Question Question *Question
@ -99,13 +98,13 @@ func (m *Message) Marshal() ([]byte, error) {
} }
buf.Write(b) buf.Write(b)
// for _, answer := range m.Answers { for _, answer := range m.Answers {
// b, err := answer.Marshal() b, err := answer.Marshal()
// if err != nil { if err != nil {
// return nil, err return nil, err
// } }
// buf.Write(b) buf.Write(b)
// } }
return buf.Bytes(), nil return buf.Bytes(), nil
} }
@ -191,6 +190,11 @@ func (h *Header) SetQR(qr int) {
h.Bits |= uint16(qr) << 15 h.Bits |= uint16(qr) << 15
} }
// SetTC .
func (h *Header) SetTC(tc int) {
h.Bits |= uint16(tc) << 9
}
// SetQdcount sets query count, most dns servers only support 1 query per request // SetQdcount sets query count, most dns servers only support 1 query per request
func (h *Header) SetQdcount(qdcount int) { func (h *Header) SetQdcount(qdcount int) {
h.QDCOUNT = uint16(qdcount) h.QDCOUNT = uint16(qdcount)
@ -282,7 +286,11 @@ func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) {
return 0, errors.New("unmarshal question must not be nil") return 0, errors.New("unmarshal question must not be nil")
} }
domain, idx := m.UnmarshalDomain(b) domain, idx, err := m.UnmarshalDomain(b)
if err != nil {
return 0, err
}
q.QNAME = domain q.QNAME = domain
q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2]) q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2])
q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4]) q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4])
@ -335,19 +343,36 @@ func NewRR() *RR {
return rr return rr
} }
// Marshal marshals RR struct to []byte
func (rr *RR) Marshal() ([]byte, error) {
var buf bytes.Buffer
buf.Write(MarshalDomain(rr.NAME))
binary.Write(&buf, binary.BigEndian, rr.TYPE)
binary.Write(&buf, binary.BigEndian, rr.CLASS)
binary.Write(&buf, binary.BigEndian, rr.TTL)
binary.Write(&buf, binary.BigEndian, rr.RDLENGTH)
buf.Write(rr.RDATA)
return buf.Bytes(), nil
}
// UnmarshalRR unmarshals []bytes to RR // UnmarshalRR unmarshals []bytes to RR
func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
if rr == nil { if rr == nil {
return 0, errors.New("unmarshal question must not be nil") return 0, errors.New("unmarshal rr must not be nil")
} }
p := m.unMarshaled[start:] p := m.unMarshaled[start:]
domain, n := m.UnmarshalDomain(p) domain, n, err := m.UnmarshalDomain(p)
if err != nil {
return 0, err
}
rr.NAME = domain rr.NAME = domain
if len(p) <= n+10 { if len(p) <= n+10 {
return 0, errors.New("not enough data") return 0, errors.New("UnmarshalRR: not enough data")
} }
rr.TYPE = binary.BigEndian.Uint16(p[n:]) rr.TYPE = binary.BigEndian.Uint16(p[n:])
@ -381,7 +406,7 @@ func MarshalDomain(domain string) []byte {
} }
// UnmarshalDomain gets domain from bytes // UnmarshalDomain gets domain from bytes
func (m *Message) UnmarshalDomain(b []byte) (string, int) { func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
var idx, size int var idx, size int
var labels = []string{} var labels = []string{}
@ -393,7 +418,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
if b[idx]&0xC0 == 0xC0 { if b[idx]&0xC0 == 0xC0 {
offset := binary.BigEndian.Uint16(b[idx : idx+2]) offset := binary.BigEndian.Uint16(b[idx : idx+2])
lable := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
if err != nil {
return "", 0, err
}
labels = append(labels, lable) labels = append(labels, lable)
idx += 2 idx += 2
break break
@ -402,18 +431,28 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
if size == 0 { if size == 0 {
idx++ idx++
break break
} else if size > 63 {
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
} }
if idx+size+1 > len(b) {
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
}
labels = append(labels, string(b[idx+1:idx+size+1])) labels = append(labels, string(b[idx+1:idx+size+1]))
idx += (size + 1) idx += (size + 1)
} }
} }
domain := strings.Join(labels, ".") domain := strings.Join(labels, ".")
return domain, idx return domain, idx, nil
} }
// UnmarshalDomainPoint gets domain from offset point // UnmarshalDomainPoint gets domain from offset point
func (m *Message) UnmarshalDomainPoint(offset int) string { func (m *Message) UnmarshalDomainPoint(offset int) (string, error) {
domain, _ := m.UnmarshalDomain(m.unMarshaled[offset:]) if offset > len(m.unMarshaled) {
return domain return "", errors.New("UnmarshalDomainPoint: offset larger than msg length")
}
domain, _, err := m.UnmarshalDomain(m.unMarshaled[offset:])
return domain, err
} }

View File

@ -2,12 +2,17 @@ package dns
import ( import (
"encoding/binary" "encoding/binary"
"io"
"net" "net"
"time"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
// conn timeout, seconds
const timeout = 30
// Server is a dns server struct // Server is a dns server struct
type Server struct { type Server struct {
addr string addr string
@ -28,6 +33,12 @@ func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server,
// ListenAndServe . // ListenAndServe .
func (s *Server) ListenAndServe() { func (s *Server) ListenAndServe() {
go s.ListenAndServeTCP()
s.ListenAndServeUDP()
}
// ListenAndServeUDP .
func (s *Server) ListenAndServeUDP() {
c, err := net.ListenPacket("udp", s.addr) c, err := net.ListenPacket("udp", s.addr)
if err != nil { if err != nil {
log.F("[dns] failed to listen on %s, error: %v", s.addr, err) log.F("[dns] failed to listen on %s, error: %v", s.addr, err)
@ -70,3 +81,56 @@ func (s *Server) ListenAndServe() {
} }
} }
// ListenAndServeTCP .
func (s *Server) ListenAndServeTCP() {
l, err := net.Listen("tcp", s.addr)
if err != nil {
log.F("[dns]-tcp error: %v", err)
return
}
log.F("[dns]-tcp listening TCP on %s", s.addr)
for {
c, err := l.Accept()
if err != nil {
log.F("[dns]-tcp error: failed to accept: %v", err)
continue
}
go s.ServeTCP(c)
}
}
// ServeTCP .
func (s *Server) ServeTCP(c net.Conn) {
defer c.Close()
c.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second))
var reqLen uint16
if err := binary.Read(c, binary.BigEndian, &reqLen); err != nil {
log.F("[dns]-tcp failed to get request length: %v", err)
return
}
reqBytes := make([]byte, reqLen+2)
_, err := io.ReadFull(c, reqBytes[2:])
if err != nil {
log.F("[dns]-tcp error in read reqBytes %s", err)
return
}
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
if err != nil {
log.F("[dns]-tcp error in exchange: %s", err)
return
}
if err := binary.Write(c, binary.BigEndian, respBytes); err != nil {
log.F("[dns]-tcp error in local write respBytes: %s", err)
return
}
}