dns: allow to switch dns server permanently (#97)

This commit is contained in:
nadoo 2020-05-02 20:02:19 +08:00
parent b99a730968
commit a9a1985a4b
5 changed files with 55 additions and 19 deletions

View File

@ -31,8 +31,8 @@ type Client struct {
proxy proxy.Proxy proxy proxy.Proxy
cache *Cache cache *Cache
config *Config config *Config
upServers []string upStream *UpStream
upServerMap map[string][]string upStreamMap map[string]*UpStream
handlers []HandleFunc handlers []HandleFunc
} }
@ -42,8 +42,8 @@ func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) {
proxy: proxy, proxy: proxy,
cache: NewCache(), cache: NewCache(),
config: config, config: config,
upServers: config.Servers, upStream: NewUpStream(config.Servers),
upServerMap: make(map[string][]string), upStreamMap: make(map[string]*UpStream),
} }
// custom records // custom records
@ -148,12 +148,13 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
network = "udp" network = "udp"
} }
servers := c.GetServers(qname) ups := c.UpStream(qname)
for _, server = range servers { server = ups.Server()
for i := 0; i < ups.Len(); i++ {
var rc net.Conn var rc net.Conn
rc, err = dialer.Dial(network, server) rc, err = dialer.Dial(network, server)
if err != nil { if err != nil {
log.F("[dns] error in resolving %s, failed to connect to server %v: %v", qname, server, err) log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v", qname, server, dialer.Addr(), err)
continue continue
} }
defer rc.Close() defer rc.Close()
@ -172,12 +173,18 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
break break
} }
log.F("[dns] error in resolving %s, failed to exchange with server %v: %v", qname, server, err) newServer := ups.Switch()
log.F("[dns] error in resolving %s, failed to exchange with server %v via %s: %v, switch to %s",
qname, server, dialer.Addr(), err, newServer)
server = newServer
} }
// if all dns upstreams failed, then maybe the forwarder is not available.
if err != nil { if err != nil {
c.proxy.Record(dialer, false) c.proxy.Record(dialer, false)
} }
return server, network, dialer.Addr(), respBytes, err return server, network, dialer.Addr(), respBytes, err
} }
@ -220,23 +227,23 @@ func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
} }
// SetServers sets upstream dns servers for the given domain. // SetServers sets upstream dns servers for the given domain.
func (c *Client) SetServers(domain string, servers ...string) { func (c *Client) SetServers(domain string, servers []string) {
c.upServerMap[domain] = append(c.upServerMap[domain], servers...) c.upStreamMap[domain] = NewUpStream(servers)
} }
// GetServers gets upstream dns servers for the given domain // UpStream returns upstream dns server for the given domain.
func (c *Client) GetServers(domain string) []string { func (c *Client) UpStream(domain string) *UpStream {
domainParts := strings.Split(domain, ".") domainParts := strings.Split(domain, ".")
length := len(domainParts) length := len(domainParts)
for i := length - 1; i >= 0; i-- { for i := length - 1; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".") domain := strings.Join(domainParts[i:length], ".")
if servers, ok := c.upServerMap[domain]; ok { if upstream, ok := c.upStreamMap[domain]; ok {
return servers return upstream
} }
} }
return c.upServers return c.upStream
} }
// AddHandler adds a custom handler to handle the resolved result (A and AAAA). // AddHandler adds a custom handler to handle the resolved result (A and AAAA).

29
dns/upstream.go Normal file
View File

@ -0,0 +1,29 @@
package dns
import "sync/atomic"
// UpStream is a dns upstream.
type UpStream struct {
index uint32
servers []string
}
// NewUpStream returns a new UpStream.
func NewUpStream(servers []string) *UpStream {
return &UpStream{servers: servers}
}
// Server returns a dns server.
func (u *UpStream) Server() string {
return u.servers[atomic.LoadUint32(&u.index)%uint32(len(u.servers))]
}
// Switch switches to the next dns server.
func (u *UpStream) Switch() string {
return u.servers[atomic.AddUint32(&u.index, 1)%uint32(len(u.servers))]
}
// Len returns the number of dns servers.
func (u *UpStream) Len() int {
return len(u.servers)
}

2
go.mod
View File

@ -13,7 +13,7 @@ require (
github.com/xtaci/kcp-go/v5 v5.5.12 github.com/xtaci/kcp-go/v5 v5.5.12
golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79 golang.org/x/crypto v0.0.0-20200429183012-4b2356b1ed79
golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5 // indirect golang.org/x/net v0.0.0-20200501053045-e0ff5e5a1de5 // indirect
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e // indirect golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3 // indirect
) )
// Replace dependency modules with local developing copy // Replace dependency modules with local developing copy

4
go.sum
View File

@ -75,8 +75,8 @@ golang.org/x/sys v0.0.0-20191010194322-b09406accb47/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191020212454-3e7259c5e7c2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191020212454-3e7259c5e7c2/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e h1:hq86ru83GdWTlfQFZGO4nZJTU4Bs2wfHl8oFHRaXsfc= golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3 h1:5B6i6EAiSYyejWfvc5Rc9BbI3rzIsrrXfAQBWnYfn+w=
golang.org/x/sys v0.0.0-20200501052902-10377860bb8e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200501145240-bc7a7d42d5c3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@ -62,7 +62,7 @@ func main() {
for _, r := range conf.rules { for _, r := range conf.rules {
for _, domain := range r.Domain { for _, domain := range r.Domain {
if len(r.DNSServers) > 0 { if len(r.DNSServers) > 0 {
d.SetServers(domain, r.DNSServers...) d.SetServers(domain, r.DNSServers)
} }
} }
} }