From 7e7c7553ccfcc5fb5d2c6125f49807f4136f9bb1 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Fri, 28 Jan 2022 23:35:29 +0800 Subject: [PATCH] proxy: improve addr handling --- config.go | 7 ++----- pkg/log/log.go | 16 ++++++++------- pkg/socks/socks.go | 9 ++++---- proxy/direct.go | 5 +++-- proxy/redir/redir_linux.go | 10 ++++++--- proxy/socks5/client.go | 37 ++++----------------------------- proxy/socks5/server.go | 3 ++- proxy/vless/addr.go | 10 ++++----- proxy/vmess/addr.go | 15 +++++++------- rule/proxy.go | 42 +++++++++++++++++++++----------------- service/dhcpd/dhcpd.go | 28 +++++++++++++++++-------- service/dhcpd/pool.go | 26 +++++++++++------------ 12 files changed, 99 insertions(+), 109 deletions(-) diff --git a/config.go b/config.go index 798ca8c..c34ca34 100644 --- a/config.go +++ b/config.go @@ -87,11 +87,8 @@ func parseConfig() *Config { os.Exit(-1) } - // setup a log func - if conf.Verbose { - log.SetFlags(conf.LogFlags) - log.F = log.Debugf - } + // setup logger + log.Set(conf.Verbose, conf.LogFlags) if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 { // flag.Usage() diff --git a/pkg/log/log.go b/pkg/log/log.go index 31937a6..65536f4 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -5,17 +5,19 @@ import ( stdlog "log" ) -// F is the main log function. -var F = func(string, ...any) {} +var verbose = false -// SetFlags sets the output flags for the logger. -func SetFlags(flag int) { +// Set sets the logger's verbose mode and output flags. +func Set(v bool, flag int) { + verbose = v stdlog.SetFlags(flag) } -// Debugf prints debug log. -func Debugf(f string, v ...any) { - stdlog.Output(2, fmt.Sprintf(f, v...)) +// F prints debug log. +func F(f string, v ...any) { + if verbose { + stdlog.Output(2, fmt.Sprintf(f, v...)) + } } // Print prints log. diff --git a/pkg/socks/socks.go b/pkg/socks/socks.go index d62df84..448ccd1 100644 --- a/pkg/socks/socks.go +++ b/pkg/socks/socks.go @@ -4,6 +4,7 @@ import ( "errors" "io" "net" + "net/netip" "strconv" ) @@ -139,16 +140,16 @@ func ParseAddr(s string) Addr { if err != nil { return nil } - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { + + if ip, err := netip.ParseAddr(host); err == nil { + if ip.Is4() { addr = make([]byte, 1+net.IPv4len+2) addr[0] = ATypIP4 - copy(addr[1:], ip4) } else { addr = make([]byte, 1+net.IPv6len+2) addr[0] = ATypIP6 - copy(addr[1:], ip) } + copy(addr[1:], ip.AsSlice()) } else { if len(host) > 255 { return nil diff --git a/proxy/direct.go b/proxy/direct.go index dc5c3ed..99d9de3 100644 --- a/proxy/direct.go +++ b/proxy/direct.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "net/netip" "time" "github.com/nadoo/glider/pkg/log" @@ -27,8 +28,8 @@ func NewDirect(intface string, dialTimeout, relayTimeout time.Duration) (*Direct d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout} if intface != "" { - if ip := net.ParseIP(intface); ip != nil { - d.ip = ip + if addr, err := netip.ParseAddr(intface); err == nil { + d.ip = addr.AsSlice() } else { iface, err := net.InterfaceByName(intface) if err != nil { diff --git a/proxy/redir/redir_linux.go b/proxy/redir/redir_linux.go index 21d884c..a5c68fb 100644 --- a/proxy/redir/redir_linux.go +++ b/proxy/redir/redir_linux.go @@ -60,7 +60,7 @@ func (s *RedirProxy) ListenAndServe() { return } - log.F("[redir] listening TCP on %s", s.addr) + log.F("[redir] listening TCP on " + s.addr) for { c, err := l.Accept() @@ -140,7 +140,9 @@ func getorigdst(fd uintptr) (netip.AddrPort, error) { if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil { return netip.AddrPort{}, err } - port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian + // NOTE: raw.Port is big-endian, just change it to little-endian + // TODO: improve here when we add big-endian $GOARCH support + port := raw.Port<<8 | raw.Port>>8 return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil } @@ -152,6 +154,8 @@ func getorigdstIPv6(fd uintptr) (netip.AddrPort, error) { if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil { return netip.AddrPort{}, err } - port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian + // NOTE: raw.Port is big-endian, just change it to little-endian + // TODO: improve here when we add big-endian $GOARCH support + port := raw.Port<<8 | raw.Port>>8 return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil } diff --git a/proxy/socks5/client.go b/proxy/socks5/client.go index 08e2118..c470c3c 100644 --- a/proxy/socks5/client.go +++ b/proxy/socks5/client.go @@ -98,23 +98,11 @@ func (s *Socks5) DialUDP(network, addr string) (pc net.PacketConn, writeTo net.A // and commands the server to extend that connection to target, // which must be a canonical address with a host and port. func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) { - host, portStr, err := net.SplitHostPort(target) - if err != nil { - return - } - - port, err := strconv.Atoi(portStr) - if err != nil { - return addr, errors.New("proxy: failed to parse port number: " + portStr) - } - if port < 1 || port > 0xffff { - return addr, errors.New("proxy: port number out of range: " + portStr) - } - // the size here is just an estimate - buf := make([]byte, 0, 6+len(host)) + buf := pool.GetBuffer(socks.MaxAddrLen) + defer pool.PutBuffer(buf) - buf = append(buf, Version) + buf = append(buf[:0], Version) if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword) } else { @@ -158,24 +146,7 @@ func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Add buf = buf[:0] buf = append(buf, Version, cmd, 0 /* reserved */) - - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - buf = append(buf, socks.ATypIP4) - ip = ip4 - } else { - buf = append(buf, socks.ATypIP6) - } - buf = append(buf, ip...) - } else { - if len(host) > 255 { - return addr, errors.New("proxy: destination hostname too long: " + host) - } - buf = append(buf, socks.ATypDomain) - buf = append(buf, byte(len(host))) - buf = append(buf, host...) - } - buf = append(buf, byte(port>>8), byte(port)) + buf = append(buf, socks.ParseAddr(target)...) if _, err := conn.Write(buf); err != nil { return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) diff --git a/proxy/socks5/server.go b/proxy/socks5/server.go index 7daa2a5..a753193 100644 --- a/proxy/socks5/server.go +++ b/proxy/socks5/server.go @@ -184,7 +184,8 @@ func newSession(key string, src net.Addr, srcPC *PktConn) *Session { // Handshake fast-tracks SOCKS initialization to get target address to connect. func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) { // Read RFC 1928 for request and reply structure and sizes - buf := make([]byte, socks.MaxAddrLen) + buf := pool.GetBuffer(socks.MaxAddrLen) + defer pool.PutBuffer(buf) // read VER, NMETHODS, METHODS if _, err := io.ReadFull(c, buf[:2]); err != nil { diff --git a/proxy/vless/addr.go b/proxy/vless/addr.go index 1276906..525d5d1 100644 --- a/proxy/vless/addr.go +++ b/proxy/vless/addr.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "io" "net" + "net/netip" "strconv" "github.com/nadoo/glider/pkg/pool" @@ -39,16 +40,13 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) { return 0, nil, 0, err } - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - addr = make([]byte, net.IPv4len) + if ip, err := netip.ParseAddr(host); err == nil { + if ip.Is4() { atyp = AtypIP4 - copy(addr[:], ip4) } else { - addr = make([]byte, net.IPv6len) atyp = AtypIP6 - copy(addr[:], ip) } + addr = ip.AsSlice() } else { if len(host) > MaxHostLen { return 0, nil, 0, err diff --git a/proxy/vmess/addr.go b/proxy/vmess/addr.go index 4f6d5a0..02ed7a0 100644 --- a/proxy/vmess/addr.go +++ b/proxy/vmess/addr.go @@ -2,6 +2,7 @@ package vmess import ( "net" + "net/netip" "strconv" ) @@ -19,6 +20,9 @@ const ( // Addr is vmess addr. type Addr []byte +// MaxHostLen is the maximum size of host in bytes. +const MaxHostLen = 255 + // Port is vmess addr port. type Port uint16 @@ -32,18 +36,15 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) { return 0, nil, 0, err } - if ip := net.ParseIP(host); ip != nil { - if ip4 := ip.To4(); ip4 != nil { - addr = make([]byte, net.IPv4len) + if ip, err := netip.ParseAddr(host); err == nil { + if ip.Is4() { atyp = AtypIP4 - copy(addr[:], ip4) } else { - addr = make([]byte, net.IPv6len) atyp = AtypIP6 - copy(addr[:], ip) } + addr = ip.AsSlice() } else { - if len(host) > 255 { + if len(host) > MaxHostLen { return 0, nil, 0, err } addr = make([]byte, 1+len(host)) diff --git a/rule/proxy.go b/rule/proxy.go index a2ba3f5..437dc92 100644 --- a/rule/proxy.go +++ b/rule/proxy.go @@ -2,9 +2,11 @@ package rule import ( "net" + "net/netip" "strings" "sync" + "github.com/nadoo/glider/pkg/log" "github.com/nadoo/glider/proxy" ) @@ -34,9 +36,12 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) } for _, s := range r.CIDR { - if _, cidr, err := net.ParseCIDR(s); err == nil { - rd.cidrMap.Store(cidr, group) + cidr, err := netip.ParsePrefix(s) + if err != nil { + log.F("[rule] parse cidr error: %s", err) + continue } + rd.cidrMap.Store(cidr, group) } } @@ -48,7 +53,7 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config) for _, f := range rd.main.fwdrs { addr := strings.Split(f.addr, ",")[0] host, _, _ := net.SplitHostPort(addr) - if ip := net.ParseIP(host); ip == nil { + if _, err := netip.ParseAddr(host); err != nil { rd.domainMap.Store(strings.ToLower(host), direct) } } @@ -74,18 +79,26 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup { return p.main } - // find ip - if ip := net.ParseIP(host); ip != nil { - // check ip - if proxy, ok := p.ipMap.Load(ip.String()); ok { + // check ip + // TODO: ipv4 should equal to ipv4-mapped ipv6? but it'll need to parse the ip address + if proxy, ok := p.ipMap.Load(host); ok { + return proxy.(*FwdrGroup) + } + + // check host + host = strings.ToLower(host) + for i := len(host); i != -1; { + i = strings.LastIndexByte(host[:i], '.') + if proxy, ok := p.domainMap.Load(host[i+1:]); ok { return proxy.(*FwdrGroup) } + } + // check cidr + if ip, err := netip.ParseAddr(host); err == nil { var ret *FwdrGroup - // check cidr p.cidrMap.Range(func(key, value any) bool { - cidr := key.(*net.IPNet) - if cidr.Contains(ip) { + if key.(netip.Prefix).Contains(ip) { ret = value.(*FwdrGroup) return false } @@ -95,15 +108,6 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup { if ret != nil { return ret } - - } - - host = strings.ToLower(host) - for i := len(host); i != -1; { - i = strings.LastIndexByte(host[:i], '.') - if proxy, ok := p.domainMap.Load(host[i+1:]); ok { - return proxy.(*FwdrGroup) - } } return p.main diff --git a/service/dhcpd/dhcpd.go b/service/dhcpd/dhcpd.go index eefa484..a340f23 100644 --- a/service/dhcpd/dhcpd.go +++ b/service/dhcpd/dhcpd.go @@ -3,6 +3,7 @@ package dhcpd import ( "errors" "net" + "net/netip" "strconv" "strings" "time" @@ -29,7 +30,7 @@ func (*dpcpd) Run(args ...string) { return } - iface, startIP, endIP, leaseMin := args[0], args[1], args[2], args[3] + iface, start, end, leaseMin := args[0], args[1], args[2], args[3] if i, err := strconv.Atoi(leaseMin); err != nil { leaseTime = time.Duration(i) * time.Minute } @@ -45,7 +46,19 @@ func (*dpcpd) Run(args ...string) { return } - pool, err := NewPool(leaseTime, net.ParseIP(startIP), net.ParseIP(endIP)) + startIP, err := netip.ParseAddr(start) + if err != nil { + log.F("[dhcpd] startIP %s is not valid: %s", start, err) + return + } + + endIP, err := netip.ParseAddr(end) + if err != nil { + log.F("[dhcpd] endIP %s is not valid: %s", end, err) + return + } + + pool, err := NewPool(leaseTime, startIP, endIP) if err != nil { log.F("[dhcpd] error in pool init: %s", err) return @@ -55,12 +68,11 @@ func (*dpcpd) Run(args ...string) { for _, host := range args[4:] { pair := strings.Split(host, "=") if len(pair) == 2 { - mac, err := net.ParseMAC(pair[0]) - if err != nil { - break + if mac, err := net.ParseMAC(pair[0]); err == nil { + if ip, err := netip.ParseAddr(pair[1]); err == nil { + pool.LeaseStaticIP(mac, ip) + } } - ip := net.ParseIP(pair[1]) - pool.LeaseStaticIP(mac, ip) } } @@ -109,7 +121,7 @@ func handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler { dhcpv4.WithMessageType(replyType), dhcpv4.WithServerIP(serverIP), dhcpv4.WithNetmask(mask), - dhcpv4.WithYourIP(replyIP), + dhcpv4.WithYourIP(replyIP.AsSlice()), dhcpv4.WithRouter(serverIP), dhcpv4.WithDNS(serverIP), // RFC 2131, Section 4.3.1. Server Identifier: MUST diff --git a/service/dhcpd/pool.go b/service/dhcpd/pool.go index 6a98247..7d3bfd7 100644 --- a/service/dhcpd/pool.go +++ b/service/dhcpd/pool.go @@ -5,6 +5,7 @@ import ( "errors" "math/rand" "net" + "net/netip" "sync" "time" ) @@ -17,18 +18,14 @@ type Pool struct { } type item struct { - ip net.IP + ip netip.Addr mac net.HardwareAddr expire time.Time } // NewPool returns a new dhcp ip pool. -func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) { - if start == nil || end == nil { - return nil, errors.New("start ip or end ip is wrong/nil, please check your config") - } - - s, e := ip2num(start.To4()), ip2num(end.To4()) +func NewPool(lease time.Duration, start, end netip.Addr) (*Pool, error) { + s, e := ip2num(start), ip2num(end) if e < s { return nil, errors.New("start ip larger than end ip") } @@ -57,7 +54,7 @@ func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) { } // LeaseIP leases an ip to mac from dhcp pool. -func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) { +func (p *Pool) LeaseIP(mac net.HardwareAddr) (netip.Addr, error) { p.mutex.Lock() defer p.mutex.Unlock() @@ -84,16 +81,16 @@ func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) { } } - return nil, errors.New("no more ip can be leased") + return netip.Addr{}, errors.New("no more ip can be leased") } // LeaseStaticIP leases static ip from pool according to the given mac. -func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip net.IP) { +func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip netip.Addr) { p.mutex.Lock() defer p.mutex.Unlock() for _, item := range p.items { - if item.ip.Equal(ip) { + if item.ip == ip { item.mac = mac item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years } @@ -113,11 +110,12 @@ func (p *Pool) ReleaseIP(mac net.HardwareAddr) { } } -func ip2num(ip net.IP) uint32 { +func ip2num(addr netip.Addr) uint32 { + ip := addr.As4() n := uint32(ip[0])<<24 + uint32(ip[1])<<16 return n + uint32(ip[2])<<8 + uint32(ip[3]) } -func num2ip(n uint32) net.IP { - return []byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)} +func num2ip(n uint32) netip.Addr { + return netip.AddrFrom4([4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}) }