rule: use new concurrent sync.Map instead of old map

This commit is contained in:
nadoo 2017-08-23 18:58:24 +08:00
parent 65f33beaa3
commit 1b400d984e
2 changed files with 48 additions and 33 deletions

15
main.go
View File

@ -26,15 +26,13 @@ func dialerFromConf() Dialer {
forwarders = append(forwarders, forward) forwarders = append(forwarders, forward)
} }
forwarder := NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration) return NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration)
return NewRuleDialer(conf.rules, forwarder)
} }
func main() { func main() {
confInit() confInit()
sDialer := dialerFromConf() sDialer := NewRuleDialer(conf.rules, dialerFromConf())
for _, listen := range conf.Listen { for _, listen := range conf.Listen {
local, err := ServerFromURL(listen, sDialer) local, err := ServerFromURL(listen, sDialer)
@ -60,13 +58,8 @@ func main() {
} }
} }
// test here // add a handler to update proxy rules when a domain resolved
dns.AddAnswerHandler(func(domain, ip string) error { dns.AddAnswerHandler(sDialer.AddDomainIP)
if ip != "" {
logf("domain: %s, ip: %s\n", domain, ip)
}
return nil
})
go dns.ListenAndServe() go dns.ListenAndServe()
} }

66
rule.go
View File

@ -4,24 +4,20 @@ import (
"log" "log"
"net" "net"
"strings" "strings"
"sync"
) )
// RuleDialer . // RuleDialer .
type RuleDialer struct { type RuleDialer struct {
gDialer Dialer gDialer Dialer
domainMap map[string]Dialer domainMap sync.Map
ipMap map[string]Dialer ipMap sync.Map
cidrMap map[string]Dialer cidrMap sync.Map
} }
// NewRuleDialer . // NewRuleDialer .
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer { func NewRuleDialer(rules []*RuleConf, gDialer Dialer) *RuleDialer {
if len(rules) == 0 {
return gDialer
}
rd := &RuleDialer{gDialer: gDialer} rd := &RuleDialer{gDialer: gDialer}
for _, r := range rules { for _, r := range rules {
@ -40,19 +36,16 @@ func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer {
sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration) sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration)
rd.domainMap = make(map[string]Dialer)
for _, domain := range r.Domain { for _, domain := range r.Domain {
rd.domainMap[domain] = sd rd.domainMap.Store(domain, sd)
} }
rd.ipMap = make(map[string]Dialer)
for _, ip := range r.IP { for _, ip := range r.IP {
rd.ipMap[ip] = sd rd.ipMap.Store(ip, sd)
} }
rd.cidrMap = make(map[string]Dialer)
for _, cidr := range r.CIDR { for _, cidr := range r.CIDR {
rd.cidrMap[cidr] = sd rd.cidrMap.Store(cidr, sd)
} }
} }
@ -74,19 +67,28 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
// find ip // find ip
if ip := net.ParseIP(host); ip != nil { if ip := net.ParseIP(host); ip != nil {
// check ip // check ip
if d, ok := p.ipMap[ip.String()]; ok { if d, ok := p.ipMap.Load(ip.String()); ok {
return d return d.(Dialer)
} }
var ret Dialer
// check cidr // check cidr
// TODO: do not parse cidr every time // TODO: do not parse cidr every time
for cidrStr, d := range p.cidrMap { p.cidrMap.Range(func(key, value interface{}) bool {
if _, net, err := net.ParseCIDR(cidrStr); err == nil { if _, net, err := net.ParseCIDR(key.(string)); err == nil {
if net.Contains(ip) { if net.Contains(ip) {
return d ret = value.(Dialer)
return false
} }
} }
return true
})
if ret != nil {
return ret
} }
} }
domainParts := strings.Split(host, ".") domainParts := strings.Split(host, ".")
@ -95,8 +97,8 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
domain := strings.Join(domainParts[i:length], ".") domain := strings.Join(domainParts[i:length], ".")
// find in domainMap // find in domainMap
if d, ok := p.domainMap[domain]; ok { if d, ok := p.domainMap.Load(domain); ok {
return d return d.(Dialer)
} }
} }
@ -106,3 +108,23 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) { func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) {
return rd.NextDialer(addr).Dial(network, addr) return rd.NextDialer(addr).Dial(network, addr)
} }
// AddDomainIP used to update ipMap rules according to domainMap rule
func (rd *RuleDialer) AddDomainIP(domain, ip string) error {
if ip != "" {
logf("domain: %s, ip: %s\n", domain, ip)
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if d, ok := rd.domainMap.Load(domain); ok {
rd.ipMap.Store(ip, d)
}
}
}
return nil
}