mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 09:25:41 +08:00
proxy: improve addr handling
This commit is contained in:
parent
e12642b47a
commit
7e7c7553cc
@ -87,11 +87,8 @@ func parseConfig() *Config {
|
|||||||
os.Exit(-1)
|
os.Exit(-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup a log func
|
// setup logger
|
||||||
if conf.Verbose {
|
log.Set(conf.Verbose, conf.LogFlags)
|
||||||
log.SetFlags(conf.LogFlags)
|
|
||||||
log.F = log.Debugf
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 {
|
if len(conf.Listens) == 0 && conf.DNS == "" && len(conf.Services) == 0 {
|
||||||
// flag.Usage()
|
// flag.Usage()
|
||||||
|
@ -5,17 +5,19 @@ import (
|
|||||||
stdlog "log"
|
stdlog "log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// F is the main log function.
|
var verbose = false
|
||||||
var F = func(string, ...any) {}
|
|
||||||
|
|
||||||
// SetFlags sets the output flags for the logger.
|
// Set sets the logger's verbose mode and output flags.
|
||||||
func SetFlags(flag int) {
|
func Set(v bool, flag int) {
|
||||||
|
verbose = v
|
||||||
stdlog.SetFlags(flag)
|
stdlog.SetFlags(flag)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Debugf prints debug log.
|
// F prints debug log.
|
||||||
func Debugf(f string, v ...any) {
|
func F(f string, v ...any) {
|
||||||
stdlog.Output(2, fmt.Sprintf(f, v...))
|
if verbose {
|
||||||
|
stdlog.Output(2, fmt.Sprintf(f, v...))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Print prints log.
|
// Print prints log.
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -139,16 +140,16 @@ func ParseAddr(s string) Addr {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
if ip, err := netip.ParseAddr(host); err == nil {
|
||||||
|
if ip.Is4() {
|
||||||
addr = make([]byte, 1+net.IPv4len+2)
|
addr = make([]byte, 1+net.IPv4len+2)
|
||||||
addr[0] = ATypIP4
|
addr[0] = ATypIP4
|
||||||
copy(addr[1:], ip4)
|
|
||||||
} else {
|
} else {
|
||||||
addr = make([]byte, 1+net.IPv6len+2)
|
addr = make([]byte, 1+net.IPv6len+2)
|
||||||
addr[0] = ATypIP6
|
addr[0] = ATypIP6
|
||||||
copy(addr[1:], ip)
|
|
||||||
}
|
}
|
||||||
|
copy(addr[1:], ip.AsSlice())
|
||||||
} else {
|
} else {
|
||||||
if len(host) > 255 {
|
if len(host) > 255 {
|
||||||
return nil
|
return nil
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/nadoo/glider/pkg/log"
|
"github.com/nadoo/glider/pkg/log"
|
||||||
@ -27,8 +28,8 @@ func NewDirect(intface string, dialTimeout, relayTimeout time.Duration) (*Direct
|
|||||||
d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout}
|
d := &Direct{dialTimeout: dialTimeout, relayTimeout: relayTimeout}
|
||||||
|
|
||||||
if intface != "" {
|
if intface != "" {
|
||||||
if ip := net.ParseIP(intface); ip != nil {
|
if addr, err := netip.ParseAddr(intface); err == nil {
|
||||||
d.ip = ip
|
d.ip = addr.AsSlice()
|
||||||
} else {
|
} else {
|
||||||
iface, err := net.InterfaceByName(intface)
|
iface, err := net.InterfaceByName(intface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -60,7 +60,7 @@ func (s *RedirProxy) ListenAndServe() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.F("[redir] listening TCP on %s", s.addr)
|
log.F("[redir] listening TCP on " + s.addr)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
c, err := l.Accept()
|
c, err := l.Accept()
|
||||||
@ -140,7 +140,9 @@ func getorigdst(fd uintptr) (netip.AddrPort, error) {
|
|||||||
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
|
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IP, _SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
|
||||||
return netip.AddrPort{}, err
|
return netip.AddrPort{}, err
|
||||||
}
|
}
|
||||||
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian
|
// NOTE: raw.Port is big-endian, just change it to little-endian
|
||||||
|
// TODO: improve here when we add big-endian $GOARCH support
|
||||||
|
port := raw.Port<<8 | raw.Port>>8
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil
|
return netip.AddrPortFrom(netip.AddrFrom4(raw.Addr), port), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,6 +154,8 @@ func getorigdstIPv6(fd uintptr) (netip.AddrPort, error) {
|
|||||||
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
|
if err := socketcall(GETSOCKOPT, fd, syscall.IPPROTO_IPV6, _IP6T_SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&raw)), uintptr(unsafe.Pointer(&siz)), 0); err != nil {
|
||||||
return netip.AddrPort{}, err
|
return netip.AddrPort{}, err
|
||||||
}
|
}
|
||||||
port := raw.Port<<8 | raw.Port>>8 // raw.Port is big-endian
|
// NOTE: raw.Port is big-endian, just change it to little-endian
|
||||||
|
// TODO: improve here when we add big-endian $GOARCH support
|
||||||
|
port := raw.Port<<8 | raw.Port>>8
|
||||||
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil
|
return netip.AddrPortFrom(netip.AddrFrom16(raw.Addr), port), nil
|
||||||
}
|
}
|
||||||
|
@ -98,23 +98,11 @@ func (s *Socks5) DialUDP(network, addr string) (pc net.PacketConn, writeTo net.A
|
|||||||
// and commands the server to extend that connection to target,
|
// and commands the server to extend that connection to target,
|
||||||
// which must be a canonical address with a host and port.
|
// which must be a canonical address with a host and port.
|
||||||
func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) {
|
func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Addr, err error) {
|
||||||
host, portStr, err := net.SplitHostPort(target)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
|
||||||
if err != nil {
|
|
||||||
return addr, errors.New("proxy: failed to parse port number: " + portStr)
|
|
||||||
}
|
|
||||||
if port < 1 || port > 0xffff {
|
|
||||||
return addr, errors.New("proxy: port number out of range: " + portStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// the size here is just an estimate
|
// the size here is just an estimate
|
||||||
buf := make([]byte, 0, 6+len(host))
|
buf := pool.GetBuffer(socks.MaxAddrLen)
|
||||||
|
defer pool.PutBuffer(buf)
|
||||||
|
|
||||||
buf = append(buf, Version)
|
buf = append(buf[:0], Version)
|
||||||
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
|
||||||
buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword)
|
buf = append(buf, 2 /* num auth methods */, socks.AuthNone, socks.AuthPassword)
|
||||||
} else {
|
} else {
|
||||||
@ -158,24 +146,7 @@ func (s *Socks5) connect(conn net.Conn, target string, cmd byte) (addr socks.Add
|
|||||||
|
|
||||||
buf = buf[:0]
|
buf = buf[:0]
|
||||||
buf = append(buf, Version, cmd, 0 /* reserved */)
|
buf = append(buf, Version, cmd, 0 /* reserved */)
|
||||||
|
buf = append(buf, socks.ParseAddr(target)...)
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
buf = append(buf, socks.ATypIP4)
|
|
||||||
ip = ip4
|
|
||||||
} else {
|
|
||||||
buf = append(buf, socks.ATypIP6)
|
|
||||||
}
|
|
||||||
buf = append(buf, ip...)
|
|
||||||
} else {
|
|
||||||
if len(host) > 255 {
|
|
||||||
return addr, errors.New("proxy: destination hostname too long: " + host)
|
|
||||||
}
|
|
||||||
buf = append(buf, socks.ATypDomain)
|
|
||||||
buf = append(buf, byte(len(host)))
|
|
||||||
buf = append(buf, host...)
|
|
||||||
}
|
|
||||||
buf = append(buf, byte(port>>8), byte(port))
|
|
||||||
|
|
||||||
if _, err := conn.Write(buf); err != nil {
|
if _, err := conn.Write(buf); err != nil {
|
||||||
return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
return addr, errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
|
||||||
|
@ -184,7 +184,8 @@ func newSession(key string, src net.Addr, srcPC *PktConn) *Session {
|
|||||||
// Handshake fast-tracks SOCKS initialization to get target address to connect.
|
// Handshake fast-tracks SOCKS initialization to get target address to connect.
|
||||||
func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) {
|
func (s *Socks5) handshake(c net.Conn) (socks.Addr, error) {
|
||||||
// Read RFC 1928 for request and reply structure and sizes
|
// Read RFC 1928 for request and reply structure and sizes
|
||||||
buf := make([]byte, socks.MaxAddrLen)
|
buf := pool.GetBuffer(socks.MaxAddrLen)
|
||||||
|
defer pool.PutBuffer(buf)
|
||||||
|
|
||||||
// read VER, NMETHODS, METHODS
|
// read VER, NMETHODS, METHODS
|
||||||
if _, err := io.ReadFull(c, buf[:2]); err != nil {
|
if _, err := io.ReadFull(c, buf[:2]); err != nil {
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/nadoo/glider/pkg/pool"
|
"github.com/nadoo/glider/pkg/pool"
|
||||||
@ -39,16 +40,13 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
|
|||||||
return 0, nil, 0, err
|
return 0, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip, err := netip.ParseAddr(host); err == nil {
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
if ip.Is4() {
|
||||||
addr = make([]byte, net.IPv4len)
|
|
||||||
atyp = AtypIP4
|
atyp = AtypIP4
|
||||||
copy(addr[:], ip4)
|
|
||||||
} else {
|
} else {
|
||||||
addr = make([]byte, net.IPv6len)
|
|
||||||
atyp = AtypIP6
|
atyp = AtypIP6
|
||||||
copy(addr[:], ip)
|
|
||||||
}
|
}
|
||||||
|
addr = ip.AsSlice()
|
||||||
} else {
|
} else {
|
||||||
if len(host) > MaxHostLen {
|
if len(host) > MaxHostLen {
|
||||||
return 0, nil, 0, err
|
return 0, nil, 0, err
|
||||||
|
@ -2,6 +2,7 @@ package vmess
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -19,6 +20,9 @@ const (
|
|||||||
// Addr is vmess addr.
|
// Addr is vmess addr.
|
||||||
type Addr []byte
|
type Addr []byte
|
||||||
|
|
||||||
|
// MaxHostLen is the maximum size of host in bytes.
|
||||||
|
const MaxHostLen = 255
|
||||||
|
|
||||||
// Port is vmess addr port.
|
// Port is vmess addr port.
|
||||||
type Port uint16
|
type Port uint16
|
||||||
|
|
||||||
@ -32,18 +36,15 @@ func ParseAddr(s string) (Atyp, Addr, Port, error) {
|
|||||||
return 0, nil, 0, err
|
return 0, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
if ip, err := netip.ParseAddr(host); err == nil {
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
if ip.Is4() {
|
||||||
addr = make([]byte, net.IPv4len)
|
|
||||||
atyp = AtypIP4
|
atyp = AtypIP4
|
||||||
copy(addr[:], ip4)
|
|
||||||
} else {
|
} else {
|
||||||
addr = make([]byte, net.IPv6len)
|
|
||||||
atyp = AtypIP6
|
atyp = AtypIP6
|
||||||
copy(addr[:], ip)
|
|
||||||
}
|
}
|
||||||
|
addr = ip.AsSlice()
|
||||||
} else {
|
} else {
|
||||||
if len(host) > 255 {
|
if len(host) > MaxHostLen {
|
||||||
return 0, nil, 0, err
|
return 0, nil, 0, err
|
||||||
}
|
}
|
||||||
addr = make([]byte, 1+len(host))
|
addr = make([]byte, 1+len(host))
|
||||||
|
@ -2,9 +2,11 @@ package rule
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/nadoo/glider/pkg/log"
|
||||||
"github.com/nadoo/glider/proxy"
|
"github.com/nadoo/glider/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -34,9 +36,12 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, s := range r.CIDR {
|
for _, s := range r.CIDR {
|
||||||
if _, cidr, err := net.ParseCIDR(s); err == nil {
|
cidr, err := netip.ParsePrefix(s)
|
||||||
rd.cidrMap.Store(cidr, group)
|
if err != nil {
|
||||||
|
log.F("[rule] parse cidr error: %s", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
rd.cidrMap.Store(cidr, group)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -48,7 +53,7 @@ func NewProxy(mainForwarders []string, mainStrategy *Strategy, rules []*Config)
|
|||||||
for _, f := range rd.main.fwdrs {
|
for _, f := range rd.main.fwdrs {
|
||||||
addr := strings.Split(f.addr, ",")[0]
|
addr := strings.Split(f.addr, ",")[0]
|
||||||
host, _, _ := net.SplitHostPort(addr)
|
host, _, _ := net.SplitHostPort(addr)
|
||||||
if ip := net.ParseIP(host); ip == nil {
|
if _, err := netip.ParseAddr(host); err != nil {
|
||||||
rd.domainMap.Store(strings.ToLower(host), direct)
|
rd.domainMap.Store(strings.ToLower(host), direct)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -74,18 +79,26 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
|
|||||||
return p.main
|
return p.main
|
||||||
}
|
}
|
||||||
|
|
||||||
// find ip
|
// check ip
|
||||||
if ip := net.ParseIP(host); ip != nil {
|
// TODO: ipv4 should equal to ipv4-mapped ipv6? but it'll need to parse the ip address
|
||||||
// check ip
|
if proxy, ok := p.ipMap.Load(host); ok {
|
||||||
if proxy, ok := p.ipMap.Load(ip.String()); ok {
|
return proxy.(*FwdrGroup)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check host
|
||||||
|
host = strings.ToLower(host)
|
||||||
|
for i := len(host); i != -1; {
|
||||||
|
i = strings.LastIndexByte(host[:i], '.')
|
||||||
|
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
|
||||||
return proxy.(*FwdrGroup)
|
return proxy.(*FwdrGroup)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check cidr
|
||||||
|
if ip, err := netip.ParseAddr(host); err == nil {
|
||||||
var ret *FwdrGroup
|
var ret *FwdrGroup
|
||||||
// check cidr
|
|
||||||
p.cidrMap.Range(func(key, value any) bool {
|
p.cidrMap.Range(func(key, value any) bool {
|
||||||
cidr := key.(*net.IPNet)
|
if key.(netip.Prefix).Contains(ip) {
|
||||||
if cidr.Contains(ip) {
|
|
||||||
ret = value.(*FwdrGroup)
|
ret = value.(*FwdrGroup)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@ -95,15 +108,6 @@ func (p *Proxy) findDialer(dstAddr string) *FwdrGroup {
|
|||||||
if ret != nil {
|
if ret != nil {
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
host = strings.ToLower(host)
|
|
||||||
for i := len(host); i != -1; {
|
|
||||||
i = strings.LastIndexByte(host[:i], '.')
|
|
||||||
if proxy, ok := p.domainMap.Load(host[i+1:]); ok {
|
|
||||||
return proxy.(*FwdrGroup)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return p.main
|
return p.main
|
||||||
|
@ -3,6 +3,7 @@ package dhcpd
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@ -29,7 +30,7 @@ func (*dpcpd) Run(args ...string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, startIP, endIP, leaseMin := args[0], args[1], args[2], args[3]
|
iface, start, end, leaseMin := args[0], args[1], args[2], args[3]
|
||||||
if i, err := strconv.Atoi(leaseMin); err != nil {
|
if i, err := strconv.Atoi(leaseMin); err != nil {
|
||||||
leaseTime = time.Duration(i) * time.Minute
|
leaseTime = time.Duration(i) * time.Minute
|
||||||
}
|
}
|
||||||
@ -45,7 +46,19 @@ func (*dpcpd) Run(args ...string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
pool, err := NewPool(leaseTime, net.ParseIP(startIP), net.ParseIP(endIP))
|
startIP, err := netip.ParseAddr(start)
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dhcpd] startIP %s is not valid: %s", start, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
endIP, err := netip.ParseAddr(end)
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dhcpd] endIP %s is not valid: %s", end, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pool, err := NewPool(leaseTime, startIP, endIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[dhcpd] error in pool init: %s", err)
|
log.F("[dhcpd] error in pool init: %s", err)
|
||||||
return
|
return
|
||||||
@ -55,12 +68,11 @@ func (*dpcpd) Run(args ...string) {
|
|||||||
for _, host := range args[4:] {
|
for _, host := range args[4:] {
|
||||||
pair := strings.Split(host, "=")
|
pair := strings.Split(host, "=")
|
||||||
if len(pair) == 2 {
|
if len(pair) == 2 {
|
||||||
mac, err := net.ParseMAC(pair[0])
|
if mac, err := net.ParseMAC(pair[0]); err == nil {
|
||||||
if err != nil {
|
if ip, err := netip.ParseAddr(pair[1]); err == nil {
|
||||||
break
|
pool.LeaseStaticIP(mac, ip)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(pair[1])
|
|
||||||
pool.LeaseStaticIP(mac, ip)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,7 +121,7 @@ func handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler {
|
|||||||
dhcpv4.WithMessageType(replyType),
|
dhcpv4.WithMessageType(replyType),
|
||||||
dhcpv4.WithServerIP(serverIP),
|
dhcpv4.WithServerIP(serverIP),
|
||||||
dhcpv4.WithNetmask(mask),
|
dhcpv4.WithNetmask(mask),
|
||||||
dhcpv4.WithYourIP(replyIP),
|
dhcpv4.WithYourIP(replyIP.AsSlice()),
|
||||||
dhcpv4.WithRouter(serverIP),
|
dhcpv4.WithRouter(serverIP),
|
||||||
dhcpv4.WithDNS(serverIP),
|
dhcpv4.WithDNS(serverIP),
|
||||||
// RFC 2131, Section 4.3.1. Server Identifier: MUST
|
// RFC 2131, Section 4.3.1. Server Identifier: MUST
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -17,18 +18,14 @@ type Pool struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type item struct {
|
type item struct {
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
mac net.HardwareAddr
|
mac net.HardwareAddr
|
||||||
expire time.Time
|
expire time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPool returns a new dhcp ip pool.
|
// NewPool returns a new dhcp ip pool.
|
||||||
func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) {
|
func NewPool(lease time.Duration, start, end netip.Addr) (*Pool, error) {
|
||||||
if start == nil || end == nil {
|
s, e := ip2num(start), ip2num(end)
|
||||||
return nil, errors.New("start ip or end ip is wrong/nil, please check your config")
|
|
||||||
}
|
|
||||||
|
|
||||||
s, e := ip2num(start.To4()), ip2num(end.To4())
|
|
||||||
if e < s {
|
if e < s {
|
||||||
return nil, errors.New("start ip larger than end ip")
|
return nil, errors.New("start ip larger than end ip")
|
||||||
}
|
}
|
||||||
@ -57,7 +54,7 @@ func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// LeaseIP leases an ip to mac from dhcp pool.
|
// LeaseIP leases an ip to mac from dhcp pool.
|
||||||
func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) {
|
func (p *Pool) LeaseIP(mac net.HardwareAddr) (netip.Addr, error) {
|
||||||
p.mutex.Lock()
|
p.mutex.Lock()
|
||||||
defer p.mutex.Unlock()
|
defer p.mutex.Unlock()
|
||||||
|
|
||||||
@ -84,16 +81,16 @@ func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errors.New("no more ip can be leased")
|
return netip.Addr{}, errors.New("no more ip can be leased")
|
||||||
}
|
}
|
||||||
|
|
||||||
// LeaseStaticIP leases static ip from pool according to the given mac.
|
// LeaseStaticIP leases static ip from pool according to the given mac.
|
||||||
func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip net.IP) {
|
func (p *Pool) LeaseStaticIP(mac net.HardwareAddr, ip netip.Addr) {
|
||||||
p.mutex.Lock()
|
p.mutex.Lock()
|
||||||
defer p.mutex.Unlock()
|
defer p.mutex.Unlock()
|
||||||
|
|
||||||
for _, item := range p.items {
|
for _, item := range p.items {
|
||||||
if item.ip.Equal(ip) {
|
if item.ip == ip {
|
||||||
item.mac = mac
|
item.mac = mac
|
||||||
item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years
|
item.expire = time.Now().Add(time.Hour * 24 * 365 * 50) // 50 years
|
||||||
}
|
}
|
||||||
@ -113,11 +110,12 @@ func (p *Pool) ReleaseIP(mac net.HardwareAddr) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ip2num(ip net.IP) uint32 {
|
func ip2num(addr netip.Addr) uint32 {
|
||||||
|
ip := addr.As4()
|
||||||
n := uint32(ip[0])<<24 + uint32(ip[1])<<16
|
n := uint32(ip[0])<<24 + uint32(ip[1])<<16
|
||||||
return n + uint32(ip[2])<<8 + uint32(ip[3])
|
return n + uint32(ip[2])<<8 + uint32(ip[3])
|
||||||
}
|
}
|
||||||
|
|
||||||
func num2ip(n uint32) net.IP {
|
func num2ip(n uint32) netip.Addr {
|
||||||
return []byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)}
|
return netip.AddrFrom4([4]byte{byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n)})
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user