dns: 1. support cache; 2. support custom records; #35

This commit is contained in:
nadoo 2018-08-01 00:09:55 +08:00
parent a2a67df771
commit a226637bfb
6 changed files with 210 additions and 58 deletions

View File

@ -25,6 +25,7 @@ var conf struct {
DNS string DNS string
DNSServer []string DNSServer []string
DNSRecord []string
IPSet string IPSet string
@ -43,6 +44,7 @@ func confInit() {
flag.StringVar(&conf.DNS, "dns", "", "dns forwarder server listen address") flag.StringVar(&conf.DNS, "dns", "", "dns forwarder server listen address")
flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server") flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server")
flag.StringSliceUniqVar(&conf.DNSRecord, "dnsrecord", nil, "custom dns record")
flag.StringVar(&conf.IPSet, "ipset", "", "ipset name") flag.StringVar(&conf.IPSet, "ipset", "", "ipset name")

69
dns/cache.go Normal file
View File

@ -0,0 +1,69 @@
package dns
import (
"sync"
"time"
)
// HundredYears is one hundred years duration in seconds, used for none-expired items
const HundredYears = 100 * 365 * 24 * 3600
type item struct {
value []byte
expire time.Time
}
// Cache is the struct of cache
type Cache struct {
m map[string]*item
l sync.RWMutex
}
// NewCache returns a new cache
func NewCache() (c *Cache) {
c = &Cache{m: make(map[string]*item)}
go func() {
for now := range time.Tick(time.Second) {
c.l.Lock()
for k, v := range c.m {
if now.After(v.expire) {
delete(c.m, k)
}
}
c.l.Unlock()
}
}()
return
}
// Len returns the length of cache
func (c *Cache) Len() int {
return len(c.m)
}
// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
func (c *Cache) Put(k string, v []byte, ttl int) {
if len(v) == 0 {
return
}
c.l.Lock()
it, ok := c.m[k]
if !ok {
it = &item{value: v}
c.m[k] = it
}
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
c.l.Unlock()
}
// Get an item from cache
func (c *Cache) Get(k string) (v []byte) {
c.l.RLock()
if it, ok := c.m[k]; ok {
v = it.value
}
c.l.RUnlock()
return
}

View File

@ -1,33 +1,39 @@
package dns package dns
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"net"
"strings" "strings"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
// DefaultTTL is default ttl in seconds
const DefaultTTL = 600
// HandleFunc function handles the dns TypeA or TypeAAAA answer // HandleFunc function handles the dns TypeA or TypeAAAA answer
type HandleFunc func(Domain, ip string) error type HandleFunc func(Domain, ip string) error
// Client is a dns client struct // Client is a dns client struct
type Client struct { type Client struct {
dialer proxy.Dialer dialer proxy.Dialer
UPServers []string cache *Cache
UPServerMap map[string][]string upServers []string
Handlers []HandleFunc upServerMap map[string][]string
handlers []HandleFunc
tcp bool
} }
// NewClient returns a new dns client // NewClient returns a new dns client
func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) { func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
c := &Client{ c := &Client{
dialer: dialer, dialer: dialer,
UPServers: upServers, cache: NewCache(),
UPServerMap: make(map[string][]string), upServers: upServers,
upServerMap: make(map[string][]string),
} }
return c, nil return c, nil
@ -35,82 +41,96 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
// Exchange handles request msg and returns response msg // Exchange handles request msg and returns response msg
// reqBytes = reqLen + reqMsg // reqBytes = reqLen + reqMsg
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) { func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
req, err := UnmarshalMessage(reqBytes[2:]) req, err := UnmarshalMessage(reqBytes[2:])
if err != nil { if err != nil {
return return nil, err
} }
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
// TODO: if query.QNAME in cache v := c.cache.Get(getKey(req.Question))
// get respMsg from cache if v != nil {
// set msg id binary.BigEndian.PutUint16(v[2:4], req.ID)
// return respMsg, nil log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME)
return v, nil
}
} }
dnsServer := c.GetServer(req.Question.QNAME) dnsServer := c.GetServer(req.Question.QNAME)
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer) rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
if err != nil { if err != nil {
log.F("[dns] failed to connect to server %v: %v", dnsServer, err) log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
return return nil, err
} }
defer rc.Close() defer rc.Close()
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil { if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
log.F("[dns] failed to write req message: %v", err) log.F("[dns] failed to write req message: %v", err)
return return nil, err
} }
var respLen uint16 var respLen uint16
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil { if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
log.F("[dns] failed to read response length: %v", err) log.F("[dns] failed to read response length: %v", err)
return return nil, err
} }
respBytes = make([]byte, respLen+2) respBytes := make([]byte, respLen+2)
binary.BigEndian.PutUint16(respBytes[:2], respLen) binary.BigEndian.PutUint16(respBytes[:2], respLen)
respMsg := respBytes[2:] _, err = io.ReadFull(rc, respBytes[2:])
_, err = io.ReadFull(rc, respMsg)
if err != nil { if err != nil {
log.F("[dns] error in read respMsg %s\n", err) log.F("[dns] error in read respMsg %s\n", err)
return return nil, err
} }
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA { if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
return log.F("[dns] %s <-> %s, type: %d, %s",
clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
return respBytes, nil
} }
resp, err := UnmarshalMessage(respMsg) resp, err := UnmarshalMessage(respBytes[2:])
if err != nil { if err != nil {
return return respBytes, err
} }
ttl := 0
ips := []string{} ips := []string{}
for _, answer := range resp.Answers { for _, answer := range resp.Answers {
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA { if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
for _, h := range c.Handlers { for _, h := range c.handlers {
h(resp.Question.QNAME, answer.IP) h(resp.Question.QNAME, answer.IP)
} }
if answer.IP != "" { if answer.IP != "" {
ips = append(ips, answer.IP) ips = append(ips, answer.IP)
} }
ttl = int(answer.TTL)
} }
} }
// if ttl in packet is 0, set it to default value
if ttl == 0 {
ttl = DefaultTTL
}
// add to cache // add to cache
c.cache.Put(getKey(resp.Question), respBytes, ttl)
log.F("[dns] %s <-> %s, type: %d, %s: %s", log.F("[dns] %s <-> %s, type: %d, %s: %s",
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
return return respBytes, nil
} }
// SetServer . // SetServer .
func (c *Client) SetServer(domain string, servers ...string) { func (c *Client) SetServer(domain string, servers ...string) {
c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...) c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
} }
// GetServer . // GetServer .
@ -120,16 +140,75 @@ func (c *Client) GetServer(domain string) string {
for i := length - 2; i >= 0; i-- { for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".") domain := strings.Join(domainParts[i:length], ".")
if servers, ok := c.UPServerMap[domain]; ok { if servers, ok := c.upServerMap[domain]; ok {
return servers[0] return servers[0]
} }
} }
// TODO: // TODO:
return c.UPServers[0] return c.upServers[0]
} }
// AddHandler . // AddHandler .
func (c *Client) AddHandler(h HandleFunc) { func (c *Client) AddHandler(h HandleFunc) {
c.Handlers = append(c.Handlers, h) c.handlers = append(c.handlers, h)
}
// AddRecord adds custom record to dns cache, format:
// www.example.com/1.2.3.4 or www.example.com/2606:2800:220:1:248:1893:25c8:1946
func (c *Client) AddRecord(record string) error {
r := strings.Split(record, "/")
domain, ip := r[0], r[1]
m, err := c.GenResponse(domain, ip)
if err != nil {
return err
}
b, _ := m.Marshal()
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
c.cache.Put(getKey(m.Question), buf.Bytes(), HundredYears)
return nil
}
// GenResponse .
func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
ipb := net.ParseIP(ip)
if ipb == nil {
return nil, errors.New("GenResponse: invalid ip format")
}
var rdata []byte
var qtype, rdlen uint16
if rdata = ipb.To4(); rdata != nil {
qtype = QTypeA
rdlen = net.IPv4len
} else {
qtype = QTypeAAAA
rdlen = net.IPv6len
rdata = ipb
}
m := NewMessage(0, Response)
m.SetQuestion(NewQuestion(qtype, domain))
rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
RDLENGTH: rdlen, RDATA: rdata}
m.AddAnswer(rr)
return m, nil
}
func getKey(q *Question) string {
qtype := ""
switch q.QTYPE {
case QTypeA:
qtype = "A"
case QTypeAAAA:
qtype = "AAAA"
}
return q.QNAME + "/" + qtype
} }

View File

@ -19,10 +19,10 @@ const UDPMaxLen = 512
// HeaderLen is the length of dns msg header // HeaderLen is the length of dns msg header
const HeaderLen = 12 const HeaderLen = 12
// QR types // Message types
const ( const (
QRQuery = 0 Query = 0
QRResponse = 1 Response = 1
) )
// QType . // QType .
@ -31,6 +31,9 @@ const (
QTypeAAAA uint16 = 28 ///ipv6 QTypeAAAA uint16 = 28 ///ipv6
) )
// CLASSIN .
const CLASSIN uint16 = 1
// Message format // Message format
// https://tools.ietf.org/html/rfc1035#section-4.1 // https://tools.ietf.org/html/rfc1035#section-4.1
// All communications inside of the domain protocol are carried in a single // All communications inside of the domain protocol are carried in a single
@ -60,10 +63,15 @@ type Message struct {
} }
// NewMessage returns a new message // NewMessage returns a new message
func NewMessage() *Message { func NewMessage(id uint16, msgType int) *Message {
return &Message{ if id == 0 {
Header: &Header{}, id = uint16(rand.Uint32())
} }
h := &Header{ID: id}
h.SetMsgType(msgType)
return &Message{Header: h}
} }
// SetQuestion sets a question to dns message, // SetQuestion sets a question to dns message,
@ -111,7 +119,7 @@ func (m *Message) Marshal() ([]byte, error) {
// UnmarshalMessage unmarshals []bytes to Message // UnmarshalMessage unmarshals []bytes to Message
func UnmarshalMessage(b []byte) (*Message, error) { func UnmarshalMessage(b []byte) (*Message, error) {
msg := NewMessage() msg := &Message{Header: &Header{}}
msg.unMarshaled = b msg.unMarshaled = b
err := UnmarshalHeader(b[:HeaderLen], msg.Header) err := UnmarshalHeader(b[:HeaderLen], msg.Header)
@ -174,19 +182,8 @@ type Header struct {
ARCOUNT uint16 ARCOUNT uint16
} }
// NewHeader returns a new dns header // SetMsgType .
func NewHeader(id uint16, qr int) *Header { func (h *Header) SetMsgType(qr int) {
if id == 0 {
id = uint16(rand.Uint32())
}
h := &Header{ID: id}
h.SetQR(qr)
return h
}
// SetQR .
func (h *Header) SetQR(qr int) {
h.Bits |= uint16(qr) << 15 h.Bits |= uint16(qr) << 15
} }
@ -265,7 +262,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
return &Question{ return &Question{
QNAME: domain, QNAME: domain,
QTYPE: qtype, QTYPE: qtype,
QCLASS: 1, QCLASS: CLASSIN,
} }
} }
@ -418,12 +415,12 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
if b[idx]&0xC0 == 0xC0 { if b[idx]&0xC0 == 0xC0 {
offset := binary.BigEndian.Uint16(b[idx : idx+2]) offset := binary.BigEndian.Uint16(b[idx : idx+2])
lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) label, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
if err != nil { if err != nil {
return "", 0, err return "", 0, err
} }
labels = append(labels, lable) labels = append(labels, label)
idx += 2 idx += 2
break break
} else { } else {
@ -432,11 +429,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
idx++ idx++
break break
} else if size > 63 { } else if size > 63 {
return "", 0, errors.New("UnmarshalDomain: label length more than 63") return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
} }
if idx+size+1 > len(b) { if idx+size+1 > len(b) {
return "", 0, errors.New("UnmarshalDomain: label length more than 63") return "", 0, errors.New("UnmarshalDomain: label size larger than msg length")
} }
labels = append(labels, string(b[idx+1:idx+size+1])) labels = append(labels, string(b[idx+1:idx+size+1]))

View File

@ -21,8 +21,8 @@ type Server struct {
} }
// NewServer returns a new dns server // NewServer returns a new dns server
func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server, error) { func NewServer(addr string, dialer proxy.Dialer, upServers []string) (*Server, error) {
c, err := NewClient(dialer, upServers...) c, err := NewClient(dialer, upServers)
s := &Server{ s := &Server{
addr: addr, addr: addr,
Client: c, Client: c,

View File

@ -57,11 +57,16 @@ func main() {
dialer := NewRuleDialer(conf.rules, dialerFromConf()) dialer := NewRuleDialer(conf.rules, dialerFromConf())
ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules) ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules)
if conf.DNS != "" { if conf.DNS != "" {
d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer...) d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// custom records
for _, record := range conf.DNSRecord {
d.AddRecord(record)
}
// rule // rule
for _, r := range conf.rules { for _, r := range conf.rules {
for _, domain := range r.Domain { for _, domain := range r.Domain {