glider/service/dhcpd/dhcpd.go

213 lines
5.1 KiB
Go

package dhcpd
import (
"bytes"
"errors"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"time"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv4/server4"
"github.com/nadoo/glider/pkg/log"
"github.com/nadoo/glider/service"
)
var leaseTime = time.Hour * 12
func init() {
service.Register("dhcpd", &dhcpd{})
service.Register("dhcpd-failover", &dhcpd{detect: true})
}
type dhcpd struct {
mu sync.Mutex
detect bool
failover bool
intface *net.Interface
}
// Run runs the service.
func (d *dhcpd) Run(args ...string) {
if len(args) < 4 {
log.F("[dhcpd] not enough parameters, exiting")
return
}
iface, start, end, leaseMin := args[0], args[1], args[2], args[3]
if i, err := strconv.Atoi(leaseMin); err != nil {
leaseTime = time.Duration(i) * time.Minute
}
intf, ip, mask, err := ifaceAddr(iface)
if err != nil {
log.F("[dhcpd] get ip of interface '%s' error: %s", iface, err)
return
}
d.intface = intf
if d.detect {
d.setFailoverMode(discovery(intf))
go d.detectServer(time.Second * 60)
}
startIP, err := netip.ParseAddr(start)
if err != nil {
log.F("[dhcpd] startIP %s is not valid: %s", start, err)
return
}
endIP, err := netip.ParseAddr(end)
if err != nil {
log.F("[dhcpd] endIP %s is not valid: %s", end, err)
return
}
pool, err := NewPool(leaseTime, startIP, endIP)
if err != nil {
log.F("[dhcpd] error in pool init: %s", err)
return
}
// static ips
for _, host := range args[4:] {
if mac, ip, ok := strings.Cut(host, "="); ok {
if mac, err := net.ParseMAC(mac); err == nil {
if ip, err := netip.ParseAddr(ip); err == nil {
pool.LeaseStaticIP(mac, ip)
}
}
}
}
laddr := net.UDPAddr{IP: net.IPv4(0, 0, 0, 0), Port: 67}
server, err := server4.NewServer(iface, &laddr, d.handleDHCP(ip, mask, pool))
if err != nil {
log.F("[dhcpd] error in server creation: %s", err)
return
}
log.F("[dhcpd] Listening on interface %s(%s/%d.%d.%d.%d), server detection: %t",
iface, ip, mask[0], mask[1], mask[2], mask[3], d.detect)
server.Serve()
}
func (d *dhcpd) handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler {
return func(conn net.PacketConn, peer net.Addr, m *dhcpv4.DHCPv4) {
var reqType, replyType dhcpv4.MessageType
switch reqType = m.MessageType(); reqType {
case dhcpv4.MessageTypeDiscover:
replyType = dhcpv4.MessageTypeOffer
case dhcpv4.MessageTypeRequest, dhcpv4.MessageTypeInform:
replyType = dhcpv4.MessageTypeAck
case dhcpv4.MessageTypeRelease:
pool.ReleaseIP(m.ClientHWAddr)
log.F("[dpcpd] %v released ip %v", m.ClientHWAddr, m.ClientIPAddr)
return
case dhcpv4.MessageTypeDecline:
pool.ReleaseIP(m.ClientHWAddr)
log.F("[dpcpd] received decline message from %v", m.ClientHWAddr)
return
default:
log.F("[dpcpd] can't handle type %v", reqType)
return
}
if d.inFailoverMode() || bytes.Equal(d.intface.HardwareAddr, m.ClientHWAddr) {
return
}
replyIP, err := pool.LeaseIP(m.ClientHWAddr)
if err != nil {
log.F("[dpcpd] can not assign IP, error %s", err)
return
}
reply, err := dhcpv4.NewReplyFromRequest(m,
dhcpv4.WithMessageType(replyType),
dhcpv4.WithServerIP(serverIP),
dhcpv4.WithNetmask(mask),
dhcpv4.WithYourIP(replyIP.AsSlice()),
dhcpv4.WithRouter(serverIP),
dhcpv4.WithDNS(serverIP),
// RFC 2131, Section 4.3.1. Server Identifier: MUST
dhcpv4.WithOption(dhcpv4.OptServerIdentifier(serverIP)),
// RFC 2131, Section 4.3.1. IP lease time: MUST
dhcpv4.WithOption(dhcpv4.OptIPAddressLeaseTime(leaseTime)),
)
if err != nil {
log.F("[dpcpd] can not create reply message, error %s", err)
return
}
if val := m.Options.Get(dhcpv4.OptionClientIdentifier); len(val) > 0 {
reply.UpdateOption(dhcpv4.OptGeneric(dhcpv4.OptionClientIdentifier, val))
}
if _, err := conn.WriteTo(reply.ToBytes(), peer); err != nil {
log.F("[dpcpd] could not write to client %s(%s): %s", peer, reply.ClientHWAddr, err)
return
}
log.F("[dpcpd] lease %v to client %v", replyIP, reply.ClientHWAddr)
}
}
func (d *dhcpd) inFailoverMode() bool {
d.mu.Lock()
defer d.mu.Unlock()
return d.failover
}
func (d *dhcpd) setFailoverMode(v bool) {
d.mu.Lock()
defer d.mu.Unlock()
if d.failover != v {
if v {
log.F("[dpcpd] existing dhcp server detected, enter failover mode")
} else {
log.F("[dpcpd] no dhcp server detected, exit failover mode")
}
}
d.failover = v
}
func (d *dhcpd) detectServer(interval time.Duration) {
for {
d.setFailoverMode(discovery(d.intface))
time.Sleep(interval)
}
}
func ifaceAddr(iface string) (*net.Interface, net.IP, net.IPMask, error) {
intf, err := net.InterfaceByName(iface)
if err != nil {
return nil, nil, nil, err
}
addrs, err := intf.Addrs()
if err != nil {
return intf, nil, nil, err
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok {
if ipnet.IP.IsLoopback() {
return intf, nil, nil, errors.New("can't use loopback interface")
}
if ip4 := ipnet.IP.To4(); ip4 != nil {
return intf, ip4, ipnet.Mask, nil
}
}
}
return intf, nil, nil, errors.New("no ip/mask defined on this interface")
}