From a9a1985a4bf01a8e893d0ff6714f6c28b6723240 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Sat, 2 May 2020 20:02:19 +0800 Subject: [PATCH] dns: allow to switch dns server permanently (#97) --- dns/client.go | 37 ++++++++++++++++++++++--------------- dns/upstream.go | 29 +++++++++++++++++++++++++++++ go.mod | 2 +- go.sum | 4 ++-- main.go | 2 +- 5 files changed, 55 insertions(+), 19 deletions(-) create mode 100644 dns/upstream.go diff --git a/dns/client.go b/dns/client.go index 9f3602a..70635b4 100644 --- a/dns/client.go +++ b/dns/client.go @@ -31,8 +31,8 @@ type Client struct { proxy proxy.Proxy cache *Cache config *Config - upServers []string - upServerMap map[string][]string + upStream *UpStream + upStreamMap map[string]*UpStream handlers []HandleFunc } @@ -42,8 +42,8 @@ func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) { proxy: proxy, cache: NewCache(), config: config, - upServers: config.Servers, - upServerMap: make(map[string][]string), + upStream: NewUpStream(config.Servers), + upStreamMap: make(map[string]*UpStream), } // custom records @@ -148,12 +148,13 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) ( network = "udp" } - servers := c.GetServers(qname) - for _, server = range servers { + ups := c.UpStream(qname) + server = ups.Server() + for i := 0; i < ups.Len(); i++ { var rc net.Conn rc, err = dialer.Dial(network, server) if err != nil { - log.F("[dns] error in resolving %s, failed to connect to server %v: %v", qname, server, err) + log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v", qname, server, dialer.Addr(), err) continue } defer rc.Close() @@ -172,12 +173,18 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) ( break } - log.F("[dns] error in resolving %s, failed to exchange with server %v: %v", qname, server, err) + newServer := ups.Switch() + log.F("[dns] error in resolving %s, failed to exchange with server %v via %s: %v, switch to %s", + qname, server, dialer.Addr(), err, newServer) + + server = newServer } + // if all dns upstreams failed, then maybe the forwarder is not available. if err != nil { c.proxy.Record(dialer, false) } + return server, network, dialer.Addr(), respBytes, err } @@ -220,23 +227,23 @@ func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) { } // SetServers sets upstream dns servers for the given domain. -func (c *Client) SetServers(domain string, servers ...string) { - c.upServerMap[domain] = append(c.upServerMap[domain], servers...) +func (c *Client) SetServers(domain string, servers []string) { + c.upStreamMap[domain] = NewUpStream(servers) } -// GetServers gets upstream dns servers for the given domain -func (c *Client) GetServers(domain string) []string { +// UpStream returns upstream dns server for the given domain. +func (c *Client) UpStream(domain string) *UpStream { domainParts := strings.Split(domain, ".") length := len(domainParts) for i := length - 1; i >= 0; i-- { domain := strings.Join(domainParts[i:length], ".") - if servers, ok := c.upServerMap[domain]; ok { - return servers + if upstream, ok := c.upStreamMap[domain]; ok { + return upstream } } - return c.upServers + return c.upStream } // AddHandler adds a custom handler to handle the resolved result (A and AAAA). diff --git a/dns/upstream.go b/dns/upstream.go new file mode 100644 index 0000000..7bf23e4 --- /dev/null +++ b/dns/upstream.go @@ -0,0 +1,29 @@ +package dns + +import "sync/atomic" + +// UpStream is a dns upstream. +type UpStream struct { + index uint32 + servers []string +} + +// NewUpStream returns a new UpStream. +func NewUpStream(servers []string) *UpStream { + return &UpStream{servers: servers} +} + +// Server returns a dns server. +func (u *UpStream) Server() string { + return u.servers[atomic.LoadUint32(&u.index)%uint32(len(u.servers))] +} + +// Switch switches to the next dns server. +func (u *UpStream) Switch() string { + return u.servers[atomic.AddUint32(&u.index, 1)%uint32(len(u.servers))] +} + +// Len returns the number of dns servers. +func (u *UpStream) Len() int { + return len(u.servers) +} diff --git a/go.mod b/go.mod index 2329787..6c7ad4c 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/xtaci/kcp-go/v5 v5.5.12 golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79 golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5 // indirect - golang.org/x/sys v0.0.0-20200501052902-10377860bb8e // indirect + golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3 // indirect ) // Replace dependency modules with local developing copy diff --git a/go.sum b/go.sum index 606d2c1..502469b 100644 --- a/go.sum +++ b/go.sum @@ -75,8 +75,8 @@ golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191020212454-3e7259c5e7c2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e h1:hq86ru83GdWTlfQFZGO4nZJTU4Bs2wfHl8oFHRaXsfc= -golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3 h1:5B6i6EAiSYyejWfvc5Rc9BbI3rzIsrrXfAQBWnYfn+w= +golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go index 872f805..1e7739e 100644 --- a/main.go +++ b/main.go @@ -62,7 +62,7 @@ func main() { for _, r := range conf.rules { for _, domain := range r.Domain { if len(r.DNSServers) > 0 { - d.SetServers(domain, r.DNSServers...) + d.SetServers(domain, r.DNSServers) } } }