diff --git a/main.go b/main.go index 772ae4c..c8d5fb3 100644 --- a/main.go +++ b/main.go @@ -26,15 +26,13 @@ func dialerFromConf() Dialer { forwarders = append(forwarders, forward) } - forwarder := NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration) - - return NewRuleDialer(conf.rules, forwarder) + return NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration) } func main() { confInit() - sDialer := dialerFromConf() + sDialer := NewRuleDialer(conf.rules, dialerFromConf()) for _, listen := range conf.Listen { local, err := ServerFromURL(listen, sDialer) @@ -60,13 +58,8 @@ func main() { } } - // test here - dns.AddAnswerHandler(func(domain, ip string) error { - if ip != "" { - logf("domain: %s, ip: %s\n", domain, ip) - } - return nil - }) + // add a handler to update proxy rules when a domain resolved + dns.AddAnswerHandler(sDialer.AddDomainIP) go dns.ListenAndServe() } diff --git a/rule.go b/rule.go index 9f87d32..10802b7 100644 --- a/rule.go +++ b/rule.go @@ -4,24 +4,20 @@ import ( "log" "net" "strings" + "sync" ) // RuleDialer . type RuleDialer struct { gDialer Dialer - domainMap map[string]Dialer - ipMap map[string]Dialer - cidrMap map[string]Dialer + domainMap sync.Map + ipMap sync.Map + cidrMap sync.Map } // NewRuleDialer . -func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer { - - if len(rules) == 0 { - return gDialer - } - +func NewRuleDialer(rules []*RuleConf, gDialer Dialer) *RuleDialer { rd := &RuleDialer{gDialer: gDialer} for _, r := range rules { @@ -40,19 +36,16 @@ func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer { sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration) - rd.domainMap = make(map[string]Dialer) for _, domain := range r.Domain { - rd.domainMap[domain] = sd + rd.domainMap.Store(domain, sd) } - rd.ipMap = make(map[string]Dialer) for _, ip := range r.IP { - rd.ipMap[ip] = sd + rd.ipMap.Store(ip, sd) } - rd.cidrMap = make(map[string]Dialer) for _, cidr := range r.CIDR { - rd.cidrMap[cidr] = sd + rd.cidrMap.Store(cidr, sd) } } @@ -74,19 +67,28 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer { // find ip if ip := net.ParseIP(host); ip != nil { // check ip - if d, ok := p.ipMap[ip.String()]; ok { - return d + if d, ok := p.ipMap.Load(ip.String()); ok { + return d.(Dialer) } + var ret Dialer // check cidr // TODO: do not parse cidr every time - for cidrStr, d := range p.cidrMap { - if _, net, err := net.ParseCIDR(cidrStr); err == nil { + p.cidrMap.Range(func(key, value interface{}) bool { + if _, net, err := net.ParseCIDR(key.(string)); err == nil { if net.Contains(ip) { - return d + ret = value.(Dialer) + return false } } + + return true + }) + + if ret != nil { + return ret } + } domainParts := strings.Split(host, ".") @@ -95,8 +97,8 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer { domain := strings.Join(domainParts[i:length], ".") // find in domainMap - if d, ok := p.domainMap[domain]; ok { - return d + if d, ok := p.domainMap.Load(domain); ok { + return d.(Dialer) } } @@ -106,3 +108,23 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer { func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) { return rd.NextDialer(addr).Dial(network, addr) } + +// AddDomainIP used to update ipMap rules according to domainMap rule +func (rd *RuleDialer) AddDomainIP(domain, ip string) error { + if ip != "" { + logf("domain: %s, ip: %s\n", domain, ip) + + domainParts := strings.Split(domain, ".") + length := len(domainParts) + for i := length - 2; i >= 0; i-- { + domain := strings.Join(domainParts[i:length], ".") + + // find in domainMap + if d, ok := rd.domainMap.Load(domain); ok { + rd.ipMap.Store(ip, d) + } + } + + } + return nil +}