diff --git a/config/glider.conf.example b/config/glider.conf.example index a260e7f..2db7b63 100644 --- a/config/glider.conf.example +++ b/config/glider.conf.example @@ -233,6 +233,9 @@ dnsminttl=0 # size of CACHE dnscachesize=4096 +# show query log of dns cache +dnscachelog=True + # custom records dnsrecord=www.example.com/1.2.3.4 dnsrecord=www.example.com/2606:2800:220:1:248:1893:25c8:1946 diff --git a/dns/client.go b/dns/client.go index 2347546..d43b3a0 100644 --- a/dns/client.go +++ b/dns/client.go @@ -109,13 +109,19 @@ func (c *Client) handleAnswer(respBytes []byte, clientAddr, dnsServer, network, } ips, ttl := c.extractAnswer(resp) - if ttl == 0 { // we got a null result - ttl = 600 + if ttl > c.config.MaxTTL { + ttl = c.config.MaxTTL + } else if ttl < c.config.MinTTL { + ttl = c.config.MinTTL + } + + if ttl <= 0 { // we got a null result + ttl = 1800 } c.cache.Set(qKey(resp.Question), valCopy(respBytes), ttl) - log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s", - clientAddr, dnsServer, network, dialerAddr, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) + log.F("[dns] %s <-> %s(%s) via %s, %s/%d: %d %s", + clientAddr, dnsServer, network, dialerAddr, resp.Question.QNAME, resp.Question.QTYPE, ttl, strings.Join(ips, ",")) return nil } @@ -137,12 +143,6 @@ func (c *Client) extractAnswer(resp *Message) ([]string, int) { } } - if ttl > c.config.MaxTTL { - ttl = c.config.MaxTTL - } else if ttl < c.config.MinTTL { - ttl = c.config.MinTTL - } - return ips, ttl } @@ -277,7 +277,7 @@ func (c *Client) AddHandler(h AnswerHandler) { func (c *Client) AddRecord(record string) error { r := strings.Split(record, "/") domain, ip := r[0], r[1] - m, err := c.MakeResponse(domain, ip) + m, err := c.MakeResponse(domain, ip, uint32(c.config.MaxTTL)) if err != nil { return err } @@ -296,10 +296,11 @@ func (c *Client) AddRecord(record string) error { } // MakeResponse makes a dns response message for the given domain and ip address. -func (c *Client) MakeResponse(domain string, ip string) (*Message, error) { +// Note: you should make sure ttl > 0. +func (c *Client) MakeResponse(domain, ip string, ttl uint32) (*Message, error) { ipb := net.ParseIP(ip) if ipb == nil { - return nil, errors.New("GenResponse: invalid ip format") + return nil, errors.New("MakeResponse: invalid ip format") } var rdata []byte @@ -316,7 +317,7 @@ func (c *Client) MakeResponse(domain string, ip string) (*Message, error) { m := NewMessage(0, Response) m.SetQuestion(NewQuestion(qtype, domain)) rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET, - TTL: uint32(c.config.MinTTL), RDLENGTH: rdlen, RDATA: rdata} + TTL: ttl, RDLENGTH: rdlen, RDATA: rdata} m.AddAnswer(rr) return m, nil diff --git a/go.mod b/go.mod index d030756..9fb5b0b 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/xtaci/kcp-go/v5 v5.6.1 golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a golang.org/x/net v0.0.0-20210525063256-abc453219eb5 // indirect + golang.org/x/sys v0.0.0-20210603125802-9665404d3644 // indirect ) // Replace dependency modules with local developing copy diff --git a/go.sum b/go.sum index df9b7ba..b90c019 100644 --- a/go.sum +++ b/go.sum @@ -154,8 +154,11 @@ golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201101102859-da207088b7d1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea h1:+WiDlPBBaO+h9vPNZi8uJ3k4BkKQB7Iow3aqwHVA5hI= golang.org/x/sys v0.0.0-20210525143221-35b2ab0089ea/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b h1:qh4f65QIVFjq9eBURLEYWqaEXmOyqdUyiBSgaXWccWk= +golang.org/x/sys v0.0.0-20210601080250-7ecdf8ef093b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644 h1:CA1DEQ4NdKphKeL70tvsWNdT5oFh1lOjihRcEDROi0I= +golang.org/x/sys v0.0.0-20210603125802-9665404d3644/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/service/dhcpd/dhcpd.go b/service/dhcpd/dhcpd.go index e583a1c..e889c11 100644 --- a/service/dhcpd/dhcpd.go +++ b/service/dhcpd/dhcpd.go @@ -67,14 +67,22 @@ func handleDHCP(serverIP net.IP, mask net.IPMask, pool *Pool) server4.Handler { switch mt := m.MessageType(); mt { case dhcpv4.MessageTypeDiscover: replyType = dhcpv4.MessageTypeOffer - case dhcpv4.MessageTypeRequest: + case dhcpv4.MessageTypeRequest, dhcpv4.MessageTypeInform: replyType = dhcpv4.MessageTypeAck + case dhcpv4.MessageTypeRelease: + pool.ReleaseIP(m.ClientHWAddr) + log.F("[dpcpd] %v released ip %v", m.ClientHWAddr, m.ClientIPAddr) + return + case dhcpv4.MessageTypeDecline: + pool.ReleaseIP(m.ClientHWAddr) + log.F("[dpcpd] received decline message from %v", m.ClientHWAddr) + return default: log.F("[dpcpd] can't handle type %v", mt) return } - replyIp, err := pool.AssignIP(m.ClientHWAddr) + replyIp, err := pool.LeaseIP(m.ClientHWAddr) if err != nil { log.F("[dpcpd] can not assign IP, error %s", err) return diff --git a/service/dhcpd/pool.go b/service/dhcpd/pool.go index c06313a..6b483d7 100644 --- a/service/dhcpd/pool.go +++ b/service/dhcpd/pool.go @@ -5,12 +5,21 @@ import ( "errors" "math/rand" "net" + "sync" "time" ) // Pool is a dhcp pool. type Pool struct { items []*item + mutex sync.RWMutex + lease time.Duration +} + +type item struct { + ip net.IP + mac net.HardwareAddr + expire time.Time } // NewPool returns a new dhcp ip pool. @@ -26,15 +35,32 @@ func NewPool(lease time.Duration, start, end net.IP) (*Pool, error) { items := make([]*item, 0, e-s+1) for n := s; n <= e; n++ { - items = append(items, &item{lease: lease, ip: num2ip(n)}) + items = append(items, &item{ip: num2ip(n)}) } rand.Seed(time.Now().Unix()) - return &Pool{items: items}, nil + + p := &Pool{items: items, lease: lease} + go func() { + for now := range time.Tick(time.Second) { + p.mutex.Lock() + for i := 0; i < len(items); i++ { + if !items[i].expire.IsZero() && now.After(items[i].expire) { + items[i].mac = nil + items[i].expire = time.Time{} + } + } + p.mutex.Unlock() + } + }() + + return p, nil } -// AssignIP assigns an ip to mac from dhco pool. -func (p *Pool) AssignIP(mac net.HardwareAddr) (net.IP, error) { - var ip net.IP +// LeaseIP leases an ip to mac from dhcp pool. +func (p *Pool) LeaseIP(mac net.HardwareAddr) (net.IP, error) { + p.mutex.Lock() + defer p.mutex.Unlock() + for _, item := range p.items { if bytes.Equal(mac, item.mac) { return item.ip, nil @@ -43,39 +69,35 @@ func (p *Pool) AssignIP(mac net.HardwareAddr) (net.IP, error) { idx := rand.Intn(len(p.items)) for _, item := range p.items[idx:] { - if ip = item.take(mac); ip != nil { - return ip, nil + if item.mac == nil { + item.mac = mac + item.expire = time.Now().Add(p.lease) + return item.ip, nil } } for _, item := range p.items { - if ip = item.take(mac); ip != nil { - return ip, nil + if item.mac == nil { + item.mac = mac + item.expire = time.Now().Add(p.lease) + return item.ip, nil } } - return nil, errors.New("no more ip can be assigned") + + return nil, errors.New("no more ip can be leased") } -type item struct { - taken bool - ip net.IP - lease time.Duration - mac net.HardwareAddr -} +// ReleaseIP releases ip from pool according to the given mac. +func (p *Pool) ReleaseIP(mac net.HardwareAddr) { + p.mutex.Lock() + defer p.mutex.Unlock() -func (i *item) take(addr net.HardwareAddr) net.IP { - if !i.taken { - i.taken = true - go func() { - timer := time.NewTimer(i.lease) - <-timer.C - i.mac = nil - i.taken = false - }() - i.mac = addr - return i.ip + for _, item := range p.items { + if bytes.Equal(mac, item.mac) { + item.mac = nil + item.expire = time.Time{} + } } - return nil } func ip2num(ip net.IP) uint32 {