mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 01:15:41 +08:00
rule: use new concurrent sync.Map instead of old map
This commit is contained in:
parent
65f33beaa3
commit
1b400d984e
15
main.go
15
main.go
@ -26,15 +26,13 @@ func dialerFromConf() Dialer {
|
||||
forwarders = append(forwarders, forward)
|
||||
}
|
||||
|
||||
forwarder := NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration)
|
||||
|
||||
return NewRuleDialer(conf.rules, forwarder)
|
||||
return NewStrategyDialer(conf.Strategy, forwarders, conf.CheckWebSite, conf.CheckDuration)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
confInit()
|
||||
sDialer := dialerFromConf()
|
||||
sDialer := NewRuleDialer(conf.rules, dialerFromConf())
|
||||
|
||||
for _, listen := range conf.Listen {
|
||||
local, err := ServerFromURL(listen, sDialer)
|
||||
@ -60,13 +58,8 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// test here
|
||||
dns.AddAnswerHandler(func(domain, ip string) error {
|
||||
if ip != "" {
|
||||
logf("domain: %s, ip: %s\n", domain, ip)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
// add a handler to update proxy rules when a domain resolved
|
||||
dns.AddAnswerHandler(sDialer.AddDomainIP)
|
||||
|
||||
go dns.ListenAndServe()
|
||||
}
|
||||
|
66
rule.go
66
rule.go
@ -4,24 +4,20 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// RuleDialer .
|
||||
type RuleDialer struct {
|
||||
gDialer Dialer
|
||||
|
||||
domainMap map[string]Dialer
|
||||
ipMap map[string]Dialer
|
||||
cidrMap map[string]Dialer
|
||||
domainMap sync.Map
|
||||
ipMap sync.Map
|
||||
cidrMap sync.Map
|
||||
}
|
||||
|
||||
// NewRuleDialer .
|
||||
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer {
|
||||
|
||||
if len(rules) == 0 {
|
||||
return gDialer
|
||||
}
|
||||
|
||||
func NewRuleDialer(rules []*RuleConf, gDialer Dialer) *RuleDialer {
|
||||
rd := &RuleDialer{gDialer: gDialer}
|
||||
|
||||
for _, r := range rules {
|
||||
@ -40,19 +36,16 @@ func NewRuleDialer(rules []*RuleConf, gDialer Dialer) Dialer {
|
||||
|
||||
sd := NewStrategyDialer(r.Strategy, forwarders, r.CheckWebSite, r.CheckDuration)
|
||||
|
||||
rd.domainMap = make(map[string]Dialer)
|
||||
for _, domain := range r.Domain {
|
||||
rd.domainMap[domain] = sd
|
||||
rd.domainMap.Store(domain, sd)
|
||||
}
|
||||
|
||||
rd.ipMap = make(map[string]Dialer)
|
||||
for _, ip := range r.IP {
|
||||
rd.ipMap[ip] = sd
|
||||
rd.ipMap.Store(ip, sd)
|
||||
}
|
||||
|
||||
rd.cidrMap = make(map[string]Dialer)
|
||||
for _, cidr := range r.CIDR {
|
||||
rd.cidrMap[cidr] = sd
|
||||
rd.cidrMap.Store(cidr, sd)
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,19 +67,28 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
|
||||
// find ip
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
// check ip
|
||||
if d, ok := p.ipMap[ip.String()]; ok {
|
||||
return d
|
||||
if d, ok := p.ipMap.Load(ip.String()); ok {
|
||||
return d.(Dialer)
|
||||
}
|
||||
|
||||
var ret Dialer
|
||||
// check cidr
|
||||
// TODO: do not parse cidr every time
|
||||
for cidrStr, d := range p.cidrMap {
|
||||
if _, net, err := net.ParseCIDR(cidrStr); err == nil {
|
||||
p.cidrMap.Range(func(key, value interface{}) bool {
|
||||
if _, net, err := net.ParseCIDR(key.(string)); err == nil {
|
||||
if net.Contains(ip) {
|
||||
return d
|
||||
ret = value.(Dialer)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if ret != nil {
|
||||
return ret
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
domainParts := strings.Split(host, ".")
|
||||
@ -95,8 +97,8 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
|
||||
domain := strings.Join(domainParts[i:length], ".")
|
||||
|
||||
// find in domainMap
|
||||
if d, ok := p.domainMap[domain]; ok {
|
||||
return d
|
||||
if d, ok := p.domainMap.Load(domain); ok {
|
||||
return d.(Dialer)
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,3 +108,23 @@ func (p *RuleDialer) NextDialer(dstAddr string) Dialer {
|
||||
func (rd *RuleDialer) Dial(network, addr string) (net.Conn, error) {
|
||||
return rd.NextDialer(addr).Dial(network, addr)
|
||||
}
|
||||
|
||||
// AddDomainIP used to update ipMap rules according to domainMap rule
|
||||
func (rd *RuleDialer) AddDomainIP(domain, ip string) error {
|
||||
if ip != "" {
|
||||
logf("domain: %s, ip: %s\n", domain, ip)
|
||||
|
||||
domainParts := strings.Split(domain, ".")
|
||||
length := len(domainParts)
|
||||
for i := length - 2; i >= 0; i-- {
|
||||
domain := strings.Join(domainParts[i:length], ".")
|
||||
|
||||
// find in domainMap
|
||||
if d, ok := rd.domainMap.Load(domain); ok {
|
||||
rd.ipMap.Store(ip, d)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user