proxy: improve addr handling

This commit is contained in:
nadoo 2022-01-28 23:35:29 +08:00
parent e12642b47a
commit 7e7c7553cc
12 changed files with 99 additions and 109 deletions

View File

@ -87,11 +87,8 @@ func parseConfig() *Config {
os.Exit(-1)
}
// setup a log func
if conf.Verbose {
log.SetFlags(conf.LogFlags)
log.F = log.Debugf
}
// setup logger
log.Set(conf.Verbose, conf.LogFlags)
if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 {
// flag.Usage()

View File

@ -5,17 +5,19 @@ import (
stdlog "log"
)
// F is the main log function.
var F = func(string, ...any) {}
var verbose = false
// SetFlags sets the output flags for the logger.
func SetFlags(flag int) {
// Set sets the logger's verbose mode and output flags.
func Set(v bool, flag int) {
verbose = v
stdlog.SetFlags(flag)
}
// Debugf prints debug log.
func Debugf(f string, v ...any) {
stdlog.Output(2, fmt.Sprintf(f, v...))
// F prints debug log.
func F(f string, v ...any) {
if verbose {
stdlog.Output(2, fmt.Sprintf(f, v...))
}
}
// Print prints log.

View File

@ -4,6 +4,7 @@ import (
"errors"
"io"
"net"
"net/netip"
"strconv"
)
@ -139,16 +140,16 @@ func ParseAddr(s string) Addr {
if err != nil {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
if ip, err := netip.ParseAddr(host); err == nil {
if ip.Is4() {
addr = make([]byte, 1+net.IPv4len+2)
addr[0] = ATypIP4
copy(addr[1:], ip4)
} else {
addr = make([]byte, 1+net.IPv6len+2)
addr[0] = ATypIP6
copy(addr[1:], ip)
}
copy(addr[1:], ip.AsSlice())
} else {
if len(host) > 255 {
return nil

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"net/netip"
"time"
"github.com/nadoo/glider/pkg/log"
@ -27,8 +28,8 @@ func NewDirect(intface string, dialTimeout, relayTimeout time.Duration) (*Direct
d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout}
if intface != "" {
if ip := net.ParseIP(intface); ip != nil {
d.ip = ip
if addr, err := netip.ParseAddr(intface); err == nil {
d.ip = addr.AsSlice()
} else {
iface, err := net.InterfaceByName(intface)
if err != nil {

View File

@ -60,7 +60,7 @@ func (s *RedirProxy) ListenAndServe() {
return
}
log.F("[redir] listening TCP on %s", s.addr)
log.F("[redir] listening TCP on " + s.addr)
for {
c, err := l.Accept()
@ -140,7 +140,9 @@ func getorigdst(fd uintptr) (netip.AddrPort, error) {
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
return netip.AddrPort{}, err
}
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian
// NOTE: raw.Port is big-endian, just change it to little-endian
// TODO: improve here when we add big-endian $GOARCH support
port := raw.Port<<8 | raw.Port>>8
return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil
}
@ -152,6 +154,8 @@ func getorigdstIPv6(fd uintptr) (netip.AddrPort, error) {
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
return netip.AddrPort{}, err
}
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian
// NOTE: raw.Port is big-endian, just change it to little-endian
// TODO: improve here when we add big-endian $GOARCH support
port := raw.Port<<8 | raw.Port>>8
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil
}

View File

@ -98,23 +98,11 @@ func (s *Socks5) DialUDP(network, addr string) (pc net.PacketConn, writeTo net.A
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return
}
port, err := strconv.Atoi(portStr)
if err != nil {
return addr, errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return addr, errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf := pool.GetBuffer(socks.MaxAddrLen)
defer pool.PutBuffer(buf)
buf = append(buf, Version)
buf = append(buf[:0], Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword)
} else {
@ -158,24 +146,7 @@ func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Add
buf = buf[:0]
buf = append(buf, Version, cmd, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, socks.ATypIP4)
ip = ip4
} else {
buf = append(buf, socks.ATypIP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return addr, errors.New("proxy: destination hostname too long: " + host)
}
buf = append(buf, socks.ATypDomain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
buf = append(buf, socks.ParseAddr(target)...)
if _, err := conn.Write(buf); err != nil {
return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())

View File

@ -184,7 +184,8 @@ func newSession(key string, src net.Addr, srcPC *PktConn) *Session {
// Handshake fast-tracks SOCKS initialization to get target address to connect.
func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) {
// Read RFC 1928 for request and reply structure and sizes
buf := make([]byte, socks.MaxAddrLen)
buf := pool.GetBuffer(socks.MaxAddrLen)
defer pool.PutBuffer(buf)
// read VER, NMETHODS, METHODS
if _, err := io.ReadFull(c, buf[:2]); err != nil {

View File

@ -4,6 +4,7 @@ import (
"encoding/binary"
"io"
"net"
"net/netip"
"strconv"
"github.com/nadoo/glider/pkg/pool"
@ -39,16 +40,13 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
return 0, nil, 0, err
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
addr = make([]byte, net.IPv4len)
if ip, err := netip.ParseAddr(host); err == nil {
if ip.Is4() {
atyp = AtypIP4
copy(addr[:], ip4)
} else {
addr = make([]byte, net.IPv6len)
atyp = AtypIP6
copy(addr[:], ip)
}
addr = ip.AsSlice()
} else {
if len(host) > MaxHostLen {
return 0, nil, 0, err

View File

@ -2,6 +2,7 @@ package vmess
import (
"net"
"net/netip"
"strconv"
)
@ -19,6 +20,9 @@ const (
// Addr is vmess addr.
type Addr []byte
// MaxHostLen is the maximum size of host in bytes.
const MaxHostLen = 255
// Port is vmess addr port.
type Port uint16
@ -32,18 +36,15 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
return 0, nil, 0, err
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
addr = make([]byte, net.IPv4len)
if ip, err := netip.ParseAddr(host); err == nil {
if ip.Is4() {
atyp = AtypIP4
copy(addr[:], ip4)
} else {
addr = make([]byte, net.IPv6len)
atyp = AtypIP6
copy(addr[:], ip)
}
addr = ip.AsSlice()
} else {
if len(host) > 255 {
if len(host) > MaxHostLen {
return 0, nil, 0, err
}
addr = make([]byte, 1+len(host))

View File

@ -2,9 +2,11 @@ package rule
import (
"net"
"net/netip"
"strings"
"sync"
"github.com/nadoo/glider/pkg/log"
"github.com/nadoo/glider/proxy"
)
@ -34,9 +36,12 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
}
for _, s := range r.CIDR {
if _, cidr, err := net.ParseCIDR(s); err == nil {
rd.cidrMap.Store(cidr, group)
cidr, err := netip.ParsePrefix(s)
if err != nil {
log.F("[rule] parse cidr error: %s", err)
continue
}
rd.cidrMap.Store(cidr, group)
}
}
@ -48,7 +53,7 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
for _, f := range rd.main.fwdrs {
addr := strings.Split(f.addr, ",")[0]
host, _, _ := net.SplitHostPort(addr)
if ip := net.ParseIP(host); ip == nil {
if _, err := netip.ParseAddr(host); err != nil {
rd.domainMap.Store(strings.ToLower(host), direct)
}
}
@ -74,18 +79,26 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
return p.main
}
// find ip
if ip := net.ParseIP(host); ip != nil {
// check ip
if proxy, ok := p.ipMap.Load(ip.String()); ok {
// check ip
// TODO: ipv4 should equal to ipv4-mapped ipv6? but it'll need to parse the ip address
if proxy, ok := p.ipMap.Load(host); ok {
return proxy.(*FwdrGroup)
}
// check host
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
return proxy.(*FwdrGroup)
}
}
// check cidr
if ip, err := netip.ParseAddr(host); err == nil {
var ret *FwdrGroup
// check cidr
p.cidrMap.Range(func(key, value any) bool {
cidr := key.(*net.IPNet)
if cidr.Contains(ip) {
if key.(netip.Prefix).Contains(ip) {
ret = value.(*FwdrGroup)
return false
}
@ -95,15 +108,6 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
if ret != nil {
return ret
}
}
host = strings.ToLower(host)
for i := len(host); i != -1; {
i = strings.LastIndexByte(host[:i], '.')
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
return proxy.(*FwdrGroup)
}
}
return p.main

View File

@ -3,6 +3,7 @@ package dhcpd
import (
"errors"
"net"
"net/netip"
"strconv"
"strings"
"time"
@ -29,7 +30,7 @@ func (*dpcpd) Run(args ...string) {
return
}
iface, startIP, endIP, leaseMin := args[0], args[1], args[2], args[3]
iface, start, end, leaseMin := args[0], args[1], args[2], args[3]
if i, err := strconv.Atoi(leaseMin); err != nil {
leaseTime = time.Duration(i) * time.Minute
}
@ -45,7 +46,19 @@ func (*dpcpd) Run(args ...string) {
return
}
pool, err := NewPool(leaseTime, net.ParseIP(startIP), net.ParseIP(endIP))
startIP, err := netip.ParseAddr(start)
if err != nil {
log.F("[dhcpd] startIP %s is not valid: %s", start, err)
return
}
endIP, err := netip.ParseAddr(end)
if err != nil {
log.F("[dhcpd] endIP %s is not valid: %s", end, err)
return
}
pool, err := NewPool(leaseTime, startIP, endIP)
if err != nil {
log.F("[dhcpd] error in pool init: %s", err)
return
@ -55,12 +68,11 @@ func (*dpcpd) Run(args ...string) {
for _, host := range args[4:] {
pair := strings.Split(host, "=")
if len(pair) == 2 {
mac, err := net.ParseMAC(pair[0])
if err != nil {
break
if mac, err := net.ParseMAC(pair[0]); err == nil {
if ip, err := netip.ParseAddr(pair[1]); err == nil {
pool.LeaseStaticIP(mac, ip)
}
}
ip := net.ParseIP(pair[1])
pool.LeaseStaticIP(mac, ip)
}
}
@ -109,7 +121,7 @@ func handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler {
dhcpv4.WithMessageType(replyType),
dhcpv4.WithServerIP(serverIP),
dhcpv4.WithNetmask(mask),
dhcpv4.WithYourIP(replyIP),
dhcpv4.WithYourIP(replyIP.AsSlice()),
dhcpv4.WithRouter(serverIP),
dhcpv4.WithDNS(serverIP),
// RFC 2131, Section 4.3.1. Server Identifier: MUST

View File

@ -5,6 +5,7 @@ import (
"errors"
"math/rand"
"net"
"net/netip"
"sync"
"time"
)
@ -17,18 +18,14 @@ type Pool struct {
}
type item struct {
ip net.IP
ip netip.Addr
mac net.HardwareAddr
expire time.Time
}
// NewPool returns a new dhcp ip pool.
func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) {
if start == nil || end == nil {
return nil, errors.New("start ip or end ip is wrong/nil, please check your config")
}
s, e := ip2num(start.To4()), ip2num(end.To4())
func NewPool(lease time.Duration, start, end netip.Addr) (*Pool, error) {
s, e := ip2num(start), ip2num(end)
if e < s {
return nil, errors.New("start ip larger than end ip")
}
@ -57,7 +54,7 @@ func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) {
}
// LeaseIP leases an ip to mac from dhcp pool.
func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) {
func (p *Pool) LeaseIP(mac net.HardwareAddr) (netip.Addr, error) {
p.mutex.Lock()
defer p.mutex.Unlock()
@ -84,16 +81,16 @@ func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) {
}
}
return nil, errors.New("no more ip can be leased")
return netip.Addr{}, errors.New("no more ip can be leased")
}
// LeaseStaticIP leases static ip from pool according to the given mac.
func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip net.IP) {
func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip netip.Addr) {
p.mutex.Lock()
defer p.mutex.Unlock()
for _, item := range p.items {
if item.ip.Equal(ip) {
if item.ip == ip {
item.mac = mac
item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years
}
@ -113,11 +110,12 @@ func (p *Pool) ReleaseIP(mac net.HardwareAddr) {
}
}
func ip2num(ip net.IP) uint32 {
func ip2num(addr netip.Addr) uint32 {
ip := addr.As4()
n := uint32(ip[0])<<24 + uint32(ip[1])<<16
return n + uint32(ip[2])<<8 + uint32(ip[3])
}
func num2ip(n uint32) net.IP {
return []byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}
func num2ip(n uint32) netip.Addr {
return netip.AddrFrom4([4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
}