proxy: improve addr handling

This commit is contained in:
nadoo 2022-01-28 23:35:29 +08:00
parent e12642b47a
commit 7e7c7553cc
12 changed files with 99 additions and 109 deletions

View File

@ -87,11 +87,8 @@ func parseConfig() *Config {
os.Exit(-1) os.Exit(-1)
} }
// setup a log func // setup logger
if conf.Verbose { log.Set(conf.Verbose, conf.LogFlags)
log.SetFlags(conf.LogFlags)
log.F = log.Debugf
}
if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 { if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 {
// flag.Usage() // flag.Usage()

View File

@ -5,17 +5,19 @@ import (
stdlog "log" stdlog "log"
) )
// F is the main log function. var verbose = false
var F = func(string, ...any) {}
// SetFlags sets the output flags for the logger. // Set sets the logger's verbose mode and output flags.
func SetFlags(flag int) { func Set(v bool, flag int) {
verbose = v
stdlog.SetFlags(flag) stdlog.SetFlags(flag)
} }
// Debugf prints debug log. // F prints debug log.
func Debugf(f string, v ...any) { func F(f string, v ...any) {
if verbose {
stdlog.Output(2, fmt.Sprintf(f, v...)) stdlog.Output(2, fmt.Sprintf(f, v...))
}
} }
// Print prints log. // Print prints log.

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
) )
@ -139,16 +140,16 @@ func ParseAddr(s string) Addr {
if err != nil { if err != nil {
return nil return nil
} }
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil { if ip, err := netip.ParseAddr(host); err == nil {
if ip.Is4() {
addr = make([]byte, 1+net.IPv4len+2) addr = make([]byte, 1+net.IPv4len+2)
addr[0] = ATypIP4 addr[0] = ATypIP4
copy(addr[1:], ip4)
} else { } else {
addr = make([]byte, 1+net.IPv6len+2) addr = make([]byte, 1+net.IPv6len+2)
addr[0] = ATypIP6 addr[0] = ATypIP6
copy(addr[1:], ip)
} }
copy(addr[1:], ip.AsSlice())
} else { } else {
if len(host) > 255 { if len(host) > 255 {
return nil return nil

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"net/netip"
"time" "time"
"github.com/nadoo/glider/pkg/log" "github.com/nadoo/glider/pkg/log"
@ -27,8 +28,8 @@ func NewDirect(intface string, dialTimeout, relayTimeout time.Duration) (*Direct
d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout} d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout}
if intface != "" { if intface != "" {
if ip := net.ParseIP(intface); ip != nil { if addr, err := netip.ParseAddr(intface); err == nil {
d.ip = ip d.ip = addr.AsSlice()
} else { } else {
iface, err := net.InterfaceByName(intface) iface, err := net.InterfaceByName(intface)
if err != nil { if err != nil {

View File

@ -60,7 +60,7 @@ func (s *RedirProxy) ListenAndServe() {
return return
} }
log.F("[redir] listening TCP on %s", s.addr) log.F("[redir] listening TCP on " + s.addr)
for { for {
c, err := l.Accept() c, err := l.Accept()
@ -140,7 +140,9 @@ func getorigdst(fd uintptr) (netip.AddrPort, error) {
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil { if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
return netip.AddrPort{}, err return netip.AddrPort{}, err
} }
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian // NOTE: raw.Port is big-endian, just change it to little-endian
// TODO: improve here when we add big-endian $GOARCH support
port := raw.Port<<8 | raw.Port>>8
return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil
} }
@ -152,6 +154,8 @@ func getorigdstIPv6(fd uintptr) (netip.AddrPort, error) {
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil { if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
return netip.AddrPort{}, err return netip.AddrPort{}, err
} }
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian // NOTE: raw.Port is big-endian, just change it to little-endian
// TODO: improve here when we add big-endian $GOARCH support
port := raw.Port<<8 | raw.Port>>8
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil
} }

View File

@ -98,23 +98,11 @@ func (s *Socks5) DialUDP(network, addr string) (pc net.PacketConn, writeTo net.A
// and commands the server to extend that connection to target, // and commands the server to extend that connection to target,
// which must be a canonical address with a host and port. // which must be a canonical address with a host and port.
func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) { func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return
}
port, err := strconv.Atoi(portStr)
if err != nil {
return addr, errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return addr, errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate // the size here is just an estimate
buf := make([]byte, 0, 6+len(host)) buf := pool.GetBuffer(socks.MaxAddrLen)
defer pool.PutBuffer(buf)
buf = append(buf, Version) buf = append(buf[:0], Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword) buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword)
} else { } else {
@ -158,24 +146,7 @@ func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Add
buf = buf[:0] buf = buf[:0]
buf = append(buf, Version, cmd, 0 /* reserved */) buf = append(buf, Version, cmd, 0 /* reserved */)
buf = append(buf, socks.ParseAddr(target)...)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, socks.ATypIP4)
ip = ip4
} else {
buf = append(buf, socks.ATypIP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return addr, errors.New("proxy: destination hostname too long: " + host)
}
buf = append(buf, socks.ATypDomain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil { if _, err := conn.Write(buf); err != nil {
return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())

View File

@ -184,7 +184,8 @@ func newSession(key string, src net.Addr, srcPC *PktConn) *Session {
// Handshake fast-tracks SOCKS initialization to get target address to connect. // Handshake fast-tracks SOCKS initialization to get target address to connect.
func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) { func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) {
// Read RFC 1928 for request and reply structure and sizes // Read RFC 1928 for request and reply structure and sizes
buf := make([]byte, socks.MaxAddrLen) buf := pool.GetBuffer(socks.MaxAddrLen)
defer pool.PutBuffer(buf)
// read VER, NMETHODS, METHODS // read VER, NMETHODS, METHODS
if _, err := io.ReadFull(c, buf[:2]); err != nil { if _, err := io.ReadFull(c, buf[:2]); err != nil {

View File

@ -4,6 +4,7 @@ import (
"encoding/binary" "encoding/binary"
"io" "io"
"net" "net"
"net/netip"
"strconv" "strconv"
"github.com/nadoo/glider/pkg/pool" "github.com/nadoo/glider/pkg/pool"
@ -39,16 +40,13 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
return 0, nil, 0, err return 0, nil, 0, err
} }
if ip := net.ParseIP(host); ip != nil { if ip, err := netip.ParseAddr(host); err == nil {
if ip4 := ip.To4(); ip4 != nil { if ip.Is4() {
addr = make([]byte, net.IPv4len)
atyp = AtypIP4 atyp = AtypIP4
copy(addr[:], ip4)
} else { } else {
addr = make([]byte, net.IPv6len)
atyp = AtypIP6 atyp = AtypIP6
copy(addr[:], ip)
} }
addr = ip.AsSlice()
} else { } else {
if len(host) > MaxHostLen { if len(host) > MaxHostLen {
return 0, nil, 0, err return 0, nil, 0, err

View File

@ -2,6 +2,7 @@ package vmess
import ( import (
"net" "net"
"net/netip"
"strconv" "strconv"
) )
@ -19,6 +20,9 @@ const (
// Addr is vmess addr. // Addr is vmess addr.
type Addr []byte type Addr []byte
// MaxHostLen is the maximum size of host in bytes.
const MaxHostLen = 255
// Port is vmess addr port. // Port is vmess addr port.
type Port uint16 type Port uint16
@ -32,18 +36,15 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
return 0, nil, 0, err return 0, nil, 0, err
} }
if ip := net.ParseIP(host); ip != nil { if ip, err := netip.ParseAddr(host); err == nil {
if ip4 := ip.To4(); ip4 != nil { if ip.Is4() {
addr = make([]byte, net.IPv4len)
atyp = AtypIP4 atyp = AtypIP4
copy(addr[:], ip4)
} else { } else {
addr = make([]byte, net.IPv6len)
atyp = AtypIP6 atyp = AtypIP6
copy(addr[:], ip)
} }
addr = ip.AsSlice()
} else { } else {
if len(host) > 255 { if len(host) > MaxHostLen {
return 0, nil, 0, err return 0, nil, 0, err
} }
addr = make([]byte, 1+len(host)) addr = make([]byte, 1+len(host))

View File

@ -2,9 +2,11 @@ package rule
import ( import (
"net" "net"
"net/netip"
"strings" "strings"
"sync" "sync"
"github.com/nadoo/glider/pkg/log"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -34,9 +36,12 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
} }
for _, s := range r.CIDR { for _, s := range r.CIDR {
if _, cidr, err := net.ParseCIDR(s); err == nil { cidr, err := netip.ParsePrefix(s)
rd.cidrMap.Store(cidr, group) if err != nil {
log.F("[rule] parse cidr error: %s", err)
continue
} }
rd.cidrMap.Store(cidr, group)
} }
} }
@ -48,7 +53,7 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
for _, f := range rd.main.fwdrs { for _, f := range rd.main.fwdrs {
addr := strings.Split(f.addr, ",")[0] addr := strings.Split(f.addr, ",")[0]
host, _, _ := net.SplitHostPort(addr) host, _, _ := net.SplitHostPort(addr)
if ip := net.ParseIP(host); ip == nil { if _, err := netip.ParseAddr(host); err != nil {
rd.domainMap.Store(strings.ToLower(host), direct) rd.domainMap.Store(strings.ToLower(host), direct)
} }
} }
@ -74,18 +79,26 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
return p.main return p.main
} }
// find ip
if ip := net.ParseIP(host); ip != nil {
// check ip // check ip
if proxy, ok := p.ipMap.Load(ip.String()); ok { // TODO: ipv4 should equal to ipv4-mapped ipv6? but it'll need to parse the ip address
if proxy, ok := p.ipMap.Load(host); ok {
return proxy.(*FwdrGroup) return proxy.(*FwdrGroup)
} }
var ret *FwdrGroup // check host
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
return proxy.(*FwdrGroup)
}
}
// check cidr // check cidr
if ip, err := netip.ParseAddr(host); err == nil {
var ret *FwdrGroup
p.cidrMap.Range(func(key, value any) bool { p.cidrMap.Range(func(key, value any) bool {
cidr := key.(*net.IPNet) if key.(netip.Prefix).Contains(ip) {
if cidr.Contains(ip) {
ret = value.(*FwdrGroup) ret = value.(*FwdrGroup)
return false return false
} }
@ -95,15 +108,6 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
if ret != nil { if ret != nil {
return ret return ret
} }
}
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
return proxy.(*FwdrGroup)
}
} }
return p.main return p.main

View File

@ -3,6 +3,7 @@ package dhcpd
import ( import (
"errors" "errors"
"net" "net"
"net/netip"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -29,7 +30,7 @@ func (*dpcpd) Run(args ...string) {
return return
} }
iface, startIP, endIP, leaseMin := args[0], args[1], args[2], args[3] iface, start, end, leaseMin := args[0], args[1], args[2], args[3]
if i, err := strconv.Atoi(leaseMin); err != nil { if i, err := strconv.Atoi(leaseMin); err != nil {
leaseTime = time.Duration(i) * time.Minute leaseTime = time.Duration(i) * time.Minute
} }
@ -45,7 +46,19 @@ func (*dpcpd) Run(args ...string) {
return return
} }
pool, err := NewPool(leaseTime, net.ParseIP(startIP), net.ParseIP(endIP)) startIP, err := netip.ParseAddr(start)
if err != nil {
log.F("[dhcpd] startIP %s is not valid: %s", start, err)
return
}
endIP, err := netip.ParseAddr(end)
if err != nil {
log.F("[dhcpd] endIP %s is not valid: %s", end, err)
return
}
pool, err := NewPool(leaseTime, startIP, endIP)
if err != nil { if err != nil {
log.F("[dhcpd] error in pool init: %s", err) log.F("[dhcpd] error in pool init: %s", err)
return return
@ -55,14 +68,13 @@ func (*dpcpd) Run(args ...string) {
for _, host := range args[4:] { for _, host := range args[4:] {
pair := strings.Split(host, "=") pair := strings.Split(host, "=")
if len(pair) == 2 { if len(pair) == 2 {
mac, err := net.ParseMAC(pair[0]) if mac, err := net.ParseMAC(pair[0]); err == nil {
if err != nil { if ip, err := netip.ParseAddr(pair[1]); err == nil {
break
}
ip := net.ParseIP(pair[1])
pool.LeaseStaticIP(mac, ip) pool.LeaseStaticIP(mac, ip)
} }
} }
}
}
laddr := net.UDPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 67} laddr := net.UDPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 67}
server, err := server4.NewServer(iface, &laddr, handleDHCP(ip, mask, pool)) server, err := server4.NewServer(iface, &laddr, handleDHCP(ip, mask, pool))
@ -109,7 +121,7 @@ func handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler {
dhcpv4.WithMessageType(replyType), dhcpv4.WithMessageType(replyType),
dhcpv4.WithServerIP(serverIP), dhcpv4.WithServerIP(serverIP),
dhcpv4.WithNetmask(mask), dhcpv4.WithNetmask(mask),
dhcpv4.WithYourIP(replyIP), dhcpv4.WithYourIP(replyIP.AsSlice()),
dhcpv4.WithRouter(serverIP), dhcpv4.WithRouter(serverIP),
dhcpv4.WithDNS(serverIP), dhcpv4.WithDNS(serverIP),
// RFC 2131, Section 4.3.1. Server Identifier: MUST // RFC 2131, Section 4.3.1. Server Identifier: MUST

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"math/rand" "math/rand"
"net" "net"
"net/netip"
"sync" "sync"
"time" "time"
) )
@ -17,18 +18,14 @@ type Pool struct {
} }
type item struct { type item struct {
ip net.IP ip netip.Addr
mac net.HardwareAddr mac net.HardwareAddr
expire time.Time expire time.Time
} }
// NewPool returns a new dhcp ip pool. // NewPool returns a new dhcp ip pool.
func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) { func NewPool(lease time.Duration, start, end netip.Addr) (*Pool, error) {
if start == nil || end == nil { s, e := ip2num(start), ip2num(end)
return nil, errors.New("start ip or end ip is wrong/nil, please check your config")
}
s, e := ip2num(start.To4()), ip2num(end.To4())
if e < s { if e < s {
return nil, errors.New("start ip larger than end ip") return nil, errors.New("start ip larger than end ip")
} }
@ -57,7 +54,7 @@ func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) {
} }
// LeaseIP leases an ip to mac from dhcp pool. // LeaseIP leases an ip to mac from dhcp pool.
func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) { func (p *Pool) LeaseIP(mac net.HardwareAddr) (netip.Addr, error) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
@ -84,16 +81,16 @@ func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) {
} }
} }
return nil, errors.New("no more ip can be leased") return netip.Addr{}, errors.New("no more ip can be leased")
} }
// LeaseStaticIP leases static ip from pool according to the given mac. // LeaseStaticIP leases static ip from pool according to the given mac.
func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip net.IP) { func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip netip.Addr) {
p.mutex.Lock() p.mutex.Lock()
defer p.mutex.Unlock() defer p.mutex.Unlock()
for _, item := range p.items { for _, item := range p.items {
if item.ip.Equal(ip) { if item.ip == ip {
item.mac = mac item.mac = mac
item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years
} }
@ -113,11 +110,12 @@ func (p *Pool) ReleaseIP(mac net.HardwareAddr) {
} }
} }
func ip2num(ip net.IP) uint32 { func ip2num(addr netip.Addr) uint32 {
ip := addr.As4()
n := uint32(ip[0])<<24 + uint32(ip[1])<<16 n := uint32(ip[0])<<24 + uint32(ip[1])<<16
return n + uint32(ip[2])<<8 + uint32(ip[3]) return n + uint32(ip[2])<<8 + uint32(ip[3])
} }
func num2ip(n uint32) net.IP { func num2ip(n uint32) netip.Addr {
return []byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)} return netip.AddrFrom4([4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
} }