Add rule ordering based on file names

This commit is contained in:
Ivan ROGER 2022-11-03 16:18:35 +01:00
parent 03157367ca
commit 09c0202d4c
3 changed files with 110 additions and 64 deletions

View File

@ -18,6 +18,7 @@ type StatusHandler func(*Forwarder)
// Forwarder associates with a `-forward` command, usually a dialer or a chain of dialers. // Forwarder associates with a `-forward` command, usually a dialer or a chain of dialers.
type Forwarder struct { type Forwarder struct {
proxy.Dialer proxy.Dialer
name string
url string url string
addr string addr string
priority uint32 priority uint32
@ -30,12 +31,15 @@ type Forwarder struct {
} }
// ForwarderFromURL parses `forward=` command value and returns a new forwarder. // ForwarderFromURL parses `forward=` command value and returns a new forwarder.
func ForwarderFromURL(s, intface string, dialTimeout, relayTimeout time.Duration) (f *Forwarder, err error) { func ForwarderFromURL(s, name string, intface string, dialTimeout, relayTimeout time.Duration) (f *Forwarder, err error) {
f = &Forwarder{url: s} f = &Forwarder{url: s, name: name}
ss := strings.Split(s, "#") ss := strings.Split(s, "#")
if len(ss) > 1 { if len(ss) > 1 {
err = f.parseOption(ss[1]) err = f.parseOption(ss[1])
if err != nil {
return nil, err
}
} }
iface := intface iface := intface

View File

@ -5,7 +5,6 @@ import (
"hash/fnv" "hash/fnv"
"net" "net"
"net/url" "net/url"
"path/filepath"
"sort" "sort"
"strings" "strings"
"sync" "sync"
@ -36,16 +35,17 @@ type FwdrGroup struct {
} }
// NewFwdrGroup returns a new forward group. // NewFwdrGroup returns a new forward group.
func NewFwdrGroup(rulePath string, s []string, c *Strategy) *FwdrGroup { func NewFwdrGroup(name string, s []string, c *Strategy) *FwdrGroup {
var fwdrs []*Forwarder var fwdrs []*Forwarder
for _, chain := range s { for _, chain := range s {
fwdr, err := ForwarderFromURL(chain, c.IntFace, fwdr, err := ForwarderFromURL(chain, name, c.IntFace,
time.Duration(c.DialTimeout)*time.Second, time.Duration(c.RelayTimeout)*time.Second) time.Duration(c.DialTimeout)*time.Second, time.Duration(c.RelayTimeout)*time.Second)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
fwdr.SetMaxFailures(uint32(c.MaxFailures)) fwdr.SetMaxFailures(uint32(c.MaxFailures))
fwdrs = append(fwdrs, fwdr) fwdrs = append(fwdrs, fwdr)
log.F("[group] %s: has forwarder %s", name, fwdr.Addr())
} }
if len(fwdrs) == 0 { if len(fwdrs) == 0 {
@ -57,9 +57,9 @@ func NewFwdrGroup(rulePath string, s []string, c *Strategy) *FwdrGroup {
} }
fwdrs = append(fwdrs, direct) fwdrs = append(fwdrs, direct)
c.Strategy = "rr" c.Strategy = "rr"
log.F("[group] %s: has DIRECT forwarder", name)
} }
name := strings.TrimSuffix(filepath.Base(rulePath), filepath.Ext(rulePath))
return newFwdrGroup(name, fwdrs, c) return newFwdrGroup(name, fwdrs, c)
} }
@ -101,6 +101,10 @@ func newFwdrGroup(name string, fwdrs []*Forwarder, c *Strategy) *FwdrGroup {
return p return p
} }
func (p *FwdrGroup) Name() string {
return p.name
}
// Dial connects to the address addr on the network net. // Dial connects to the address addr on the network net.
func (p *FwdrGroup) Dial(network, addr string) (net.Conn, proxy.Dialer, error) { func (p *FwdrGroup) Dial(network, addr string) (net.Conn, proxy.Dialer, error) {
nd := p.NextDialer(addr) nd := p.NextDialer(addr)
@ -127,6 +131,17 @@ func (p *FwdrGroup) NextDialer(dstAddr string) proxy.Dialer {
return p.next(dstAddr) return p.next(dstAddr)
} }
// Record records result while using the dialer from proxy.
func (p *FwdrGroup) Record(dialer proxy.Dialer, success bool) {
if fwdr, ok := dialer.(*Forwarder); ok {
if !success {
fwdr.IncFailures()
return
}
fwdr.Enable()
}
}
// Priority returns the active priority of dialer. // Priority returns the active priority of dialer.
func (p *FwdrGroup) Priority() uint32 { return atomic.LoadUint32(&p.priority) } func (p *FwdrGroup) Priority() uint32 { return atomic.LoadUint32(&p.priority) }

View File

@ -3,41 +3,49 @@ package rule
import ( import (
"net" "net"
"net/netip" "net/netip"
"path/filepath"
"strings" "strings"
"sync"
"github.com/nadoo/glider/pkg/log" "github.com/nadoo/glider/pkg/log"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
type Rule struct {
name string
forwarders *FwdrGroup
domains []string
ips []netip.Addr
cidrs []netip.Prefix
}
// Proxy implements the proxy.Proxy interface with rule support. // Proxy implements the proxy.Proxy interface with rule support.
type Proxy struct { type Proxy struct {
main *FwdrGroup main *FwdrGroup
all []*FwdrGroup rules []*Rule
domainMap sync.Map
ipMap sync.Map
cidrMap sync.Map
} }
// NewProxy returns a new rule proxy. // NewProxy returns a new rule proxy.
func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) *Proxy { func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) *Proxy {
rd := &Proxy{main: NewFwdrGroup("main", mainForwarders, mainStrategy)} proxy := &Proxy{main: NewFwdrGroup("main", mainForwarders, mainStrategy)}
for _, r := range rules { for _, r := range rules {
group := NewFwdrGroup(r.RulePath, r.Forward, &r.Strategy) name := strings.TrimSuffix(filepath.Base(r.RulePath), filepath.Ext(r.RulePath))
rd.all = append(rd.all, group) forwarders := NewFwdrGroup(name, r.Forward, &r.Strategy)
rule := &Rule{name: name, forwarders: forwarders}
for _, domain := range r.Domain { for _, domain := range r.Domain {
rd.domainMap.Store(strings.ToLower(domain), group) rule.domains = append(rule.domains, strings.ToLower(domain))
log.F("[rule] %s: has domain rule for %s", name, domain)
} }
for _, s := range r.IP { for _, s := range r.IP {
ip, err := netip.ParseAddr(s) ip, err := netip.ParseAddr(s)
if err != nil { if err != nil {
log.F("[rule] parse ip error: %s", err) log.F("[rule] %s: parse ip error: %s", name, err)
continue continue
} }
rd.ipMap.Store(ip, group) rule.ips = append(rule.ips, ip)
log.F("[rule] %s: has IP rule for %s", name, ip.String())
} }
for _, s := range r.CIDR { for _, s := range r.CIDR {
@ -46,25 +54,75 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
log.F("[rule] parse cidr error: %s", err) log.F("[rule] parse cidr error: %s", err)
continue continue
} }
rd.cidrMap.Store(cidr, group) rule.cidrs = append(rule.cidrs, cidr)
} log.F("[rule] %s: has CIDR rule for %s", name, cidr.String())
} }
direct := NewFwdrGroup("", nil, mainStrategy) proxy.rules = append(proxy.rules, rule)
rd.domainMap.Store("direct", direct) }
direct := NewFwdrGroup("direct", nil, mainStrategy)
directRule := &Rule{name: "direct", forwarders: direct, domains: []string{"direct"}}
// if there's any forwarder defined in main config, make sure they will be accessed directly. // if there's any forwarder defined in main config, make sure they will be accessed directly.
if len(mainForwarders) > 0 { if len(mainForwarders) > 0 {
for _, f := range rd.main.fwdrs { for _, f := range proxy.main.fwdrs {
addr := strings.Split(f.addr, ",")[0] addr := strings.Split(f.addr, ",")[0]
host, _, _ := net.SplitHostPort(addr) host, _, _ := net.SplitHostPort(addr)
if _, err := netip.ParseAddr(host); err != nil { if _, err := netip.ParseAddr(host); err != nil {
rd.domainMap.Store(strings.ToLower(host), direct) directRule.domains = append(directRule.domains, strings.ToLower(host))
log.F("[rule] direct: has domain rule for %s", host)
} }
} }
} }
return rd proxy.rules = append(proxy.rules, directRule)
return proxy
}
func (r *Rule) checkDomain(host string) bool {
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
for _, domain := range r.domains {
if domain == host[i+1:] {
return true
}
}
}
return false
}
func (r *Rule) checkIP(host string) bool {
if ip, err := netip.ParseAddr(host); err == nil {
// check ip
for _, addr := range r.ips {
if addr == ip {
return true
}
}
// check cidr
for _, prefix := range r.cidrs {
if prefix.Contains(ip) {
return true
}
}
}
return false
}
// checkMatch checks whether the given dstAddr matches the rules
func (r *Rule) checkMatch(dstAddr string) bool {
host, _, err := net.SplitHostPort(dstAddr)
if err != nil {
return false
}
return r.checkIP(host) || r.checkDomain(host)
} }
// Dial dials to targer addr and return a conn. // Dial dials to targer addr and return a conn.
@ -79,38 +137,9 @@ func (p *Proxy) DialUDP(network, addr string) (pc net.PacketConn, dialer proxy.U
// findDialer returns a dialer by dstAddr according to rule. // findDialer returns a dialer by dstAddr according to rule.
func (p *Proxy) findDialer(dstAddr string) *FwdrGroup { func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
host, _, err := net.SplitHostPort(dstAddr) for _, rule := range p.rules {
if err != nil { if rule.checkMatch(dstAddr) {
return p.main return rule.forwarders
}
if ip, err := netip.ParseAddr(host); err == nil {
// check ip
if proxy, ok := p.ipMap.Load(ip); ok {
return proxy.(*FwdrGroup)
}
// check cidr
var ret *FwdrGroup
p.cidrMap.Range(func(key, value any) bool {
if key.(netip.Prefix).Contains(ip) {
ret = value.(*FwdrGroup)
return false
}
return true
})
if ret != nil {
return ret
}
}
// check host
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
return proxy.(*FwdrGroup)
} }
} }
@ -135,14 +164,12 @@ func (p *Proxy) Record(dialer proxy.Dialer, success bool) {
// AddDomainIP used to update ipMap rules according to domainMap rule. // AddDomainIP used to update ipMap rules according to domainMap rule.
func (p *Proxy) AddDomainIP(domain string, ip netip.Addr) error { func (p *Proxy) AddDomainIP(domain string, ip netip.Addr) error {
domain = strings.ToLower(domain) for _, rule := range p.rules {
for i := len(domain); i != -1; { if rule.checkDomain(domain) {
i = strings.LastIndexByte(domain[:i], '.') rule.ips = append(rule.ips, ip)
if dialer, ok := p.domainMap.Load(domain[i+1:]); ok {
p.ipMap.Store(ip, dialer)
// log.F("[rule] update map: %s/%s based on rule: domain=%s\n", domain, ip, domain[i+1:])
} }
} }
return nil return nil
} }
@ -150,7 +177,7 @@ func (p *Proxy) AddDomainIP(domain string, ip netip.Addr) error {
func (p *Proxy) Check() { func (p *Proxy) Check() {
p.main.Check() p.main.Check()
for _, fwdrGroup := range p.all { for _, rule := range p.rules {
fwdrGroup.Check() rule.forwarders.Check()
} }
} }