dns: update cache when an item expired

This commit is contained in:
nadoo 2020-10-09 22:02:19 +08:00
parent 6d64ee4c0e
commit 6eda2b79c8
10 changed files with 80 additions and 78 deletions

View File

@ -63,7 +63,7 @@ func parseConfig() *Config {
flag.IntVar(&conf.DNSConfig.Timeout, "dnstimeout", 3, "timeout value used in multiple dnsservers switch(seconds)")
flag.IntVar(&conf.DNSConfig.MaxTTL, "dnsmaxttl", 1800, "maximum TTL value for entries in the CACHE(seconds)")
flag.IntVar(&conf.DNSConfig.MinTTL, "dnsminttl", 0, "minimum TTL value for entries in the CACHE(seconds)")
flag.IntVar(&conf.DNSConfig.CacheSize, "dnscachesize", 1024, "size of CACHE")
flag.IntVar(&conf.DNSConfig.CacheSize, "dnscachesize", 4096, "size of CACHE")
flag.StringSliceUniqVar(&conf.DNSConfig.Records, "dnsrecord", nil, "custom dns record, format: domain/ip")
// service configs

View File

@ -9,31 +9,31 @@ import (
type LruCache struct {
mu sync.Mutex
size int
head *Item
tail *Item
cache map[string]*Item
head *item
tail *item
cache map[string]*item
store map[string][]byte
}
// Item is the struct of cache item.
type Item struct {
// item is the struct of cache item.
type item struct {
key string
val []byte
exp int64
prev *Item
next *Item
prev *item
next *item
}
// NewCache returns a new LruCache.
// NewLruCache returns a new LruCache.
func NewLruCache(size int) *LruCache {
// init 2 items here, it doesn't matter cuz they will be deleted when the cache if full
head, tail := &Item{key: "head"}, &Item{key: "tail"}
// init 2 items here, it doesn't matter cuz they will be deleted when the cache is full
head, tail := &item{key: "head"}, &item{key: "tail"}
head.next, tail.prev = tail, head
c := &LruCache{
size: size,
head: head,
tail: tail,
cache: make(map[string]*Item, size),
cache: make(map[string]*item, size),
store: make(map[string][]byte),
}
c.cache[head.key], c.cache[tail.key] = head, tail
@ -82,6 +82,9 @@ func (c *LruCache) Set(k string, v []byte, ttl int) {
}
c.putToHead(k, v, exp)
// NOTE: the cache size will always be c.size + 2,
// but it doesn't matter in our environment.
if len(c.cache) > c.size {
c.removeTail()
}
@ -89,17 +92,17 @@ func (c *LruCache) Set(k string, v []byte, ttl int) {
// putToHead puts a new item to cache's head.
func (c *LruCache) putToHead(k string, v []byte, exp int64) {
it := &Item{key: k, val: v, exp: exp, prev: nil, next: c.head}
c.cache[k] = it
it := &item{key: k, val: v, exp: exp, prev: nil, next: c.head}
it.prev = nil
it.next = c.head
c.head.prev = it
c.head = it
c.cache[k] = it
}
// moveToHead moves an existing item to cache's head.
func (c *LruCache) moveToHead(it *Item) {
func (c *LruCache) moveToHead(it *item) {
if it != c.head {
if c.tail == it {
c.tail = it.prev
@ -119,8 +122,6 @@ func (c *LruCache) moveToHead(it *Item) {
func (c *LruCache) removeTail() {
delete(c.cache, c.tail.key)
if c.tail.prev != nil {
c.tail.prev.next = nil
}
c.tail.prev.next = nil
c.tail = c.tail.prev
}

View File

@ -5,6 +5,7 @@ import (
"errors"
"io"
"net"
"strconv"
"strings"
"time"
@ -56,21 +57,29 @@ func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) {
}
// Exchange handles request message and returns response message.
// NOTE: reqBytes = reqLen + reqMsg.
// TODO: optimize it
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
req, err := UnmarshalMessage(reqBytes[2:])
req, err := UnmarshalMessage(reqBytes)
if err != nil {
return nil, err
}
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
v, _ := c.cache.Get(qKey(req.Question))
if len(v) > 4 {
if v, expired := c.cache.Get(qKey(req.Question)); len(v) > 2 {
v = valCopy(v)
binary.BigEndian.PutUint16(v[2:4], req.ID)
binary.BigEndian.PutUint16(v[:2], req.ID)
log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME)
if expired { // update cache
go func(qname string, reqBytes []byte, preferTCP bool) {
defer pool.PutBuffer(reqBytes)
if dnsServer, network, dialerAddr, respBytes, err := c.exchange(qname, reqBytes, preferTCP); err == nil {
c.handleAnswer(respBytes, clientAddr, dnsServer, network, dialerAddr)
}
}(req.Question.QNAME, valCopy(reqBytes), preferTCP)
}
return v, nil
}
}
@ -86,14 +95,17 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
return respBytes, nil
}
resp, err := UnmarshalMessage(respBytes[2:])
err = c.handleAnswer(respBytes, clientAddr, dnsServer, network, dialerAddr)
return respBytes, err
}
func (c *Client) handleAnswer(respBytes []byte, clientAddr, dnsServer, network, dialerAddr string) error {
resp, err := UnmarshalMessage(respBytes)
if err != nil {
return respBytes, err
return err
}
ips, ttl := c.extractAnswer(resp)
// add to cache only when there's a valid ip address
if len(ips) != 0 && ttl > 0 {
c.cache.Set(qKey(resp.Question), valCopy(respBytes), ttl)
}
@ -101,7 +113,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s",
clientAddr, dnsServer, network, dialerAddr, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
return respBytes, nil
return nil
}
func (c *Client) extractAnswer(resp *Message) ([]string, int) {
@ -197,7 +209,12 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
// exchangeTCP exchange with server over tcp.
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
if _, err := rc.Write(reqBytes); err != nil {
reqLen := pool.GetBuffer(2)
defer pool.PutBuffer(reqLen)
binary.BigEndian.PutUint16(reqLen, uint16(len(reqBytes)))
if _, err := (&net.Buffers{reqLen, reqBytes}).WriteTo(rc); err != nil {
return nil, err
}
@ -206,10 +223,8 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
return nil, err
}
respBytes := pool.GetBuffer(int(respLen) + 2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err := io.ReadFull(rc, respBytes[2:])
respBytes := pool.GetBuffer(int(respLen))
_, err := io.ReadFull(rc, respBytes)
if err != nil {
return nil, err
}
@ -219,18 +234,17 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
// exchangeUDP exchange with server over udp.
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
if _, err := rc.Write(reqBytes[2:]); err != nil {
if _, err := rc.Write(reqBytes); err != nil {
return nil, err
}
respBytes := pool.GetBuffer(UDPMaxLen)
n, err := rc.Read(respBytes[2:])
n, err := rc.Read(respBytes)
if err != nil {
return nil, err
}
binary.BigEndian.PutUint16(respBytes[:2], uint16(n))
return respBytes[:2+n], nil
return respBytes[:n], nil
}
// SetServers sets upstream dns servers for the given domain.
@ -268,15 +282,12 @@ func (c *Client) AddRecord(record string) error {
wb := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(wb)
wb.Write([]byte{0, 0})
n, err := m.MarshalTo(wb)
_, err = m.MarshalTo(wb)
if err != nil {
return err
}
binary.BigEndian.PutUint16(wb.Bytes()[:2], uint16(n))
c.cache.Set(qKey(m.Question), wb.Bytes(), 0)
c.cache.Set(qKey(m.Question), valCopy(wb.Bytes()), 0)
return nil
}
@ -309,13 +320,7 @@ func (c *Client) MakeResponse(domain string, ip string) (*Message, error) {
}
func qKey(q *Question) string {
switch q.QTYPE {
case QTypeA:
return q.QNAME + "/4"
case QTypeAAAA:
return q.QNAME + "/6"
}
return q.QNAME
return q.QNAME + "/" + strconv.FormatUint(uint64(q.QTYPE), 10)
}
func valCopy(v []byte) (b []byte) {

View File

@ -62,22 +62,14 @@ func (s *Server) ListenAndServeUDP(wg *sync.WaitGroup) {
for {
reqBytes := pool.GetBuffer(UDPMaxLen)
n, caddr, err := pc.ReadFrom(reqBytes[2:])
n, caddr, err := pc.ReadFrom(reqBytes)
if err != nil {
log.F("[dns] local read error: %v", err)
pool.PutBuffer(reqBytes)
continue
}
reqLen := uint16(n)
if reqLen <= HeaderLen+2 {
log.F("[dns] not enough message data")
pool.PutBuffer(reqBytes)
continue
}
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
go s.ServePacket(pc, caddr, reqBytes[:2+n])
go s.ServePacket(pc, caddr, reqBytes[:n])
}
}
@ -94,7 +86,7 @@ func (s *Server) ServePacket(pc net.PacketConn, caddr net.Addr, reqBytes []byte)
return
}
_, err = pc.WriteTo(respBytes[2:], caddr)
_, err = pc.WriteTo(respBytes, caddr)
if err != nil {
log.F("[dns] error in local write: %s", err)
return
@ -135,17 +127,15 @@ func (s *Server) ServeTCP(c net.Conn) {
return
}
reqBytes := pool.GetBuffer(int(reqLen) + 2)
reqBytes := pool.GetBuffer(int(reqLen))
defer pool.PutBuffer(reqBytes)
_, err := io.ReadFull(c, reqBytes[2:])
_, err := io.ReadFull(c, reqBytes)
if err != nil {
log.F("[dns-tcp] error in read reqBytes %s", err)
return
}
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
defer pool.PutBuffer(respBytes)
if err != nil {
@ -153,7 +143,11 @@ func (s *Server) ServeTCP(c net.Conn) {
return
}
if _, err := c.Write(respBytes); err != nil {
respLen := pool.GetBuffer(2)
defer pool.PutBuffer(respLen)
binary.BigEndian.PutUint16(respLen, uint16(len(respBytes)))
if _, err := (&net.Buffers{respLen, respBytes}).WriteTo(c); err != nil {
log.F("[dns-tcp] error in write respBytes: %s", err)
return
}

6
go.mod
View File

@ -11,9 +11,9 @@ require (
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/xtaci/kcp-go/v5 v5.5.17
golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 // indirect
golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71 // indirect
golang.org/x/tools v0.0.0-20201008025239-9df69603baec // indirect
golang.org/x/net v0.0.0-20201009032441-dbdefad45b89 // indirect
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634 // indirect
golang.org/x/tools v0.0.0-20201009032223-96877f285f7e // indirect
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
)

12
go.sum
View File

@ -141,8 +141,8 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgN
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M=
golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201009032441-dbdefad45b89 h1:1GKfLldebiSdhTlt3nalwrb7L40Tixr/0IH+kSbRgmk=
golang.org/x/net v0.0.0-20201009032441-dbdefad45b89/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -167,8 +167,8 @@ golang.org/x/sys v0.0.0-20200808120158-1030fc2bf1d9/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200916030750-2334cc1a136f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71 h1:ZPX6UakxrJCxWiyGWpXtFY+fp86Esy7xJT/jJCG8bgU=
golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634 h1:bNEHhJCnrwMKNMmOx3yAynp5vs5/gRy+XWFtZFu7NBM=
golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@ -177,8 +177,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU=
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20201008025239-9df69603baec h1:RY2OghEV/7X1MLaecgm1mwFd3sGvUddm5pGVSxQvX0c=
golang.org/x/tools v0.0.0-20201008025239-9df69603baec/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.0.0-20201009032223-96877f285f7e h1:G1acLyqfyttmexrW7XPhzsaS8m6s+P9XsW9djwh10s4=
golang.org/x/tools v0.0.0-20201009032223-96877f285f7e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=

View File

@ -18,7 +18,7 @@ import (
)
var (
version = "0.11.2"
version = "0.12.0"
config = parseConfig()
)

View File

@ -125,7 +125,7 @@ func getOrigDst(c *net.TCPConn, ipv6 bool) (*net.TCPAddr, error) {
var addr *net.TCPAddr
rc.Control(func(fd uintptr) {
if ipv6 {
addr, err = ipv6_getorigdst(fd)
addr, err = getorigdstIPv6(fd)
} else {
addr, err = getorigdst(fd)
}
@ -150,7 +150,7 @@ func getorigdst(fd uintptr) (*net.TCPAddr, error) {
// Call ipv6_getorigdst() from linux/net/ipv6/netfilter/nf_conntrack_l3proto_ipv6.c
// NOTE: I haven't tried yet but it should work since Linux 3.8.
func ipv6_getorigdst(fd uintptr) (*net.TCPAddr, error) {
func getorigdstIPv6(fd uintptr) (*net.TCPAddr, error) {
const _IP6T_SO_ORIGINAL_DST = 80 // from linux/include/uapi/linux/netfilter_ipv6/ip6_tables.h
var raw syscall.RawSockaddrInet6
siz := unsafe.Sizeof(raw)

View File

@ -5,6 +5,7 @@ import (
"unsafe"
)
// https://github.com/golang/go/blob/9e6b79a5dfb2f6fe4301ced956419a0da83bd025/src/syscall/syscall_linux_386.go#L196
const GETSOCKOPT = 15 // https://golang.org/src/syscall/syscall_linux_386.go#L183
func socketcall(call, a0, a1, a2, a3, a4, a5 uintptr) error {

View File

@ -4,6 +4,7 @@ package redir
import "syscall"
// GETSOCKOPT from syscall
const GETSOCKOPT = syscall.SYS_GETSOCKOPT
func socketcall(call, a0, a1, a2, a3, a4, a5 uintptr) error {