diff --git a/strategy.go b/strategy.go index 6c2efdc..2852dcc 100644 --- a/strategy.go +++ b/strategy.go @@ -10,24 +10,26 @@ import ( ) // NewStrategyDialer returns a new Strategy Dialer -func NewStrategyDialer(strategy string, dialers []Dialer, website string, duration int) Dialer { - var dialer Dialer +func NewStrategyDialer(strategy string, dialers []Dialer, website string, interval int) Dialer { if len(dialers) == 0 { - dialer = Direct - } else if len(dialers) == 1 { + return Direct + } + + if len(dialers) == 1 { + return dialers[0] + } + + var dialer Dialer + switch strategy { + case "rr": + dialer = newRRDialer(dialers, website, interval) + logf("forward to remote servers in round robin mode.") + case "ha": + dialer = newHADialer(dialers, website, interval) + logf("forward to remote servers in high availability mode.") + default: + logf("not supported forward mode '%s', just use the first forward server.", conf.Strategy) dialer = dialers[0] - } else if len(dialers) > 1 { - switch strategy { - case "rr": - dialer = newRRDialer(dialers, website, duration) - logf("forward to remote servers in round robin mode.") - case "ha": - dialer = newHADialer(dialers, website, duration) - logf("forward to remote servers in high availability mode.") - default: - logf("not supported forward mode '%s', just use the first forward server.", conf.Strategy) - dialer = dialers[0] - } } return dialer @@ -42,15 +44,15 @@ type rrDialer struct { // for checking website string - duration int + interval int } // newRRDialer returns a new rrDialer -func newRRDialer(dialers []Dialer, website string, duration int) *rrDialer { +func newRRDialer(dialers []Dialer, website string, interval int) *rrDialer { rr := &rrDialer{dialers: dialers} rr.website = website - rr.duration = duration + rr.interval = interval for k := range dialers { rr.status.Store(k, true) @@ -104,7 +106,7 @@ func (rr *rrDialer) checkDialer(idx int) { d := rr.dialers[idx] for { - time.Sleep(time.Duration(rr.duration) * time.Second * time.Duration(retry>>1)) + time.Sleep(time.Duration(rr.interval) * time.Second * time.Duration(retry>>1)) retry <<= 1 if retry > 16 { @@ -154,7 +156,6 @@ func (ha *haDialer) Dial(network, addr string) (net.Conn, error) { d := ha.dialers[ha.idx] result, ok := ha.status.Load(ha.idx) - if ok && !result.(bool) { d = ha.NextDialer(addr) }