rule: use new concurrent sync.Map instead of old map

This commit is contained in:
nadoo 2017-08-23 18:58:24 +08:00
parent 65f33beaa3
commit 1b400d984e
2 changed files with 48 additions and 33 deletions

15
main.go
View File

@ -26,15 +26,13 @@ func dialerFromConf() Dialer {
forwarders = append(forwarders, forward)
}
forwarder := NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration)
return NewRuleDialer(conf.rules, forwarder)
return NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration)
}
func main() {
confInit()
sDialer := dialerFromConf()
sDialer := NewRuleDialer(conf.rules, dialerFromConf())
for _, listen := range conf.Listen {
local, err := ServerFromURL(listen, sDialer)
@ -60,13 +58,8 @@ func main() {
}
}
// test here
dns.AddAnswerHandler(func(domain, ip string) error {
if ip != "" {
logf("domain: %s, ip: %s\n", domain, ip)
}
return nil
})
// add a handler to update proxy rules when a domain resolved
dns.AddAnswerHandler(sDialer.AddDomainIP)
go dns.ListenAndServe()
}

66
rule.go
View File

@ -4,24 +4,20 @@ import (
"log"
"net"
"strings"
"sync"
)
// RuleDialer .
type RuleDialer struct {
gDialer Dialer
domainMap map[string]Dialer
ipMap map[string]Dialer
cidrMap map[string]Dialer
domainMap sync.Map
ipMap sync.Map
cidrMap sync.Map
}
// NewRuleDialer .
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer {
if len(rules) == 0 {
return gDialer
}
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) *RuleDialer {
rd := &RuleDialer{gDialer: gDialer}
for _, r := range rules {
@ -40,19 +36,16 @@ func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer {
sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration)
rd.domainMap = make(map[string]Dialer)
for _, domain := range r.Domain {
rd.domainMap[domain] = sd
rd.domainMap.Store(domain, sd)
}
rd.ipMap = make(map[string]Dialer)
for _, ip := range r.IP {
rd.ipMap[ip] = sd
rd.ipMap.Store(ip, sd)
}
rd.cidrMap = make(map[string]Dialer)
for _, cidr := range r.CIDR {
rd.cidrMap[cidr] = sd
rd.cidrMap.Store(cidr, sd)
}
}
@ -74,19 +67,28 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
// find ip
if ip := net.ParseIP(host); ip != nil {
// check ip
if d, ok := p.ipMap[ip.String()]; ok {
return d
if d, ok := p.ipMap.Load(ip.String()); ok {
return d.(Dialer)
}
var ret Dialer
// check cidr
// TODO: do not parse cidr every time
for cidrStr, d := range p.cidrMap {
if _, net, err := net.ParseCIDR(cidrStr); err == nil {
p.cidrMap.Range(func(key, value interface{}) bool {
if _, net, err := net.ParseCIDR(key.(string)); err == nil {
if net.Contains(ip) {
return d
ret = value.(Dialer)
return false
}
}
return true
})
if ret != nil {
return ret
}
}
domainParts := strings.Split(host, ".")
@ -95,8 +97,8 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if d, ok := p.domainMap[domain]; ok {
return d
if d, ok := p.domainMap.Load(domain); ok {
return d.(Dialer)
}
}
@ -106,3 +108,23 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) {
return rd.NextDialer(addr).Dial(network, addr)
}
// AddDomainIP used to update ipMap rules according to domainMap rule
func (rd *RuleDialer) AddDomainIP(domain, ip string) error {
if ip != "" {
logf("domain: %s, ip: %s\n", domain, ip)
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if d, ok := rd.domainMap.Load(domain); ok {
rd.ipMap.Store(ip, d)
}
}
}
return nil
}