strategy: optimize code

This commit is contained in:
nadoo 2018-01-13 20:08:49 +08:00
parent 045531d68d
commit 5b2518fac0

View File

@ -10,25 +10,27 @@ import (
) )
// NewStrategyDialer returns a new Strategy Dialer // NewStrategyDialer returns a new Strategy Dialer
func NewStrategyDialer(strategy string, dialers []Dialer, website string, duration int) Dialer { func NewStrategyDialer(strategy string, dialers []Dialer, website string, interval int) Dialer {
var dialer Dialer
if len(dialers) == 0 { if len(dialers) == 0 {
dialer = Direct return Direct
} else if len(dialers) == 1 { }
dialer = dialers[0]
} else if len(dialers) > 1 { if len(dialers) == 1 {
return dialers[0]
}
var dialer Dialer
switch strategy { switch strategy {
case "rr": case "rr":
dialer = newRRDialer(dialers, website, duration) dialer = newRRDialer(dialers, website, interval)
logf("forward to remote servers in round robin mode.") logf("forward to remote servers in round robin mode.")
case "ha": case "ha":
dialer = newHADialer(dialers, website, duration) dialer = newHADialer(dialers, website, interval)
logf("forward to remote servers in high availability mode.") logf("forward to remote servers in high availability mode.")
default: default:
logf("not supported forward mode '%s', just use the first forward server.", conf.Strategy) logf("not supported forward mode '%s', just use the first forward server.", conf.Strategy)
dialer = dialers[0] dialer = dialers[0]
} }
}
return dialer return dialer
} }
@ -42,15 +44,15 @@ type rrDialer struct {
// for checking // for checking
website string website string
duration int interval int
} }
// newRRDialer returns a new rrDialer // newRRDialer returns a new rrDialer
func newRRDialer(dialers []Dialer, website string, duration int) *rrDialer { func newRRDialer(dialers []Dialer, website string, interval int) *rrDialer {
rr := &rrDialer{dialers: dialers} rr := &rrDialer{dialers: dialers}
rr.website = website rr.website = website
rr.duration = duration rr.interval = interval
for k := range dialers { for k := range dialers {
rr.status.Store(k, true) rr.status.Store(k, true)
@ -104,7 +106,7 @@ func (rr *rrDialer) checkDialer(idx int) {
d := rr.dialers[idx] d := rr.dialers[idx]
for { for {
time.Sleep(time.Duration(rr.duration) * time.Second * time.Duration(retry>>1)) time.Sleep(time.Duration(rr.interval) * time.Second * time.Duration(retry>>1))
retry <<= 1 retry <<= 1
if retry > 16 { if retry > 16 {
@ -154,7 +156,6 @@ func (ha *haDialer) Dial(network, addr string) (net.Conn, error) {
d := ha.dialers[ha.idx] d := ha.dialers[ha.idx]
result, ok := ha.status.Load(ha.idx) result, ok := ha.status.Load(ha.idx)
if ok && !result.(bool) { if ok && !result.(bool) {
d = ha.NextDialer(addr) d = ha.NextDialer(addr)
} }