diff --git a/rule/forward.go b/rule/forward.go index 2304220..db9e0bd 100644 --- a/rule/forward.go +++ b/rule/forward.go @@ -18,6 +18,7 @@ type StatusHandler func(*Forwarder) // Forwarder associates with a `-forward` command, usually a dialer or a chain of dialers. type Forwarder struct { proxy.Dialer + name string url string addr string priority uint32 @@ -30,12 +31,15 @@ type Forwarder struct { } // ForwarderFromURL parses `forward=` command value and returns a new forwarder. -func ForwarderFromURL(s, intface string, dialTimeout, relayTimeout time.Duration) (f *Forwarder, err error) { - f = &Forwarder{url: s} +func ForwarderFromURL(s, name string, intface string, dialTimeout, relayTimeout time.Duration) (f *Forwarder, err error) { + f = &Forwarder{url: s, name: name} ss := strings.Split(s, "#") if len(ss) > 1 { err = f.parseOption(ss[1]) + if err != nil { + return nil, err + } } iface := intface diff --git a/rule/group.go b/rule/group.go index fc9607b..684acb7 100644 --- a/rule/group.go +++ b/rule/group.go @@ -5,7 +5,6 @@ import ( "hash/fnv" "net" "net/url" - "path/filepath" "sort" "strings" "sync" @@ -36,16 +35,17 @@ type FwdrGroup struct { } // NewFwdrGroup returns a new forward group. -func NewFwdrGroup(rulePath string, s []string, c *Strategy) *FwdrGroup { +func NewFwdrGroup(name string, s []string, c *Strategy) *FwdrGroup { var fwdrs []*Forwarder for _, chain := range s { - fwdr, err := ForwarderFromURL(chain, c.IntFace, + fwdr, err := ForwarderFromURL(chain, name, c.IntFace, time.Duration(c.DialTimeout)*time.Second, time.Duration(c.RelayTimeout)*time.Second) if err != nil { log.Fatal(err) } fwdr.SetMaxFailures(uint32(c.MaxFailures)) fwdrs = append(fwdrs, fwdr) + log.F("[group] %s: has forwarder %s", name, fwdr.Addr()) } if len(fwdrs) == 0 { @@ -57,9 +57,9 @@ func NewFwdrGroup(rulePath string, s []string, c *Strategy) *FwdrGroup { } fwdrs = append(fwdrs, direct) c.Strategy = "rr" + log.F("[group] %s: has DIRECT forwarder", name) } - name := strings.TrimSuffix(filepath.Base(rulePath), filepath.Ext(rulePath)) return newFwdrGroup(name, fwdrs, c) } @@ -101,6 +101,10 @@ func newFwdrGroup(name string, fwdrs []*Forwarder, c *Strategy) *FwdrGroup { return p } +func (p *FwdrGroup) Name() string { + return p.name +} + // Dial connects to the address addr on the network net. func (p *FwdrGroup) Dial(network, addr string) (net.Conn, proxy.Dialer, error) { nd := p.NextDialer(addr) @@ -127,6 +131,17 @@ func (p *FwdrGroup) NextDialer(dstAddr string) proxy.Dialer { return p.next(dstAddr) } +// Record records result while using the dialer from proxy. +func (p *FwdrGroup) Record(dialer proxy.Dialer, success bool) { + if fwdr, ok := dialer.(*Forwarder); ok { + if !success { + fwdr.IncFailures() + return + } + fwdr.Enable() + } +} + // Priority returns the active priority of dialer. func (p *FwdrGroup) Priority() uint32 { return atomic.LoadUint32(&p.priority) } diff --git a/rule/proxy.go b/rule/proxy.go index 495b83e..a3e18c1 100644 --- a/rule/proxy.go +++ b/rule/proxy.go @@ -3,41 +3,49 @@ package rule import ( "net" "net/netip" + "path/filepath" "strings" - "sync" "github.com/nadoo/glider/pkg/log" "github.com/nadoo/glider/proxy" ) +type Rule struct { + name string + forwarders *FwdrGroup + domains []string + ips []netip.Addr + cidrs []netip.Prefix +} + // Proxy implements the proxy.Proxy interface with rule support. type Proxy struct { - main *FwdrGroup - all []*FwdrGroup - domainMap sync.Map - ipMap sync.Map - cidrMap sync.Map + main *FwdrGroup + rules []*Rule } // NewProxy returns a new rule proxy. func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) *Proxy { - rd := &Proxy{main: NewFwdrGroup("main", mainForwarders, mainStrategy)} + proxy := &Proxy{main: NewFwdrGroup("main", mainForwarders, mainStrategy)} for _, r := range rules { - group := NewFwdrGroup(r.RulePath, r.Forward, &r.Strategy) - rd.all = append(rd.all, group) + name := strings.TrimSuffix(filepath.Base(r.RulePath), filepath.Ext(r.RulePath)) + forwarders := NewFwdrGroup(name, r.Forward, &r.Strategy) + rule := &Rule{name: name, forwarders: forwarders} for _, domain := range r.Domain { - rd.domainMap.Store(strings.ToLower(domain), group) + rule.domains = append(rule.domains, strings.ToLower(domain)) + log.F("[rule] %s: has domain rule for %s", name, domain) } for _, s := range r.IP { ip, err := netip.ParseAddr(s) if err != nil { - log.F("[rule] parse ip error: %s", err) + log.F("[rule] %s: parse ip error: %s", name, err) continue } - rd.ipMap.Store(ip, group) + rule.ips = append(rule.ips, ip) + log.F("[rule] %s: has IP rule for %s", name, ip.String()) } for _, s := range r.CIDR { @@ -46,25 +54,75 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) log.F("[rule] parse cidr error: %s", err) continue } - rd.cidrMap.Store(cidr, group) + rule.cidrs = append(rule.cidrs, cidr) + log.F("[rule] %s: has CIDR rule for %s", name, cidr.String()) } + + proxy.rules = append(proxy.rules, rule) } - direct := NewFwdrGroup("", nil, mainStrategy) - rd.domainMap.Store("direct", direct) + direct := NewFwdrGroup("direct", nil, mainStrategy) + directRule := &Rule{name: "direct", forwarders: direct, domains: []string{"direct"}} // if there's any forwarder defined in main config, make sure they will be accessed directly. if len(mainForwarders) > 0 { - for _, f := range rd.main.fwdrs { + for _, f := range proxy.main.fwdrs { addr := strings.Split(f.addr, ",")[0] host, _, _ := net.SplitHostPort(addr) if _, err := netip.ParseAddr(host); err != nil { - rd.domainMap.Store(strings.ToLower(host), direct) + directRule.domains = append(directRule.domains, strings.ToLower(host)) + log.F("[rule] direct: has domain rule for %s", host) } } } - return rd + proxy.rules = append(proxy.rules, directRule) + + return proxy +} + +func (r *Rule) checkDomain(host string) bool { + host = strings.ToLower(host) + for i := len(host); i != -1; { + i = strings.LastIndexByte(host[:i], '.') + for _, domain := range r.domains { + if domain == host[i+1:] { + return true + } + } + } + + return false +} + +func (r *Rule) checkIP(host string) bool { + if ip, err := netip.ParseAddr(host); err == nil { + // check ip + for _, addr := range r.ips { + if addr == ip { + return true + } + } + + // check cidr + for _, prefix := range r.cidrs { + if prefix.Contains(ip) { + return true + } + } + } + + return false +} + +// checkMatch checks whether the given dstAddr matches the rules +func (r *Rule) checkMatch(dstAddr string) bool { + host, _, err := net.SplitHostPort(dstAddr) + if err != nil { + return false + } + + return r.checkIP(host) || r.checkDomain(host) } // Dial dials to targer addr and return a conn. @@ -79,38 +137,9 @@ func (p *Proxy) DialUDP(network, addr string) (pc net.PacketConn, dialer proxy.U // findDialer returns a dialer by dstAddr according to rule. func (p *Proxy) findDialer(dstAddr string) *FwdrGroup { - host, _, err := net.SplitHostPort(dstAddr) - if err != nil { - return p.main - } - - if ip, err := netip.ParseAddr(host); err == nil { - // check ip - if proxy, ok := p.ipMap.Load(ip); ok { - return proxy.(*FwdrGroup) - } - - // check cidr - var ret *FwdrGroup - p.cidrMap.Range(func(key, value any) bool { - if key.(netip.Prefix).Contains(ip) { - ret = value.(*FwdrGroup) - return false - } - return true - }) - - if ret != nil { - return ret - } - } - - // check host - host = strings.ToLower(host) - for i := len(host); i != -1; { - i = strings.LastIndexByte(host[:i], '.') - if proxy, ok := p.domainMap.Load(host[i+1:]); ok { - return proxy.(*FwdrGroup) + for _, rule := range p.rules { + if rule.checkMatch(dstAddr) { + return rule.forwarders } } @@ -135,14 +164,12 @@ func (p *Proxy) Record(dialer proxy.Dialer, success bool) { // AddDomainIP used to update ipMap rules according to domainMap rule. func (p *Proxy) AddDomainIP(domain string, ip netip.Addr) error { - domain = strings.ToLower(domain) - for i := len(domain); i != -1; { - i = strings.LastIndexByte(domain[:i], '.') - if dialer, ok := p.domainMap.Load(domain[i+1:]); ok { - p.ipMap.Store(ip, dialer) - // log.F("[rule] update map: %s/%s based on rule: domain=%s\n", domain, ip, domain[i+1:]) + for _, rule := range p.rules { + if rule.checkDomain(domain) { + rule.ips = append(rule.ips, ip) } } + return nil } @@ -150,7 +177,7 @@ func (p *Proxy) AddDomainIP(domain string, ip netip.Addr) error { func (p *Proxy) Check() { p.main.Check() - for _, fwdrGroup := range p.all { - fwdrGroup.Check() + for _, rule := range p.rules { + rule.forwarders.Check() } }