From d5e3ea539ac992e31ccd1fdee051e27105598ded Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Sun, 29 Jul 2018 23:44:23 +0800 Subject: [PATCH] dns: rewrite codes --- conf.go | 6 +- dns/client.go | 140 ++++++++++++++ dns/dns.go | 474 ------------------------------------------------ dns/message.go | 425 +++++++++++++++++++++++++++++++++++++++++++ dns/server.go | 72 ++++++++ main.go | 8 +- proxy/direct.go | 1 - 7 files changed, 644 insertions(+), 482 deletions(-) create mode 100644 dns/client.go delete mode 100644 dns/dns.go create mode 100644 dns/message.go create mode 100644 dns/server.go diff --git a/conf.go b/conf.go index df2e847..da15edf 100644 --- a/conf.go +++ b/conf.go @@ -71,9 +71,9 @@ func confInit() { } if conf.RulesDir != "" { - if !path.IsAbs(conf.RulesDir) { - conf.RulesDir = path.Join(flag.ConfDir(), conf.RulesDir) - } + if !path.IsAbs(conf.RulesDir) { + conf.RulesDir = path.Join(flag.ConfDir(), conf.RulesDir) + } ruleFolderFiles, _ := listDir(conf.RulesDir, ".rule") for _, ruleFile := range ruleFolderFiles { diff --git a/dns/client.go b/dns/client.go new file mode 100644 index 0000000..1f7d201 --- /dev/null +++ b/dns/client.go @@ -0,0 +1,140 @@ +package dns + +import ( + "encoding/binary" + "io" + "strings" + + "github.com/nadoo/glider/common/log" + "github.com/nadoo/glider/proxy" +) + +// HandleFunc function handles the dns TypeA or TypeAAAA answer +type HandleFunc func(Domain, ip string) error + +// Client is a dns client struct +type Client struct { + dialer proxy.Dialer + UPServers []string + UPServerMap map[string][]string + Handlers []HandleFunc + + tcp bool +} + +// NewClient returns a new dns client +func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) { + c := &Client{ + dialer: dialer, + UPServers: upServers, + UPServerMap: make(map[string][]string), + } + + return c, nil +} + +// Exchange handles request msg and returns response msg +// reqBytes = reqLen + reqMsg +func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) { + reqMsg := reqBytes[2:] + + reqM := NewMessage() + err = UnmarshalMessage(reqMsg, reqM) + if err != nil { + return + } + + if reqM.Question.QTYPE == QTypeA || reqM.Question.QTYPE == QTypeAAAA { + // TODO: if query.QNAME in cache + // get respMsg from cache + // set msg id + // return respMsg, nil + } + + dnsServer := c.GetServer(reqM.Question.QNAME) + rc, err := c.dialer.NextDialer(reqM.Question.QNAME+":53").Dial("tcp", dnsServer) + if err != nil { + log.F("[dns] failed to connect to server %v: %v", dnsServer, err) + return + } + defer rc.Close() + + if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil { + log.F("[dns] failed to write req message: %v", err) + return + } + + var respLen uint16 + if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil { + log.F("[dns] failed to read response length: %v", err) + return + } + + respBytes = make([]byte, respLen+2) + binary.BigEndian.PutUint16(respBytes[:2], respLen) + + respMsg := respBytes[2:] + _, err = io.ReadFull(rc, respMsg) + if err != nil { + log.F("[dns] error in read respMsg %s\n", err) + return + } + + if reqM.Question.QTYPE != QTypeA && reqM.Question.QTYPE != QTypeAAAA { + return + } + + respM := NewMessage() + err = UnmarshalMessage(respMsg, respM) + if err != nil { + return + } + + ip := "" + for _, answer := range respM.Answers { + if answer.TYPE == QTypeA { + for _, h := range c.Handlers { + h(reqM.Question.QNAME, answer.IP) + } + + if answer.IP != "" { + ip += answer.IP + "," + } + } + + log.F("rr: %+v", answer) + } + + // add to cache + + log.F("[dns] %s <-> %s, type: %d, %s: %s", + clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, ip) + + return +} + +// SetServer . +func (c *Client) SetServer(domain string, servers ...string) { + c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...) +} + +// GetServer . +func (c *Client) GetServer(domain string) string { + domainParts := strings.Split(domain, ".") + length := len(domainParts) + for i := length - 2; i >= 0; i-- { + domain := strings.Join(domainParts[i:length], ".") + + if servers, ok := c.UPServerMap[domain]; ok { + return servers[0] + } + } + + // TODO: + return c.UPServers[0] +} + +// AddHandler . +func (c *Client) AddHandler(h HandleFunc) { + c.Handlers = append(c.Handlers, h) +} diff --git a/dns/dns.go b/dns/dns.go deleted file mode 100644 index d6b1682..0000000 --- a/dns/dns.go +++ /dev/null @@ -1,474 +0,0 @@ -// https://tools.ietf.org/html/rfc1035 - -package dns - -import ( - "encoding/binary" - "errors" - "io" - "net" - "strings" - - "github.com/nadoo/glider/common/log" - "github.com/nadoo/glider/proxy" -) - -// HeaderLen is the length of dns msg header -const HeaderLen = 12 - -// UDPMaxLen is the max size of udp dns request. -// https://tools.ietf.org/html/rfc1035#section-4.2.1 -// Messages carried by UDP are restricted to 512 bytes (not counting the IP -// or UDP headers). Longer messages are truncated and the TC bit is set in -// the header. -// TODO: If the request length > 512 then the client will send TCP packets instead, -// so we should also serve tcp requests. -const UDPMaxLen = 512 - -// QType . -const ( - QTypeA = 1 //ipv4 - QTypeAAAA = 28 ///ipv6 -) - -// Msg format -// https://tools.ietf.org/html/rfc1035#section-4.1 -// All communications inside of the domain protocol are carried in a single -// format called a message. The top level format of message is divided -// into 5 sections (some of which are empty in certain cases) shown below: -// -// +---------------------+ -// | Header | -// +---------------------+ -// | Question | the question for the name server -// +---------------------+ -// | Answer | RRs answering the question -// +---------------------+ -// | Authority | RRs pointing toward an authority -// +---------------------+ -// | Additional | RRs holding additional information -// type Msg struct { -// Header -// Questions []Question -// Answers []RR -// } - -// Header format -// https://tools.ietf.org/html/rfc1035#section-4.1.1 -// The header contains the following fields: -// -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ID | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// |QR| Opcode |AA|TC|RD|RA| Z | RCODE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QDCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ANCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | NSCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | ARCOUNT | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// -// type Header struct { -// ID uint16 -// } - -// Question format -// https://tools.ietf.org/html/rfc1035#section-4.1.2 -// The question section is used to carry the "question" in most queries, -// i.e., the parameters that define what is being asked. The section -// contains QDCOUNT (usually 1) entries, each of the following format: -// -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | | -// / QNAME / -// / / -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QTYPE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | QCLASS | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -type Question struct { - QNAME string - QTYPE uint16 - QCLASS uint16 - - Offset int -} - -// RR format -// https://tools.ietf.org/html/rfc1035#section-3.2.1 -// https://tools.ietf.org/html/rfc1035#section-4.1.3 -// The answer, authority, and additional sections all share the same -// format: a variable number of resource records, where the number of -// records is specified in the corresponding count field in the header. -// Each resource record has the following format: -// -// 1 1 1 1 1 1 -// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | | -// / / -// / NAME / -// | | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | TYPE | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | CLASS | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | TTL | -// | | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -// | RDLENGTH | -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| -// / RDATA / -// / / -// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ -type RR struct { - // NAME string - TYPE uint16 - CLASS uint16 - TTL uint32 - RDLENGTH uint16 - RDATA []byte - - IP string -} - -// AnswerHandler function handles the dns TypeA or TypeAAAA answer -type AnswerHandler func(Domain, ip string) error - -// DNS . -type DNS struct { - dialer proxy.Dialer - addr string - - Tunnel bool - - DNSServer string - - DNSServerMap map[string]string - AnswerHandlers []AnswerHandler -} - -// NewDNS returns a dns forwarder. client[dns.udp] -> glider[tcp] -> forwarder[dns.tcp] -> remote dns addr -func NewDNS(addr, raddr string, dialer proxy.Dialer, tunnel bool) (*DNS, error) { - s := &DNS{ - dialer: dialer, - addr: addr, - - Tunnel: tunnel, - - DNSServer: raddr, - DNSServerMap: make(map[string]string), - } - - return s, nil -} - -// ListenAndServe . -func (s *DNS) ListenAndServe() { - go s.ListenAndServeTCP() - s.ListenAndServeUDP() -} - -// ListenAndServeUDP . -func (s *DNS) ListenAndServeUDP() { - c, err := net.ListenPacket("udp", s.addr) - if err != nil { - log.F("[dns] failed to listen on %s, error: %v", s.addr, err) - return - } - defer c.Close() - - log.F("[dns] listening UDP on %s", s.addr) - - for { - b := make([]byte, UDPMaxLen) - n, clientAddr, err := c.ReadFrom(b) - if err != nil { - log.F("[dns] local read error: %v", err) - continue - } - - reqLen := uint16(n) - // TODO: check here - if reqLen <= HeaderLen+2 { - log.F("[dns] not enough data") - continue - } - - reqMsg := b[:n] - go func() { - _, respMsg, err := s.Exchange(reqLen, reqMsg, clientAddr.String()) - if err != nil { - log.F("[dns] error in exchange: %s", err) - return - } - - _, err = c.WriteTo(respMsg, clientAddr) - if err != nil { - log.F("[dns] error in local write: %s", err) - return - } - - }() - } -} - -// ListenAndServeTCP . -func (s *DNS) ListenAndServeTCP() { - l, err := net.Listen("tcp", s.addr) - if err != nil { - log.F("[dns]-tcp error: %v", err) - return - } - - log.F("[dns]-tcp listening TCP on %s", s.addr) - - for { - c, err := l.Accept() - if err != nil { - log.F("[dns]-tcp error: failed to accept: %v", err) - continue - } - go s.ServeTCP(c) - } -} - -// ServeTCP . -func (s *DNS) ServeTCP(c net.Conn) { - defer c.Close() - - if c, ok := c.(*net.TCPConn); ok { - c.SetKeepAlive(true) - } - - var reqLen uint16 - if err := binary.Read(c, binary.BigEndian, &reqLen); err != nil { - log.F("[dns]-tcp failed to get request length: %v", err) - return - } - - // TODO: check here - if reqLen <= HeaderLen+2 { - log.F("[dns]-tcp not enough data") - return - } - - reqMsg := make([]byte, reqLen) - _, err := io.ReadFull(c, reqMsg) - if err != nil { - log.F("[dns]-tcp error in read reqMsg %s", err) - return - } - - respLen, respMsg, err := s.Exchange(reqLen, reqMsg, c.RemoteAddr().String()) - if err != nil { - log.F("[dns]-tcp error in exchange: %s", err) - return - } - - if err := binary.Write(c, binary.BigEndian, respLen); err != nil { - log.F("[dns]-tcp error in local write respLen: %s", err) - return - } - if err := binary.Write(c, binary.BigEndian, respMsg); err != nil { - log.F("[dns]-tcp error in local write respMsg: %s", err) - return - } -} - -// Exchange handles request msg and returns response msg -// TODO: multiple questions support, parse header to get the number of questions -func (s *DNS) Exchange(reqLen uint16, reqMsg []byte, addr string) (respLen uint16, respMsg []byte, err error) { - // fmt.Printf("\ndns req len %d:\n%s\n", reqLen, hex.Dump(reqMsg[:])) - query, err := parseQuestion(reqMsg) - if err != nil { - log.F("[dns] error in parseQuestion reqMsg: %s", err) - return - } - - dnsServer := s.GetServer(query.QNAME) - - rc, err := s.dialer.NextDialer(query.QNAME+":53").Dial("tcp", dnsServer) - if err != nil { - log.F("[dns] failed to connect to server %v: %v", dnsServer, err) - return - } - defer rc.Close() - - if err = binary.Write(rc, binary.BigEndian, reqLen); err != nil { - log.F("[dns] failed to write req length: %v", err) - return - } - if err = binary.Write(rc, binary.BigEndian, reqMsg); err != nil { - log.F("[dns] failed to write req message: %v", err) - return - } - - if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil { - log.F("[dns] failed to read response length: %v", err) - return - } - - respMsg = make([]byte, respLen) - _, err = io.ReadFull(rc, respMsg) - if err != nil { - log.F("[dns] error in read respMsg %s\n", err) - return - } - - // fmt.Printf("\ndns resp len %d:\n%s\n", respLen, hex.Dump(respMsg[:])) - - var ip string - respReq, err := parseQuestion(respMsg) - if err != nil { - log.F("[dns] error in parseQuestion respMsg: %s", err) - return - } - - if (respReq.QTYPE == QTypeA || respReq.QTYPE == QTypeAAAA) && - len(respMsg) > respReq.Offset { - - var answers []*RR - answers, err = parseAnswers(respMsg[respReq.Offset:]) - if err != nil { - log.F("[dns] error in parseAnswers: %s", err) - return - } - - for _, answer := range answers { - for _, h := range s.AnswerHandlers { - h(respReq.QNAME, answer.IP) - } - - if answer.IP != "" { - ip += answer.IP + "," - } - } - - } - - log.F("[dns] %s <-> %s, type: %d, %s: %s", addr, dnsServer, query.QTYPE, query.QNAME, ip) - return -} - -// SetServer . -func (s *DNS) SetServer(domain, server string) { - s.DNSServerMap[domain] = server -} - -// GetServer . -func (s *DNS) GetServer(domain string) string { - if !s.Tunnel { - domainParts := strings.Split(domain, ".") - length := len(domainParts) - for i := length - 2; i >= 0; i-- { - domain := strings.Join(domainParts[i:length], ".") - - if server, ok := s.DNSServerMap[domain]; ok { - return server - } - } - } - - return s.DNSServer -} - -// AddAnswerHandler . -func (s *DNS) AddAnswerHandler(h AnswerHandler) { - s.AnswerHandlers = append(s.AnswerHandlers, h) -} - -func parseQuestion(p []byte) (*Question, error) { - q := &Question{} - lenP := len(p) - - var i int - var domain []byte - for i = HeaderLen; i < lenP; { - l := int(p[i]) - - if l == 0 { - i++ - break - } - - if lenP <= i+l+1 { - return nil, errors.New("not enough data for QNAME") - } - - domain = append(domain, p[i+1:i+l+1]...) - domain = append(domain, '.') - - i = i + l + 1 - } - - if len(domain) == 0 { - return nil, errors.New("no QNAME") - } - - q.QNAME = string(domain[:len(domain)-1]) - - if lenP < i+4 { - return nil, errors.New("not enough data") - } - - q.QTYPE = binary.BigEndian.Uint16(p[i:]) - q.QCLASS = binary.BigEndian.Uint16(p[i+2:]) - q.Offset = i + 4 - - return q, nil -} - -func parseAnswers(p []byte) ([]*RR, error) { - var answers []*RR - lenP := len(p) - - for i := 0; i < lenP; { - - // https://tools.ietf.org/html/rfc1035#section-4.1.4 - // "Message compression", - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // | 1 1| OFFSET | - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - - if p[i]>>6 == 3 { - i += 2 - } else { - // TODO: none compressed query name and Additional records will be ignored - break - } - - if lenP <= i+10 { - return nil, errors.New("not enough data") - } - - answer := &RR{} - - answer.TYPE = binary.BigEndian.Uint16(p[i:]) - answer.CLASS = binary.BigEndian.Uint16(p[i+2:]) - answer.TTL = binary.BigEndian.Uint32(p[i+4:]) - answer.RDLENGTH = binary.BigEndian.Uint16(p[i+8:]) - answer.RDATA = p[i+10 : i+10+int(answer.RDLENGTH)] - - if answer.TYPE == QTypeA { - answer.IP = net.IP(answer.RDATA[:net.IPv4len]).String() - } else if answer.TYPE == QTypeAAAA { - answer.IP = net.IP(answer.RDATA[:net.IPv6len]).String() - } - - answers = append(answers, answer) - - i = i + 10 + int(answer.RDLENGTH) - } - - return answers, nil -} diff --git a/dns/message.go b/dns/message.go new file mode 100644 index 0000000..82566bc --- /dev/null +++ b/dns/message.go @@ -0,0 +1,425 @@ +package dns + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "math/rand" + "net" + "strings" +) + +// UDPMaxLen is the max size of udp dns request. +// https://tools.ietf.org/html/rfc1035#section-4.2.1 +// Messages carried by UDP are restricted to 512 bytes (not counting the IP +// or UDP headers). Longer messages are truncated and the TC bit is set in +// the header. +const UDPMaxLen = 512 + +// HeaderLen is the length of dns msg header +const HeaderLen = 12 + +// QR types +const ( + QRQuery = 0 + QRResponse = 1 +) + +// QType . +const ( + QTypeA uint16 = 1 //ipv4 + QTypeAAAA uint16 = 28 ///ipv6 +) + +// Message format +// https://tools.ietf.org/html/rfc1035#section-4.1 +// All communications inside of the domain protocol are carried in a single +// format called a message. The top level format of message is divided +// into 5 sections (some of which are empty in certain cases) shown below: +// +// +---------------------+ +// | Header | +// +---------------------+ +// | Question | the question for the name server +// +---------------------+ +// | Answer | RRs answering the question +// +---------------------+ +// | Authority | RRs pointing toward an authority +// +---------------------+ +// | Additional | RRs holding additional information +type Message struct { + // all dns messages should start with a 12 byte dns header + *Header + // most dns implementation only support 1 question + Question *Question + Answers []*RR + Authority []*RR + Additional []*RR + + // used in UnmarshalMessage + unMarshaled []byte +} + +// NewMessage returns a new message +func NewMessage() *Message { + return &Message{ + Header: &Header{}, + } +} + +// SetQuestion sets a question to dns message, +func (m *Message) SetQuestion(q *Question) error { + m.Question = q + m.Header.SetQdcount(1) + return nil +} + +// AddAnswer adds an answer to dns message +func (m *Message) AddAnswer(rr *RR) error { + m.Answers = append(m.Answers, rr) + return nil +} + +// Marshal marshals message struct to []byte +func (m *Message) Marshal() ([]byte, error) { + var buf bytes.Buffer + + m.Header.SetQdcount(1) + m.Header.SetAncount(len(m.Answers)) + + b, err := m.Header.Marshal() + if err != nil { + return nil, err + } + buf.Write(b) + + b, err = m.Question.Marshal() + if err != nil { + return nil, err + } + buf.Write(b) + + // for _, answer := range m.Answers { + // b, err := answer.Marshal() + // if err != nil { + // return nil, err + // } + // buf.Write(b) + // } + + return buf.Bytes(), nil +} + +// UnmarshalMessage unmarshals []bytes to Message +func UnmarshalMessage(b []byte, msg *Message) error { + msg.unMarshaled = b + + err := UnmarshalHeader(b[:HeaderLen], msg.Header) + if err != nil { + return err + } + + q := &Question{} + qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q) + if err != nil { + return err + } + + msg.SetQuestion(q) + + // resp answers + rrLen := 0 + for i := 0; i < int(msg.Header.ANCOUNT); i++ { + rr := &RR{} + rrLen, err = msg.UnmarshalRR(HeaderLen+qLen+rrLen, rr) + if err != nil { + return err + } + msg.AddAnswer(rr) + rrLen += rrLen + } + + msg.Header.SetAncount(len(msg.Answers)) + + return nil +} + +// Header format +// https://tools.ietf.org/html/rfc1035#section-4.1.1 +// The header contains the following fields: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ID | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// |QR| Opcode |AA|TC|RD|RA| Z | RCODE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QDCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ANCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | NSCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | ARCOUNT | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// +type Header struct { + ID uint16 + Bits uint16 + QDCOUNT uint16 + ANCOUNT uint16 + NSCOUNT uint16 + ARCOUNT uint16 +} + +// NewHeader returns a new dns header +func NewHeader(id uint16, qr int) *Header { + if id == 0 { + id = uint16(rand.Uint32()) + } + + h := &Header{ID: id} + h.SetQR(qr) + return h +} + +// SetQR . +func (h *Header) SetQR(qr int) { + h.Bits |= uint16(qr) << 15 +} + +// SetQdcount sets query count, most dns servers only support 1 query per request +func (h *Header) SetQdcount(qdcount int) { + h.QDCOUNT = uint16(qdcount) +} + +// SetAncount sets answers count +func (h *Header) SetAncount(ancount int) { + h.ANCOUNT = uint16(ancount) +} + +func (h *Header) setFlag(QR uint16, Opcode uint16, AA uint16, + TC uint16, RD uint16, RA uint16, RCODE uint16) { + h.Bits = QR<<15 + Opcode<<11 + AA<<10 + TC<<9 + RD<<8 + RA<<7 + RCODE +} + +// Marshal marshals header struct to []byte +func (h *Header) Marshal() ([]byte, error) { + var buf bytes.Buffer + err := binary.Write(&buf, binary.BigEndian, h) + return buf.Bytes(), err +} + +// UnmarshalHeader unmarshals []bytes to Header +func UnmarshalHeader(b []byte, h *Header) error { + if h == nil { + return errors.New("unmarshal header must not be nil") + } + + if len(b) != HeaderLen { + return errors.New("unmarshal header bytes has an unexpected size") + } + + h.ID = binary.BigEndian.Uint16(b[:2]) + h.Bits = binary.BigEndian.Uint16(b[2:4]) + h.QDCOUNT = binary.BigEndian.Uint16(b[4:6]) + h.ANCOUNT = binary.BigEndian.Uint16(b[6:8]) + h.NSCOUNT = binary.BigEndian.Uint16(b[8:10]) + h.ARCOUNT = binary.BigEndian.Uint16(b[10:]) + + return nil +} + +// Question format +// https://tools.ietf.org/html/rfc1035#section-4.1.2 +// The question section is used to carry the "question" in most queries, +// i.e., the parameters that define what is being asked. The section +// contains QDCOUNT (usually 1) entries, each of the following format: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / QNAME / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QTYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | QCLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +type Question struct { + QNAME string + QTYPE uint16 + QCLASS uint16 +} + +// NewQuestion returns a new dns question +func NewQuestion(qtype uint16, domain string) *Question { + return &Question{ + QNAME: domain, + QTYPE: qtype, + QCLASS: 1, + } +} + +// Marshal marshals Question struct to []byte +func (q *Question) Marshal() ([]byte, error) { + var buf bytes.Buffer + + buf.Write(MarshalDomain(q.QNAME)) + binary.Write(&buf, binary.BigEndian, q.QTYPE) + binary.Write(&buf, binary.BigEndian, q.QCLASS) + + return buf.Bytes(), nil +} + +// UnmarshalQuestion unmarshals []bytes to Question +func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) { + if q == nil { + return 0, errors.New("unmarshal question must not be nil") + } + + domain, idx := m.GetDomain(b) + q.QNAME = domain + q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2]) + q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4]) + + return idx + 3 + 1, nil +} + +// RR format +// https://tools.ietf.org/html/rfc1035#section-3.2.1 +// https://tools.ietf.org/html/rfc1035#section-4.1.3 +// The answer, authority, and additional sections all share the same +// format: a variable number of resource records, where the number of +// records is specified in the corresponding count field in the header. +// Each resource record has the following format: +// +// 1 1 1 1 1 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | | +// / / +// / NAME / +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TYPE | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | CLASS | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | TTL | +// | | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +// | RDLENGTH | +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--| +// / RDATA / +// / / +// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ +type RR struct { + NAME string + TYPE uint16 + CLASS uint16 + TTL uint32 + RDLENGTH uint16 + RDATA []byte + + IP string +} + +// NewRR returns a new dns rr +func NewRR() *RR { + rr := &RR{} + return rr +} + +// UnmarshalRR unmarshals []bytes to RR +func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { + if rr == nil { + return 0, errors.New("unmarshal question must not be nil") + } + + // https://tools.ietf.org/html/rfc1035#section-4.1.4 + // "Message compression", + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | 1 1| OFFSET | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + p := m.unMarshaled[start:] + if p[0]>>6 == 3 { + offset := binary.BigEndian.Uint16(p[:2]) + rr.NAME = m.GetDomainByPoint(int(offset) & 0x3F) + n += 2 + } else { + // TODO: none compressed query name and Additional records will be ignored + return 0, nil + } + // domain, n := m.GetDomain(p) + // rr.NAME = domain + + if len(p) <= n+10 { + return 0, errors.New("not enough data") + } + + rr.TYPE = binary.BigEndian.Uint16(p[n:]) + rr.CLASS = binary.BigEndian.Uint16(p[n+2:]) + rr.TTL = binary.BigEndian.Uint32(p[n+4:]) + rr.RDLENGTH = binary.BigEndian.Uint16(p[n+8:]) + rr.RDATA = p[n+10 : n+10+int(rr.RDLENGTH)] + + if rr.TYPE == QTypeA { + rr.IP = net.IP(rr.RDATA[:net.IPv4len]).String() + } else if rr.TYPE == QTypeAAAA { + rr.IP = net.IP(rr.RDATA[:net.IPv6len]).String() + } + + n = n + 10 + int(rr.RDLENGTH) + + return n, nil +} + +// MarshalDomain marshals domain string struct to []byte +func MarshalDomain(domain string) []byte { + var buf bytes.Buffer + + for _, seg := range strings.Split(domain, ".") { + binary.Write(&buf, binary.BigEndian, byte(len(seg))) + binary.Write(&buf, binary.BigEndian, []byte(seg)) + } + binary.Write(&buf, binary.BigEndian, byte(0x00)) + + return buf.Bytes() +} + +// GetDomain gets domain from bytes +func (m *Message) GetDomain(b []byte) (string, int) { + var idx, size int + var labels = []string{} + + for { + if b[idx]&0xC0 == 0xC0 { + fmt.Println("aaaaaaaaaaaaaa") + offset := binary.BigEndian.Uint16(b[idx : idx+2]) + lable := m.GetDomainByPoint(int(offset) & 0x3F) + labels = append(labels, lable) + idx += 2 + break + } else { + size = int(b[idx]) + if size == 0 { + idx++ + break + } + labels = append(labels, string(b[idx+1:idx+size+1])) + idx += (size + 1) + } + } + + return strings.Join(labels, "."), idx +} + +// GetDomainByPoint gets domain from +func (m *Message) GetDomainByPoint(offset int) string { + domain, _ := m.GetDomain(m.unMarshaled[offset:]) + return domain +} diff --git a/dns/server.go b/dns/server.go new file mode 100644 index 0000000..4f824bf --- /dev/null +++ b/dns/server.go @@ -0,0 +1,72 @@ +package dns + +import ( + "encoding/binary" + "net" + + "github.com/nadoo/glider/common/log" + "github.com/nadoo/glider/proxy" +) + +// Server is a dns server struct +type Server struct { + addr string + // Client is used to communicate with upstream dns servers + *Client +} + +// NewServer returns a new dns server +func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server, error) { + c, err := NewClient(dialer, upServers...) + s := &Server{ + addr: addr, + Client: c, + } + + return s, err +} + +// ListenAndServe . +func (s *Server) ListenAndServe() { + c, err := net.ListenPacket("udp", s.addr) + if err != nil { + log.F("[dns] failed to listen on %s, error: %v", s.addr, err) + return + } + defer c.Close() + + log.F("[dns] listening UDP on %s", s.addr) + + for { + reqBytes := make([]byte, 2+UDPMaxLen) + n, caddr, err := c.ReadFrom(reqBytes[2:]) + if err != nil { + log.F("[dns] local read error: %v", err) + continue + } + + reqLen := uint16(n) + if reqLen <= HeaderLen+2 { + log.F("[dns] not enough message data") + continue + } + + binary.BigEndian.PutUint16(reqBytes[:2], reqLen) + + go func() { + respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String()) + if err != nil { + log.F("[dns] error in exchange: %s", err) + return + } + + _, err = c.WriteTo(respBytes[2:], caddr) + if err != nil { + log.F("[dns] error in local write: %s", err) + return + } + + }() + } + +} diff --git a/main.go b/main.go index aff4303..36bb28f 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,7 @@ import ( "github.com/nadoo/glider/dns" "github.com/nadoo/glider/proxy" - _ "github.com/nadoo/glider/proxy/dnstun" + // _ "github.com/nadoo/glider/proxy/dnstun" _ "github.com/nadoo/glider/proxy/http" _ "github.com/nadoo/glider/proxy/mixed" _ "github.com/nadoo/glider/proxy/socks5" @@ -58,7 +58,7 @@ func main() { dialer := NewRuleDialer(conf.rules, dialerFromConf()) ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules) if conf.DNS != "" { - d, err := dns.NewDNS(conf.DNS, conf.DNSServer[0], dialer, false) + d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer...) if err != nil { log.Fatal(err) } @@ -73,9 +73,9 @@ func main() { } // add a handler to update proxy rules when a domain resolved - d.AddAnswerHandler(dialer.AddDomainIP) + d.AddHandler(dialer.AddDomainIP) if ipsetM != nil { - d.AddAnswerHandler(ipsetM.AddDomainIP) + d.AddHandler(ipsetM.AddDomainIP) } go d.ListenAndServe() diff --git a/proxy/direct.go b/proxy/direct.go index 2652e7a..719a731 100644 --- a/proxy/direct.go +++ b/proxy/direct.go @@ -31,7 +31,6 @@ func (d *direct) Dial(network, addr string) (net.Conn, error) { return c, err } -// DialUDP connects to the given address via the proxy func (d *direct) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) { pc, err := net.ListenPacket(network, "") if err != nil {