glider/dns/client.go

345 lines
8.8 KiB
Go
Raw Permalink Normal View History

2018-07-29 23:44:23 +08:00
package dns
import (
"encoding/binary"
"errors"
2018-07-29 23:44:23 +08:00
"io"
"net"
2022-01-28 15:12:02 +08:00
"net/netip"
2020-10-09 22:02:19 +08:00
"strconv"
2018-07-29 23:44:23 +08:00
"strings"
2018-08-05 23:41:34 +08:00
"time"
2018-07-29 23:44:23 +08:00
"github.com/nadoo/glider/pkg/log"
"github.com/nadoo/glider/pkg/pool"
2018-07-29 23:44:23 +08:00
"github.com/nadoo/glider/proxy"
)
// AnswerHandler function handles the dns TypeA or TypeAAAA answer.
type AnswerHandler func(domain string, ip netip.Addr) error
2018-07-29 23:44:23 +08:00
// Config for dns.
type Config struct {
2018-08-26 22:36:14 +08:00
Servers []string
Timeout int
MaxTTL int
MinTTL int
Records []string
AlwaysTCP bool
2020-10-08 18:48:23 +08:00
CacheSize int
CacheLog bool
NoAAAA bool
}
// Client is a dns client struct.
2018-07-29 23:44:23 +08:00
type Client struct {
proxy proxy.Proxy
2020-10-08 18:48:23 +08:00
cache *LruCache
config *Config
2020-05-06 20:10:18 +08:00
upStream *UPStream
upStreamMap map[string]*UPStream
handlers []AnswerHandler
2018-07-29 23:44:23 +08:00
}
// NewClient returns a new dns client.
func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) {
2018-07-29 23:44:23 +08:00
c := &Client{
proxy: proxy,
2020-10-08 18:48:23 +08:00
cache: NewLruCache(config.CacheSize),
config: config,
2020-05-06 20:10:18 +08:00
upStream: NewUPStream(config.Servers),
upStreamMap: make(map[string]*UPStream),
2018-07-29 23:44:23 +08:00
}
// custom records
for _, record := range config.Records {
if err := c.AddRecord(record); err != nil {
log.F("[dns] add record '%s' error: %s", record, err)
}
}
2018-07-29 23:44:23 +08:00
return c, nil
}
// Exchange handles request message and returns response message.
2020-10-09 22:02:19 +08:00
// TODO: optimize it
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
2020-10-09 22:02:19 +08:00
req, err := UnmarshalMessage(reqBytes)
2018-07-29 23:44:23 +08:00
if err != nil {
return nil, err
2018-07-29 23:44:23 +08:00
}
if c.config.NoAAAA && req.Question.QTYPE == QTypeAAAA {
respBytes := valCopy(reqBytes)
respBytes[2] |= uint8(ResponseMsg) << 7
return respBytes, nil
}
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
2020-10-09 22:02:19 +08:00
if v, expired := c.cache.Get(qKey(req.Question)); len(v) > 2 {
2020-10-08 18:48:23 +08:00
v = valCopy(v)
2020-10-09 22:02:19 +08:00
binary.BigEndian.PutUint16(v[:2], req.ID)
if c.config.CacheLog {
log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME)
}
2020-10-09 22:02:19 +08:00
if expired { // update cache
go func(qname string, reqBytes []byte, preferTCP bool) {
defer pool.PutBuffer(reqBytes)
if dnsServer, network, dialerAddr, respBytes, err := c.exchange(qname, reqBytes, preferTCP); err == nil {
c.handleAnswer(respBytes, "cache", dnsServer, network, dialerAddr)
2020-10-09 22:02:19 +08:00
}
}(req.Question.QNAME, valCopy(reqBytes), preferTCP)
}
return v, nil
}
2018-07-29 23:44:23 +08:00
}
dnsServer, network, dialerAddr, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
2018-07-29 23:44:23 +08:00
if err != nil {
return nil, err
2018-07-29 23:44:23 +08:00
}
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s",
clientAddr, dnsServer, network, dialerAddr, req.Question.QTYPE, req.Question.QNAME)
return respBytes, nil
2018-07-29 23:44:23 +08:00
}
2020-10-09 22:02:19 +08:00
err = c.handleAnswer(respBytes, clientAddr, dnsServer, network, dialerAddr)
return respBytes, err
}
func (c *Client) handleAnswer(respBytes []byte, clientAddr, dnsServer, network, dialerAddr string) error {
resp, err := UnmarshalMessage(respBytes)
2018-07-29 23:44:23 +08:00
if err != nil {
2020-10-09 22:02:19 +08:00
return err
2018-07-29 23:44:23 +08:00
}
2020-04-13 00:55:11 +08:00
ips, ttl := c.extractAnswer(resp)
if ttl > c.config.MaxTTL {
ttl = c.config.MaxTTL
} else if ttl < c.config.MinTTL {
ttl = c.config.MinTTL
}
if ttl <= 0 { // we got a null result
ttl = 1800
2020-04-13 00:55:11 +08:00
}
c.cache.Set(qKey(resp.Question), valCopy(respBytes), ttl)
log.F("[dns] %s <-> %s(%s) via %s, %s/%d: %s, ttl: %ds",
2021-07-06 20:31:39 +08:00
clientAddr, dnsServer, network, dialerAddr, resp.Question.QNAME, resp.Question.QTYPE, strings.Join(ips, ","), ttl)
2020-04-13 00:55:11 +08:00
2020-10-09 22:02:19 +08:00
return nil
2020-04-13 00:55:11 +08:00
}
func (c *Client) extractAnswer(resp *Message) ([]string, int) {
var ips []string
ttl := c.config.MinTTL
for _, answer := range resp.Answers {
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
if answer.IP.IsValid() && !answer.IP.IsUnspecified() {
2022-01-29 21:10:09 +08:00
for _, h := range c.handlers {
h(resp.Question.QNAME, answer.IP)
2022-01-29 21:10:09 +08:00
}
ips = append(ips, answer.IP.String())
2018-07-29 23:44:23 +08:00
}
if answer.TTL != 0 {
ttl = int(answer.TTL)
}
2018-08-01 00:36:11 +08:00
}
}
2020-04-13 00:55:11 +08:00
return ips, ttl
}
// exchange choose a upstream dns server based on qname, communicate with it on the network.
2019-09-19 18:03:48 +08:00
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
server, network, dialerAddr string, respBytes []byte, err error) {
// use tcp to connect upstream server default
network = "tcp"
2021-07-02 19:09:01 +08:00
dialer := c.proxy.NextDialer(qname + ":0")
2021-07-02 19:09:01 +08:00
// if we are resolving a domain which uses a forwarder `REJECT`, then use `DIRECT` instead
// so we can resolve it correctly.
2019-09-19 18:03:48 +08:00
// TODO: dialer.Addr() == "REJECT", tricky
2021-07-02 19:09:01 +08:00
if dialer.Addr() == "REJECT" {
dialer = c.proxy.NextDialer("direct:0")
}
// If client uses udp and no forwarders specified, use udp
// TODO: dialer.Addr() == "DIRECT", tricky
2018-08-26 22:36:14 +08:00
if !preferTCP && !c.config.AlwaysTCP && dialer.Addr() == "DIRECT" {
network = "udp"
}
ups := c.UpStream(qname)
server = ups.Server()
for i := 0; i < ups.Len(); i++ {
var rc net.Conn
rc, err = dialer.Dial(network, server)
2018-08-05 23:41:34 +08:00
if err != nil {
newServer := ups.SwitchIf(server)
2020-09-29 00:38:35 +08:00
log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v, next server: %s",
qname, server, dialer.Addr(), err, newServer)
server = newServer
2018-08-05 23:41:34 +08:00
continue
}
2018-08-06 00:46:07 +08:00
defer rc.Close()
2018-07-29 23:44:23 +08:00
2018-08-06 08:48:13 +08:00
// TODO: support timeout setting for different upstream server
2020-05-05 01:30:57 +08:00
if c.config.Timeout > 0 {
rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second))
}
2018-08-06 08:48:13 +08:00
2018-08-05 23:41:34 +08:00
switch network {
case "tcp":
respBytes, err = c.exchangeTCP(rc, reqBytes)
case "udp":
respBytes, err = c.exchangeUDP(rc, reqBytes)
}
if err == nil {
break
}
newServer := ups.SwitchIf(server)
2020-09-29 00:38:35 +08:00
log.F("[dns] error in resolving %s, failed to exchange with server %v via %s: %v, next server: %s",
qname, server, dialer.Addr(), err, newServer)
server = newServer
}
// if all dns upstreams failed, then maybe the forwarder is not available.
if err != nil {
c.proxy.Record(dialer, false)
}
return server, network, dialer.Addr(), respBytes, err
}
// exchangeTCP exchange with server over tcp.
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
2020-10-23 22:29:12 +08:00
lenBuf := pool.GetBuffer(2)
defer pool.PutBuffer(lenBuf)
2020-10-09 22:02:19 +08:00
2020-10-23 22:29:12 +08:00
binary.BigEndian.PutUint16(lenBuf, uint16(len(reqBytes)))
if _, err := (&net.Buffers{lenBuf, reqBytes}).WriteTo(rc); err != nil {
return nil, err
}
var respLen uint16
if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
return nil, err
}
2020-10-09 22:02:19 +08:00
respBytes := pool.GetBuffer(int(respLen))
_, err := io.ReadFull(rc, respBytes)
if err != nil {
return nil, err
}
2018-07-29 23:44:23 +08:00
return respBytes, nil
2018-07-29 23:44:23 +08:00
}
// exchangeUDP exchange with server over udp.
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
2020-10-09 22:02:19 +08:00
if _, err := rc.Write(reqBytes); err != nil {
return nil, err
}
2020-08-23 23:23:30 +08:00
respBytes := pool.GetBuffer(UDPMaxLen)
2020-10-09 22:02:19 +08:00
n, err := rc.Read(respBytes)
if err != nil {
return nil, err
}
2020-10-09 22:02:19 +08:00
return respBytes[:n], nil
}
// SetServers sets upstream dns servers for the given domain.
func (c *Client) SetServers(domain string, servers []string) {
c.upStreamMap[strings.ToLower(domain)] = NewUPStream(servers)
2018-07-29 23:44:23 +08:00
}
// UpStream returns upstream dns server for the given domain.
2020-05-06 20:10:18 +08:00
func (c *Client) UpStream(domain string) *UPStream {
domain = strings.ToLower(domain)
for i := len(domain); i != -1; {
i = strings.LastIndexByte(domain[:i], '.')
if upstream, ok := c.upStreamMap[domain[i+1:]]; ok {
return upstream
2018-07-29 23:44:23 +08:00
}
}
return c.upStream
2018-07-29 23:44:23 +08:00
}
// AddHandler adds a custom handler to handle the resolved result (A and AAAA).
func (c *Client) AddHandler(h AnswerHandler) {
c.handlers = append(c.handlers, h)
}
// AddRecord adds custom record to dns cache, format:
// www.example.com/1.2.3.4 or www.example.com/2606:2800:220:1:248:1893:25c8:1946
func (c *Client) AddRecord(record string) error {
domain, ip, found := strings.Cut(record, "/")
if !found {
return errors.New("wrong record format, must contain '/'")
}
m, err := MakeResponse(domain, ip, uint32(c.config.MaxTTL))
if err != nil {
2022-01-28 15:12:02 +08:00
log.F("[dns] add custom record error: %s", err)
return err
}
2020-11-03 22:52:50 +08:00
wb := pool.GetBytesBuffer()
defer pool.PutBytesBuffer(wb)
2020-10-09 22:02:19 +08:00
_, err = m.MarshalTo(wb)
2020-08-23 23:23:30 +08:00
if err != nil {
return err
}
2020-10-09 22:02:19 +08:00
c.cache.Set(qKey(m.Question), valCopy(wb.Bytes()), 0)
return nil
}
2020-08-23 23:23:30 +08:00
// MakeResponse makes a dns response message for the given domain and ip address.
// Note: you should make sure ttl > 0.
func MakeResponse(domain, ip string, ttl uint32) (*Message, error) {
2022-01-28 15:12:02 +08:00
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, err
}
2022-01-28 15:12:02 +08:00
var qtype, rdlen uint16 = QTypeA, net.IPv4len
if addr.Is6() {
qtype, rdlen = QTypeAAAA, net.IPv6len
}
2022-01-22 23:33:08 +08:00
m := NewMessage(0, ResponseMsg)
m.SetQuestion(NewQuestion(qtype, domain))
rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
2022-01-28 15:12:02 +08:00
TTL: ttl, RDLENGTH: rdlen, RDATA: addr.AsSlice()}
m.AddAnswer(rr)
return m, nil
}
2020-08-23 23:23:30 +08:00
func qKey(q *Question) string {
2020-10-09 22:02:19 +08:00
return q.QNAME + "/" + strconv.FormatUint(uint64(q.QTYPE), 10)
2018-07-29 23:44:23 +08:00
}
2020-10-08 18:48:23 +08:00
func valCopy(v []byte) (b []byte) {
if v != nil {
b = pool.GetBuffer(len(v))
copy(b, v)
}
return
}