dns: switch upstream when dial error occurred

This commit is contained in:
nadoo 2020-05-03 20:02:11 +08:00
parent 2fe9c3990b
commit 48e059db6c
5 changed files with 32 additions and 19 deletions

View File

@ -52,6 +52,7 @@ we can set up local listeners as proxy servers, and forward requests to internet
|trojan | | |√|√|client only |trojan | | |√|√|client only
|vmess | | |√| |client only |vmess | | |√| |client only
|redir |√| | | |linux only |redir |√| | | |linux only
|redir6 |√| | | |linux only(ipv6)
|tls |√| |√| |transport client & server |tls |√| |√| |transport client & server
|kcp | |√|√| |transport client & server |kcp | |√|√| |transport client & server
|unix |√| |√| |transport client & server |unix |√| |√| |transport client & server

View File

@ -154,7 +154,10 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
var rc net.Conn var rc net.Conn
rc, err = dialer.Dial(network, server) rc, err = dialer.Dial(network, server)
if err != nil { if err != nil {
log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v", qname, server, dialer.Addr(), err) newServer := ups.SwitchIf(server)
log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v, switch to %s",
qname, server, dialer.Addr(), err, newServer)
server = newServer
continue continue
} }
defer rc.Close() defer rc.Close()
@ -173,7 +176,7 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
break break
} }
newServer := ups.Switch() newServer := ups.SwitchIf(server)
log.F("[dns] error in resolving %s, failed to exchange with server %v via %s: %v, switch to %s", log.F("[dns] error in resolving %s, failed to exchange with server %v via %s: %v, switch to %s",
qname, server, dialer.Addr(), err, newServer) qname, server, dialer.Addr(), err, newServer)

View File

@ -23,6 +23,14 @@ func (u *UpStream) Switch() string {
return u.servers[atomic.AddUint32(&u.index, 1)%uint32(len(u.servers))] return u.servers[atomic.AddUint32(&u.index, 1)%uint32(len(u.servers))]
} }
// SwitchIf switches to the next dns server if needed.
func (u *UpStream) SwitchIf(server string) string {
if u.Server() == server {
return u.Switch()
}
return u.Server()
}
// Len returns the number of dns servers. // Len returns the number of dns servers.
func (u *UpStream) Len() int { func (u *UpStream) Len() int {
return len(u.servers) return len(u.servers)

View File

@ -3,6 +3,7 @@ package proxy
import ( import (
"errors" "errors"
"net" "net"
"time"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
) )
@ -76,7 +77,7 @@ func dial(network, addr string, localIP net.IP) (net.Conn, error) {
la = &net.UDPAddr{IP: localIP} la = &net.UDPAddr{IP: localIP}
} }
dialer := &net.Dialer{LocalAddr: la} dialer := &net.Dialer{LocalAddr: la, Timeout: time.Second * 3}
c, err := dialer.Dial(network, addr) c, err := dialer.Dial(network, addr)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -27,7 +27,7 @@ type Forwarder struct {
handlers []StatusHandler handlers []StatusHandler
} }
// ForwarderFromURL parses `forward=` command value and returns a new forwarder // ForwarderFromURL parses `forward=` command value and returns a new forwarder.
func ForwarderFromURL(s, intface string) (f *Forwarder, err error) { func ForwarderFromURL(s, intface string) (f *Forwarder, err error) {
f = &Forwarder{} f = &Forwarder{}
@ -63,7 +63,7 @@ func ForwarderFromURL(s, intface string) (f *Forwarder, err error) {
return f, err return f, err
} }
// DirectForwarder returns a direct forwarder // DirectForwarder returns a direct forwarder.
func DirectForwarder(intface string) *Forwarder { func DirectForwarder(intface string) *Forwarder {
d, err := proxy.NewDirect(intface) d, err := proxy.NewDirect(intface)
if err != nil { if err != nil {
@ -91,12 +91,12 @@ func (f *Forwarder) parseOption(option string) error {
return err return err
} }
// Addr . // Addr returns the forwarder's addr.
func (f *Forwarder) Addr() string { func (f *Forwarder) Addr() string {
return f.addr return f.addr
} }
// Dial . // Dial dials to addr and returns conn.
func (f *Forwarder) Dial(network, addr string) (c net.Conn, err error) { func (f *Forwarder) Dial(network, addr string) (c net.Conn, err error) {
c, err = f.Dialer.Dial(network, addr) c, err = f.Dialer.Dial(network, addr)
if err != nil { if err != nil {
@ -106,12 +106,12 @@ func (f *Forwarder) Dial(network, addr string) (c net.Conn, err error) {
return c, err return c, err
} }
// Failures returns the failuer count of forwarder // Failures returns the failuer count of forwarder.
func (f *Forwarder) Failures() uint32 { func (f *Forwarder) Failures() uint32 {
return atomic.LoadUint32(&f.failures) return atomic.LoadUint32(&f.failures)
} }
// IncFailures increase the failuer count by 1 // IncFailures increase the failuer count by 1.
func (f *Forwarder) IncFailures() { func (f *Forwarder) IncFailures() {
failures := atomic.AddUint32(&f.failures, 1) failures := atomic.AddUint32(&f.failures, 1)
log.F("[forwarder] %s recorded %d failures, maxfailures: %d", f.addr, failures, f.MaxFailures()) log.F("[forwarder] %s recorded %d failures, maxfailures: %d", f.addr, failures, f.MaxFailures())
@ -122,12 +122,12 @@ func (f *Forwarder) IncFailures() {
} }
} }
// AddHandler adds a custom handler to handle the status change event // AddHandler adds a custom handler to handle the status change event.
func (f *Forwarder) AddHandler(h StatusHandler) { func (f *Forwarder) AddHandler(h StatusHandler) {
f.handlers = append(f.handlers, h) f.handlers = append(f.handlers, h)
} }
// Enable the forwarder // Enable the forwarder.
func (f *Forwarder) Enable() { func (f *Forwarder) Enable() {
if atomic.CompareAndSwapUint32(&f.disabled, 1, 0) { if atomic.CompareAndSwapUint32(&f.disabled, 1, 0) {
for _, h := range f.handlers { for _, h := range f.handlers {
@ -137,7 +137,7 @@ func (f *Forwarder) Enable() {
atomic.StoreUint32(&f.failures, 0) atomic.StoreUint32(&f.failures, 0)
} }
// Disable the forwarder // Disable the forwarder.
func (f *Forwarder) Disable() { func (f *Forwarder) Disable() {
if atomic.CompareAndSwapUint32(&f.disabled, 0, 1) { if atomic.CompareAndSwapUint32(&f.disabled, 0, 1) {
for _, h := range f.handlers { for _, h := range f.handlers {
@ -146,7 +146,7 @@ func (f *Forwarder) Disable() {
} }
} }
// Enabled returns the status of forwarder // Enabled returns the status of forwarder.
func (f *Forwarder) Enabled() bool { func (f *Forwarder) Enabled() bool {
return !isTrue(atomic.LoadUint32(&f.disabled)) return !isTrue(atomic.LoadUint32(&f.disabled))
} }
@ -155,32 +155,32 @@ func isTrue(n uint32) bool {
return n&1 == 1 return n&1 == 1
} }
// Priority returns the priority of forwarder // Priority returns the priority of forwarder.
func (f *Forwarder) Priority() uint32 { func (f *Forwarder) Priority() uint32 {
return atomic.LoadUint32(&f.priority) return atomic.LoadUint32(&f.priority)
} }
// SetPriority sets the priority of forwarder // SetPriority sets the priority of forwarder.
func (f *Forwarder) SetPriority(l uint32) { func (f *Forwarder) SetPriority(l uint32) {
atomic.StoreUint32(&f.priority, l) atomic.StoreUint32(&f.priority, l)
} }
// MaxFailures returns the maxFailures of forwarder // MaxFailures returns the maxFailures of forwarder.
func (f *Forwarder) MaxFailures() uint32 { func (f *Forwarder) MaxFailures() uint32 {
return atomic.LoadUint32(&f.maxFailures) return atomic.LoadUint32(&f.maxFailures)
} }
// SetMaxFailures sets the maxFailures of forwarder // SetMaxFailures sets the maxFailures of forwarder.
func (f *Forwarder) SetMaxFailures(l uint32) { func (f *Forwarder) SetMaxFailures(l uint32) {
atomic.StoreUint32(&f.maxFailures, l) atomic.StoreUint32(&f.maxFailures, l)
} }
// Latency returns the latency of forwarder // Latency returns the latency of forwarder.
func (f *Forwarder) Latency() int64 { func (f *Forwarder) Latency() int64 {
return atomic.LoadInt64(&f.latency) return atomic.LoadInt64(&f.latency)
} }
// SetLatency sets the latency of forwarder // SetLatency sets the latency of forwarder.
func (f *Forwarder) SetLatency(l int64) { func (f *Forwarder) SetLatency(l int64) {
atomic.StoreInt64(&f.latency, l) atomic.StoreInt64(&f.latency, l)
} }