diff --git a/dns/client.go b/dns/client.go index 3c61889..e9042fc 100644 --- a/dns/client.go +++ b/dns/client.go @@ -7,6 +7,7 @@ import ( "io" "net" "strings" + "time" "github.com/nadoo/glider/common/log" "github.com/nadoo/glider/proxy" @@ -105,7 +106,6 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) { // use tcp to connect upstream server default network = "tcp" - dialer := c.dialer.NextDialer(qname + ":53") // If client uses udp and no forwarders specified, use udp @@ -113,21 +113,30 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server network = "udp" } - server = c.GetServer(qname) - rc, err := dialer.Dial(network, server) - if err != nil { - log.F("[dns] failed to connect to server %v: %v", server, err) - return + servers := c.GetServers(qname) + for _, server = range servers { + rc, err := dialer.Dial(network, server) + // TODO: check the timeout setting here, some dns client has 5 seconds timeout + rc.SetDeadline(time.Now().Add(time.Duration(3) * time.Second)) + if err != nil { + log.F("[dns] failed to connect to server %v: %v", server, err) + continue + } + + switch network { + case "tcp": + respBytes, err = c.exchangeTCP(rc, reqBytes) + case "udp": + respBytes, err = c.exchangeUDP(rc, reqBytes) + } + + if err == nil { + break + } + log.F("[dns] failed to exchange with server %v: %v", server, err) } - switch network { - case "tcp": - respBytes, err = c.exchangeTCP(rc, reqBytes) - case "udp": - respBytes, err = c.exchangeUDP(rc, reqBytes) - } - - return + return server, network, respBytes, err } // exchangeTCP exchange with server over tcp @@ -181,20 +190,20 @@ func (c *Client) SetServer(domain string, servers ...string) { c.upServerMap[domain] = append(c.upServerMap[domain], servers...) } -// GetServer . -func (c *Client) GetServer(domain string) string { +// GetServers . +func (c *Client) GetServers(domain string) []string { domainParts := strings.Split(domain, ".") length := len(domainParts) for i := length - 2; i >= 0; i-- { domain := strings.Join(domainParts[i:length], ".") if servers, ok := c.upServerMap[domain]; ok { - return servers[0] + return servers } } // TODO: - return c.upServers[0] + return c.upServers } // AddHandler .