From a226637bfba60a28d3a77e5def9d93dc3e594958 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Wed, 1 Aug 2018 00:09:55 +0800 Subject: [PATCH] dns: 1. support cache; 2. support custom records; #35 --- conf.go | 2 + dns/cache.go | 69 ++++++++++++++++++++++++ dns/client.go | 139 ++++++++++++++++++++++++++++++++++++++----------- dns/message.go | 47 ++++++++--------- dns/server.go | 4 +- main.go | 7 ++- 6 files changed, 210 insertions(+), 58 deletions(-) create mode 100644 dns/cache.go diff --git a/conf.go b/conf.go index 936cc76..3069dae 100644 --- a/conf.go +++ b/conf.go @@ -25,6 +25,7 @@ var conf struct { DNS string DNSServer []string + DNSRecord []string IPSet string @@ -43,6 +44,7 @@ func confInit() { flag.StringVar(&conf.DNS, "dns", "", "dns forwarder server listen address") flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server") + flag.StringSliceUniqVar(&conf.DNSRecord, "dnsrecord", nil, "custom dns record") flag.StringVar(&conf.IPSet, "ipset", "", "ipset name") diff --git a/dns/cache.go b/dns/cache.go new file mode 100644 index 0000000..50b371a --- /dev/null +++ b/dns/cache.go @@ -0,0 +1,69 @@ +package dns + +import ( + "sync" + "time" +) + +// HundredYears is one hundred years duration in seconds, used for none-expired items +const HundredYears = 100 * 365 * 24 * 3600 + +type item struct { + value []byte + expire time.Time +} + +// Cache is the struct of cache +type Cache struct { + m map[string]*item + l sync.RWMutex +} + +// NewCache returns a new cache +func NewCache() (c *Cache) { + c = &Cache{m: make(map[string]*item)} + go func() { + for now := range time.Tick(time.Second) { + c.l.Lock() + for k, v := range c.m { + if now.After(v.expire) { + delete(c.m, k) + } + } + c.l.Unlock() + } + }() + + return +} + +// Len returns the length of cache +func (c *Cache) Len() int { + return len(c.m) +} + +// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0 +func (c *Cache) Put(k string, v []byte, ttl int) { + if len(v) == 0 { + return + } + + c.l.Lock() + it, ok := c.m[k] + if !ok { + it = &item{value: v} + c.m[k] = it + } + it.expire = time.Now().Add(time.Duration(ttl) * time.Second) + c.l.Unlock() +} + +// Get an item from cache +func (c *Cache) Get(k string) (v []byte) { + c.l.RLock() + if it, ok := c.m[k]; ok { + v = it.value + } + c.l.RUnlock() + return +} diff --git a/dns/client.go b/dns/client.go index a5ae487..3789ec2 100644 --- a/dns/client.go +++ b/dns/client.go @@ -1,33 +1,39 @@ package dns import ( + "bytes" "encoding/binary" + "errors" "io" + "net" "strings" "github.com/nadoo/glider/common/log" "github.com/nadoo/glider/proxy" ) +// DefaultTTL is default ttl in seconds +const DefaultTTL = 600 + // HandleFunc function handles the dns TypeA or TypeAAAA answer type HandleFunc func(Domain, ip string) error // Client is a dns client struct type Client struct { dialer proxy.Dialer - UPServers []string - UPServerMap map[string][]string - Handlers []HandleFunc - - tcp bool + cache *Cache + upServers []string + upServerMap map[string][]string + handlers []HandleFunc } // NewClient returns a new dns client -func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) { +func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) { c := &Client{ dialer: dialer, - UPServers: upServers, - UPServerMap: make(map[string][]string), + cache: NewCache(), + upServers: upServers, + upServerMap: make(map[string][]string), } return c, nil @@ -35,82 +41,96 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) { // Exchange handles request msg and returns response msg // reqBytes = reqLen + reqMsg -func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) { +func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) { req, err := UnmarshalMessage(reqBytes[2:]) if err != nil { - return + return nil, err } if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { - // TODO: if query.QNAME in cache - // get respMsg from cache - // set msg id - // return respMsg, nil + v := c.cache.Get(getKey(req.Question)) + if v != nil { + binary.BigEndian.PutUint16(v[2:4], req.ID) + log.F("[dns] %s <-> cache, type: %d, %s", + clientAddr, req.Question.QTYPE, req.Question.QNAME) + + return v, nil + } } dnsServer := c.GetServer(req.Question.QNAME) rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer) if err != nil { log.F("[dns] failed to connect to server %v: %v", dnsServer, err) - return + return nil, err } defer rc.Close() if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil { log.F("[dns] failed to write req message: %v", err) - return + return nil, err } var respLen uint16 if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil { log.F("[dns] failed to read response length: %v", err) - return + return nil, err } - respBytes = make([]byte, respLen+2) + respBytes := make([]byte, respLen+2) binary.BigEndian.PutUint16(respBytes[:2], respLen) - respMsg := respBytes[2:] - _, err = io.ReadFull(rc, respMsg) + _, err = io.ReadFull(rc, respBytes[2:]) if err != nil { log.F("[dns] error in read respMsg %s\n", err) - return + return nil, err } if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA { - return + log.F("[dns] %s <-> %s, type: %d, %s", + clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME) + return respBytes, nil } - resp, err := UnmarshalMessage(respMsg) + resp, err := UnmarshalMessage(respBytes[2:]) if err != nil { - return + return respBytes, err } + ttl := 0 ips := []string{} for _, answer := range resp.Answers { if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA { - for _, h := range c.Handlers { + for _, h := range c.handlers { h(resp.Question.QNAME, answer.IP) } if answer.IP != "" { ips = append(ips, answer.IP) } + + ttl = int(answer.TTL) } } + // if ttl in packet is 0, set it to default value + if ttl == 0 { + ttl = DefaultTTL + } + // add to cache + c.cache.Put(getKey(resp.Question), respBytes, ttl) log.F("[dns] %s <-> %s, type: %d, %s: %s", clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) - return + return respBytes, nil } // SetServer . func (c *Client) SetServer(domain string, servers ...string) { - c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...) + c.upServerMap[domain] = append(c.upServerMap[domain], servers...) } // GetServer . @@ -120,16 +140,75 @@ func (c *Client) GetServer(domain string) string { for i := length - 2; i >= 0; i-- { domain := strings.Join(domainParts[i:length], ".") - if servers, ok := c.UPServerMap[domain]; ok { + if servers, ok := c.upServerMap[domain]; ok { return servers[0] } } // TODO: - return c.UPServers[0] + return c.upServers[0] } // AddHandler . func (c *Client) AddHandler(h HandleFunc) { - c.Handlers = append(c.Handlers, h) + c.handlers = append(c.handlers, h) +} + +// AddRecord adds custom record to dns cache, format: +// www.example.com/1.2.3.4 or www.example.com/2606:2800:220:1:248:1893:25c8:1946 +func (c *Client) AddRecord(record string) error { + r := strings.Split(record, "/") + domain, ip := r[0], r[1] + m, err := c.GenResponse(domain, ip) + if err != nil { + return err + } + + b, _ := m.Marshal() + + var buf bytes.Buffer + binary.Write(&buf, binary.BigEndian, uint16(len(b))) + buf.Write(b) + + c.cache.Put(getKey(m.Question), buf.Bytes(), HundredYears) + + return nil +} + +// GenResponse . +func (c *Client) GenResponse(domain string, ip string) (*Message, error) { + ipb := net.ParseIP(ip) + if ipb == nil { + return nil, errors.New("GenResponse: invalid ip format") + } + + var rdata []byte + var qtype, rdlen uint16 + if rdata = ipb.To4(); rdata != nil { + qtype = QTypeA + rdlen = net.IPv4len + } else { + qtype = QTypeAAAA + rdlen = net.IPv6len + rdata = ipb + } + + m := NewMessage(0, Response) + m.SetQuestion(NewQuestion(qtype, domain)) + rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN, + RDLENGTH: rdlen, RDATA: rdata} + m.AddAnswer(rr) + + return m, nil +} + +func getKey(q *Question) string { + qtype := "" + switch q.QTYPE { + case QTypeA: + qtype = "A" + case QTypeAAAA: + qtype = "AAAA" + } + return q.QNAME + "/" + qtype } diff --git a/dns/message.go b/dns/message.go index e9b0b6d..8730863 100644 --- a/dns/message.go +++ b/dns/message.go @@ -19,10 +19,10 @@ const UDPMaxLen = 512 // HeaderLen is the length of dns msg header const HeaderLen = 12 -// QR types +// Message types const ( - QRQuery = 0 - QRResponse = 1 + Query = 0 + Response = 1 ) // QType . @@ -31,6 +31,9 @@ const ( QTypeAAAA uint16 = 28 ///ipv6 ) +// CLASSIN . +const CLASSIN uint16 = 1 + // Message format // https://tools.ietf.org/html/rfc1035#section-4.1 // All communications inside of the domain protocol are carried in a single @@ -60,10 +63,15 @@ type Message struct { } // NewMessage returns a new message -func NewMessage() *Message { - return &Message{ - Header: &Header{}, +func NewMessage(id uint16, msgType int) *Message { + if id == 0 { + id = uint16(rand.Uint32()) } + + h := &Header{ID: id} + h.SetMsgType(msgType) + + return &Message{Header: h} } // SetQuestion sets a question to dns message, @@ -111,7 +119,7 @@ func (m *Message) Marshal() ([]byte, error) { // UnmarshalMessage unmarshals []bytes to Message func UnmarshalMessage(b []byte) (*Message, error) { - msg := NewMessage() + msg := &Message{Header: &Header{}} msg.unMarshaled = b err := UnmarshalHeader(b[:HeaderLen], msg.Header) @@ -174,19 +182,8 @@ type Header struct { ARCOUNT uint16 } -// NewHeader returns a new dns header -func NewHeader(id uint16, qr int) *Header { - if id == 0 { - id = uint16(rand.Uint32()) - } - - h := &Header{ID: id} - h.SetQR(qr) - return h -} - -// SetQR . -func (h *Header) SetQR(qr int) { +// SetMsgType . +func (h *Header) SetMsgType(qr int) { h.Bits |= uint16(qr) << 15 } @@ -265,7 +262,7 @@ func NewQuestion(qtype uint16, domain string) *Question { return &Question{ QNAME: domain, QTYPE: qtype, - QCLASS: 1, + QCLASS: CLASSIN, } } @@ -418,12 +415,12 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) { // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ if b[idx]&0xC0 == 0xC0 { offset := binary.BigEndian.Uint16(b[idx : idx+2]) - lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) + label, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) if err != nil { return "", 0, err } - labels = append(labels, lable) + labels = append(labels, label) idx += 2 break } else { @@ -432,11 +429,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) { idx++ break } else if size > 63 { - return "", 0, errors.New("UnmarshalDomain: label length more than 63") + return "", 0, errors.New("UnmarshalDomain: label size larger than 63") } if idx+size+1 > len(b) { - return "", 0, errors.New("UnmarshalDomain: label length more than 63") + return "", 0, errors.New("UnmarshalDomain: label size larger than msg length") } labels = append(labels, string(b[idx+1:idx+size+1])) diff --git a/dns/server.go b/dns/server.go index 0806c0b..77717ed 100644 --- a/dns/server.go +++ b/dns/server.go @@ -21,8 +21,8 @@ type Server struct { } // NewServer returns a new dns server -func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server, error) { - c, err := NewClient(dialer, upServers...) +func NewServer(addr string, dialer proxy.Dialer, upServers []string) (*Server, error) { + c, err := NewClient(dialer, upServers) s := &Server{ addr: addr, Client: c, diff --git a/main.go b/main.go index d84b352..ad41ad7 100644 --- a/main.go +++ b/main.go @@ -57,11 +57,16 @@ func main() { dialer := NewRuleDialer(conf.rules, dialerFromConf()) ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules) if conf.DNS != "" { - d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer...) + d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer) if err != nil { log.Fatal(err) } + // custom records + for _, record := range conf.DNSRecord { + d.AddRecord(record) + } + // rule for _, r := range conf.rules { for _, domain := range r.Domain {