glider/dns/client.go

296 lines
7.1 KiB
Go
Raw Normal View History

2018-07-29 23:44:23 +08:00
package dns
import (
"bytes"
2018-07-29 23:44:23 +08:00
"encoding/binary"
"errors"
2018-07-29 23:44:23 +08:00
"io"
"net"
2018-07-29 23:44:23 +08:00
"strings"
2018-08-05 23:41:34 +08:00
"time"
2018-07-29 23:44:23 +08:00
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/proxy"
)
// HandleFunc function handles the dns TypeA or TypeAAAA answer
type HandleFunc func(Domain, ip string) error
// Config for dns
type Config struct {
2018-08-26 22:36:14 +08:00
Servers []string
Timeout int
MaxTTL int
MinTTL int
Records []string
AlwaysTCP bool
}
2018-07-29 23:44:23 +08:00
// Client is a dns client struct
type Client struct {
dialer proxy.Dialer
cache *Cache
config *Config
upServers []string
upServerMap map[string][]string
handlers []HandleFunc
2018-07-29 23:44:23 +08:00
}
// NewClient returns a new dns client
func NewClient(dialer proxy.Dialer, config *Config) (*Client, error) {
2018-07-29 23:44:23 +08:00
c := &Client{
dialer: dialer,
cache: NewCache(),
config: config,
upServers: config.Servers,
upServerMap: make(map[string][]string),
2018-07-29 23:44:23 +08:00
}
// custom records
for _, record := range config.Records {
c.AddRecord(record)
}
2018-07-29 23:44:23 +08:00
return c, nil
}
// Exchange handles request msg and returns response msg
// reqBytes = reqLen + reqMsg
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
req, err := UnmarshalMessage(reqBytes[2:])
2018-07-29 23:44:23 +08:00
if err != nil {
return nil, err
2018-07-29 23:44:23 +08:00
}
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
v := c.cache.Get(getKey(req.Question))
if v != nil {
binary.BigEndian.PutUint16(v[2:4], req.ID)
log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME)
return v, nil
}
2018-07-29 23:44:23 +08:00
}
dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
2018-07-29 23:44:23 +08:00
if err != nil {
return nil, err
2018-07-29 23:44:23 +08:00
}
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
log.F("[dns] %s <-> %s(%s), type: %d, %s",
clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME)
return respBytes, nil
2018-07-29 23:44:23 +08:00
}
resp, err := UnmarshalMessage(respBytes[2:])
2018-07-29 23:44:23 +08:00
if err != nil {
return respBytes, err
2018-07-29 23:44:23 +08:00
}
ttl := c.config.MinTTL
2018-07-30 00:18:10 +08:00
ips := []string{}
for _, answer := range resp.Answers {
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
for _, h := range c.handlers {
h(resp.Question.QNAME, answer.IP)
2018-07-29 23:44:23 +08:00
}
if answer.IP != "" {
2018-07-30 00:18:10 +08:00
ips = append(ips, answer.IP)
2018-07-29 23:44:23 +08:00
}
if answer.TTL != 0 {
ttl = int(answer.TTL)
}
2018-08-01 00:36:11 +08:00
}
}
if ttl > c.config.MaxTTL {
ttl = c.config.MaxTTL
} else if ttl < c.config.MinTTL {
ttl = c.config.MinTTL
}
2018-08-02 13:02:04 +08:00
// add to cache only when there's a valid ip address
if len(ips) != 0 && ttl > 0 {
c.cache.Put(getKey(resp.Question), respBytes, ttl)
}
log.F("[dns] %s <-> %s(%s), type: %d, %s: %s",
clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
return respBytes, nil
}
// exchange choose a upstream dns server based on qname, communicate with it on the network
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) {
// use tcp to connect upstream server default
network = "tcp"
dialer := c.dialer.NextDialer(qname + ":53")
// if we are resolving the dialer's domain, then use Direct to avoid denpency loop
if strings.Contains(dialer.Addr(), qname) {
dialer = proxy.Default
}
// If client uses udp and no forwarders specified, use udp
2018-08-26 22:36:14 +08:00
if !preferTCP && !c.config.AlwaysTCP && dialer.Addr() == "DIRECT" {
network = "udp"
}
2018-08-05 23:41:34 +08:00
servers := c.GetServers(qname)
for _, server = range servers {
var rc net.Conn
2018-08-06 08:13:23 +08:00
rc, err = dialer.Dial(network, server)
2018-08-05 23:41:34 +08:00
if err != nil {
log.F("[dns] failed to connect to server %v: %v", server, err)
continue
}
2018-08-06 00:46:07 +08:00
defer rc.Close()
2018-07-29 23:44:23 +08:00
2018-08-06 08:48:13 +08:00
// TODO: support timeout setting for different upstream server
rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second))
2018-08-06 08:48:13 +08:00
2018-08-05 23:41:34 +08:00
switch network {
case "tcp":
respBytes, err = c.exchangeTCP(rc, reqBytes)
case "udp":
respBytes, err = c.exchangeUDP(rc, reqBytes)
}
if err == nil {
break
}
2018-08-05 23:41:34 +08:00
log.F("[dns] failed to exchange with server %v: %v", server, err)
}
2018-08-05 23:41:34 +08:00
return server, network, respBytes, err
}
// exchangeTCP exchange with server over tcp
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
if _, err := rc.Write(reqBytes); err != nil {
log.F("[dns] failed to write req message: %v", err)
return nil, err
}
var respLen uint16
if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
log.F("[dns] failed to read response length: %v", err)
return nil, err
}
respBytes := make([]byte, respLen+2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err := io.ReadFull(rc, respBytes[2:])
if err != nil {
log.F("[dns] error in read respMsg %s\n", err)
return nil, err
}
2018-07-29 23:44:23 +08:00
return respBytes, nil
2018-07-29 23:44:23 +08:00
}
// exchangeUDP exchange with server over udp
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
if _, err := rc.Write(reqBytes[2:]); err != nil {
log.F("[dns] failed to write req message: %v", err)
return nil, err
}
reqBytes = make([]byte, 2+UDPMaxLen)
n, err := rc.Read(reqBytes[2:])
if err != nil {
return nil, err
}
binary.BigEndian.PutUint16(reqBytes[:2], uint16(n))
return reqBytes[:2+n], nil
}
// SetServers sets upstream dns servers for the given domain
func (c *Client) SetServers(domain string, servers ...string) {
c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
2018-07-29 23:44:23 +08:00
}
2018-08-06 00:46:07 +08:00
// GetServers gets upstream dns servers for the given domain
2018-08-05 23:41:34 +08:00
func (c *Client) GetServers(domain string) []string {
2018-07-29 23:44:23 +08:00
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
if servers, ok := c.upServerMap[domain]; ok {
2018-08-05 23:41:34 +08:00
return servers
2018-07-29 23:44:23 +08:00
}
}
2018-08-05 23:41:34 +08:00
return c.upServers
2018-07-29 23:44:23 +08:00
}
2018-08-06 00:46:07 +08:00
// AddHandler adds a custom handler to handle the resolved result (A and AAAA)
2018-07-29 23:44:23 +08:00
func (c *Client) AddHandler(h HandleFunc) {
c.handlers = append(c.handlers, h)
}
// AddRecord adds custom record to dns cache, format:
// www.example.com/1.2.3.4 or www.example.com/2606:2800:220:1:248:1893:25c8:1946
func (c *Client) AddRecord(record string) error {
r := strings.Split(record, "/")
domain, ip := r[0], r[1]
m, err := c.GenResponse(domain, ip)
if err != nil {
return err
}
b, _ := m.Marshal()
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL)
return nil
}
2018-08-06 00:46:07 +08:00
// GenResponse generates a dns response message for the given domani an ip address
func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
ipb := net.ParseIP(ip)
if ipb == nil {
return nil, errors.New("GenResponse: invalid ip format")
}
var rdata []byte
var qtype, rdlen uint16
if rdata = ipb.To4(); rdata != nil {
qtype = QTypeA
rdlen = net.IPv4len
} else {
qtype = QTypeAAAA
rdlen = net.IPv6len
rdata = ipb
}
m := NewMessage(0, Response)
m.SetQuestion(NewQuestion(qtype, domain))
rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
RDLENGTH: rdlen, RDATA: rdata}
m.AddAnswer(rr)
return m, nil
}
func getKey(q *Question) string {
qtype := ""
switch q.QTYPE {
case QTypeA:
qtype = "A"
case QTypeAAAA:
qtype = "AAAA"
}
return q.QNAME + "/" + qtype
2018-07-29 23:44:23 +08:00
}