From 6eda2b79c8342eb3393879b310b78458fc5d2c29 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Fri, 9 Oct 2020 22:02:19 +0800 Subject: [PATCH] dns: update cache when an item expired --- config.go | 2 +- dns/cache.go | 37 +++++++++--------- dns/client.go | 67 +++++++++++++++++--------------- dns/server.go | 26 +++++-------- go.mod | 6 +-- go.sum | 12 +++--- main.go | 2 +- proxy/redir/redir_linux.go | 4 +- proxy/redir/redir_linux_386.go | 1 + proxy/redir/redir_linux_other.go | 1 + 10 files changed, 80 insertions(+), 78 deletions(-) diff --git a/config.go b/config.go index 021420a..e10e968 100644 --- a/config.go +++ b/config.go @@ -63,7 +63,7 @@ func parseConfig() *Config { flag.IntVar(&conf.DNSConfig.Timeout, "dnstimeout", 3, "timeout value used in multiple dnsservers switch(seconds)") flag.IntVar(&conf.DNSConfig.MaxTTL, "dnsmaxttl", 1800, "maximum TTL value for entries in the CACHE(seconds)") flag.IntVar(&conf.DNSConfig.MinTTL, "dnsminttl", 0, "minimum TTL value for entries in the CACHE(seconds)") - flag.IntVar(&conf.DNSConfig.CacheSize, "dnscachesize", 1024, "size of CACHE") + flag.IntVar(&conf.DNSConfig.CacheSize, "dnscachesize", 4096, "size of CACHE") flag.StringSliceUniqVar(&conf.DNSConfig.Records, "dnsrecord", nil, "custom dns record, format: domain/ip") // service configs diff --git a/dns/cache.go b/dns/cache.go index 4771b43..9fca040 100644 --- a/dns/cache.go +++ b/dns/cache.go @@ -9,31 +9,31 @@ import ( type LruCache struct { mu sync.Mutex size int - head *Item - tail *Item - cache map[string]*Item + head *item + tail *item + cache map[string]*item store map[string][]byte } -// Item is the struct of cache item. -type Item struct { +// item is the struct of cache item. +type item struct { key string val []byte exp int64 - prev *Item - next *Item + prev *item + next *item } -// NewCache returns a new LruCache. +// NewLruCache returns a new LruCache. func NewLruCache(size int) *LruCache { - // init 2 items here, it doesn't matter cuz they will be deleted when the cache if full - head, tail := &Item{key: "head"}, &Item{key: "tail"} + // init 2 items here, it doesn't matter cuz they will be deleted when the cache is full + head, tail := &item{key: "head"}, &item{key: "tail"} head.next, tail.prev = tail, head c := &LruCache{ size: size, head: head, tail: tail, - cache: make(map[string]*Item, size), + cache: make(map[string]*item, size), store: make(map[string][]byte), } c.cache[head.key], c.cache[tail.key] = head, tail @@ -82,6 +82,9 @@ func (c *LruCache) Set(k string, v []byte, ttl int) { } c.putToHead(k, v, exp) + + // NOTE: the cache size will always be c.size + 2, + // but it doesn't matter in our environment. if len(c.cache) > c.size { c.removeTail() } @@ -89,17 +92,17 @@ func (c *LruCache) Set(k string, v []byte, ttl int) { // putToHead puts a new item to cache's head. func (c *LruCache) putToHead(k string, v []byte, exp int64) { - it := &Item{key: k, val: v, exp: exp, prev: nil, next: c.head} - c.cache[k] = it - + it := &item{key: k, val: v, exp: exp, prev: nil, next: c.head} it.prev = nil it.next = c.head c.head.prev = it c.head = it + + c.cache[k] = it } // moveToHead moves an existing item to cache's head. -func (c *LruCache) moveToHead(it *Item) { +func (c *LruCache) moveToHead(it *item) { if it != c.head { if c.tail == it { c.tail = it.prev @@ -119,8 +122,6 @@ func (c *LruCache) moveToHead(it *Item) { func (c *LruCache) removeTail() { delete(c.cache, c.tail.key) - if c.tail.prev != nil { - c.tail.prev.next = nil - } + c.tail.prev.next = nil c.tail = c.tail.prev } diff --git a/dns/client.go b/dns/client.go index 757fc5f..616c3b4 100644 --- a/dns/client.go +++ b/dns/client.go @@ -5,6 +5,7 @@ import ( "errors" "io" "net" + "strconv" "strings" "time" @@ -56,21 +57,29 @@ func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) { } // Exchange handles request message and returns response message. -// NOTE: reqBytes = reqLen + reqMsg. +// TODO: optimize it func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) { - req, err := UnmarshalMessage(reqBytes[2:]) + req, err := UnmarshalMessage(reqBytes) if err != nil { return nil, err } if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { - v, _ := c.cache.Get(qKey(req.Question)) - if len(v) > 4 { + if v, expired := c.cache.Get(qKey(req.Question)); len(v) > 2 { v = valCopy(v) - binary.BigEndian.PutUint16(v[2:4], req.ID) + binary.BigEndian.PutUint16(v[:2], req.ID) + log.F("[dns] %s <-> cache, type: %d, %s", clientAddr, req.Question.QTYPE, req.Question.QNAME) + if expired { // update cache + go func(qname string, reqBytes []byte, preferTCP bool) { + defer pool.PutBuffer(reqBytes) + if dnsServer, network, dialerAddr, respBytes, err := c.exchange(qname, reqBytes, preferTCP); err == nil { + c.handleAnswer(respBytes, clientAddr, dnsServer, network, dialerAddr) + } + }(req.Question.QNAME, valCopy(reqBytes), preferTCP) + } return v, nil } } @@ -86,14 +95,17 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ return respBytes, nil } - resp, err := UnmarshalMessage(respBytes[2:]) + err = c.handleAnswer(respBytes, clientAddr, dnsServer, network, dialerAddr) + return respBytes, err +} + +func (c *Client) handleAnswer(respBytes []byte, clientAddr, dnsServer, network, dialerAddr string) error { + resp, err := UnmarshalMessage(respBytes) if err != nil { - return respBytes, err + return err } ips, ttl := c.extractAnswer(resp) - - // add to cache only when there's a valid ip address if len(ips) != 0 && ttl > 0 { c.cache.Set(qKey(resp.Question), valCopy(respBytes), ttl) } @@ -101,7 +113,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s", clientAddr, dnsServer, network, dialerAddr, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) - return respBytes, nil + return nil } func (c *Client) extractAnswer(resp *Message) ([]string, int) { @@ -197,7 +209,12 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) ( // exchangeTCP exchange with server over tcp. func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { - if _, err := rc.Write(reqBytes); err != nil { + reqLen := pool.GetBuffer(2) + defer pool.PutBuffer(reqLen) + + binary.BigEndian.PutUint16(reqLen, uint16(len(reqBytes))) + + if _, err := (&net.Buffers{reqLen, reqBytes}).WriteTo(rc); err != nil { return nil, err } @@ -206,10 +223,8 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { return nil, err } - respBytes := pool.GetBuffer(int(respLen) + 2) - binary.BigEndian.PutUint16(respBytes[:2], respLen) - - _, err := io.ReadFull(rc, respBytes[2:]) + respBytes := pool.GetBuffer(int(respLen)) + _, err := io.ReadFull(rc, respBytes) if err != nil { return nil, err } @@ -219,18 +234,17 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { // exchangeUDP exchange with server over udp. func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) { - if _, err := rc.Write(reqBytes[2:]); err != nil { + if _, err := rc.Write(reqBytes); err != nil { return nil, err } respBytes := pool.GetBuffer(UDPMaxLen) - n, err := rc.Read(respBytes[2:]) + n, err := rc.Read(respBytes) if err != nil { return nil, err } - binary.BigEndian.PutUint16(respBytes[:2], uint16(n)) - return respBytes[:2+n], nil + return respBytes[:n], nil } // SetServers sets upstream dns servers for the given domain. @@ -268,15 +282,12 @@ func (c *Client) AddRecord(record string) error { wb := pool.GetWriteBuffer() defer pool.PutWriteBuffer(wb) - wb.Write([]byte{0, 0}) - - n, err := m.MarshalTo(wb) + _, err = m.MarshalTo(wb) if err != nil { return err } - binary.BigEndian.PutUint16(wb.Bytes()[:2], uint16(n)) - c.cache.Set(qKey(m.Question), wb.Bytes(), 0) + c.cache.Set(qKey(m.Question), valCopy(wb.Bytes()), 0) return nil } @@ -309,13 +320,7 @@ func (c *Client) MakeResponse(domain string, ip string) (*Message, error) { } func qKey(q *Question) string { - switch q.QTYPE { - case QTypeA: - return q.QNAME + "/4" - case QTypeAAAA: - return q.QNAME + "/6" - } - return q.QNAME + return q.QNAME + "/" + strconv.FormatUint(uint64(q.QTYPE), 10) } func valCopy(v []byte) (b []byte) { diff --git a/dns/server.go b/dns/server.go index c44e63b..3cb6fe6 100644 --- a/dns/server.go +++ b/dns/server.go @@ -62,22 +62,14 @@ func (s *Server) ListenAndServeUDP(wg *sync.WaitGroup) { for { reqBytes := pool.GetBuffer(UDPMaxLen) - n, caddr, err := pc.ReadFrom(reqBytes[2:]) + n, caddr, err := pc.ReadFrom(reqBytes) if err != nil { log.F("[dns] local read error: %v", err) pool.PutBuffer(reqBytes) continue } - reqLen := uint16(n) - if reqLen <= HeaderLen+2 { - log.F("[dns] not enough message data") - pool.PutBuffer(reqBytes) - continue - } - binary.BigEndian.PutUint16(reqBytes[:2], reqLen) - - go s.ServePacket(pc, caddr, reqBytes[:2+n]) + go s.ServePacket(pc, caddr, reqBytes[:n]) } } @@ -94,7 +86,7 @@ func (s *Server) ServePacket(pc net.PacketConn, caddr net.Addr, reqBytes []byte) return } - _, err = pc.WriteTo(respBytes[2:], caddr) + _, err = pc.WriteTo(respBytes, caddr) if err != nil { log.F("[dns] error in local write: %s", err) return @@ -135,17 +127,15 @@ func (s *Server) ServeTCP(c net.Conn) { return } - reqBytes := pool.GetBuffer(int(reqLen) + 2) + reqBytes := pool.GetBuffer(int(reqLen)) defer pool.PutBuffer(reqBytes) - _, err := io.ReadFull(c, reqBytes[2:]) + _, err := io.ReadFull(c, reqBytes) if err != nil { log.F("[dns-tcp] error in read reqBytes %s", err) return } - binary.BigEndian.PutUint16(reqBytes[:2], reqLen) - respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true) defer pool.PutBuffer(respBytes) if err != nil { @@ -153,7 +143,11 @@ func (s *Server) ServeTCP(c net.Conn) { return } - if _, err := c.Write(respBytes); err != nil { + respLen := pool.GetBuffer(2) + defer pool.PutBuffer(respLen) + binary.BigEndian.PutUint16(respLen, uint16(len(respBytes))) + + if _, err := (&net.Buffers{respLen, respBytes}).WriteTo(c); err != nil { log.F("[dns-tcp] error in write respBytes: %s", err) return } diff --git a/go.mod b/go.mod index 04139d4..dced3ac 100644 --- a/go.mod +++ b/go.mod @@ -11,9 +11,9 @@ require ( github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/xtaci/kcp-go/v5 v5.5.17 golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 - golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 // indirect - golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71 // indirect - golang.org/x/tools v0.0.0-20201008025239-9df69603baec // indirect + golang.org/x/net v0.0.0-20201009032441-dbdefad45b89 // indirect + golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634 // indirect + golang.org/x/tools v0.0.0-20201009032223-96877f285f7e // indirect gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect ) diff --git a/go.sum b/go.sum index c687699..684dc79 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgN golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0 h1:wBouT66WTYFXdxfVdz9sVWARVd/2vfGcmI45D2gj45M= -golang.org/x/net v0.0.0-20201006153459-a7d1128ccaa0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201009032441-dbdefad45b89 h1:1GKfLldebiSdhTlt3nalwrb7L40Tixr/0IH+kSbRgmk= +golang.org/x/net v0.0.0-20201009032441-dbdefad45b89/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -167,8 +167,8 @@ golang.org/x/sys v0.0.0-20200808120158-1030fc2bf1d9/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200916030750-2334cc1a136f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71 h1:ZPX6UakxrJCxWiyGWpXtFY+fp86Esy7xJT/jJCG8bgU= -golang.org/x/sys v0.0.0-20201008064518-c1f3e3309c71/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634 h1:bNEHhJCnrwMKNMmOx3yAynp5vs5/gRy+XWFtZFu7NBM= +golang.org/x/sys v0.0.0-20201009025420-dfb3f7c4e634/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -177,8 +177,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20201008025239-9df69603baec h1:RY2OghEV/7X1MLaecgm1mwFd3sGvUddm5pGVSxQvX0c= -golang.org/x/tools v0.0.0-20201008025239-9df69603baec/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= +golang.org/x/tools v0.0.0-20201009032223-96877f285f7e h1:G1acLyqfyttmexrW7XPhzsaS8m6s+P9XsW9djwh10s4= +golang.org/x/tools v0.0.0-20201009032223-96877f285f7e/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/main.go b/main.go index c2dcbc4..9619197 100644 --- a/main.go +++ b/main.go @@ -18,7 +18,7 @@ import ( ) var ( - version = "0.11.2" + version = "0.12.0" config = parseConfig() ) diff --git a/proxy/redir/redir_linux.go b/proxy/redir/redir_linux.go index d929806..acf365b 100644 --- a/proxy/redir/redir_linux.go +++ b/proxy/redir/redir_linux.go @@ -125,7 +125,7 @@ func getOrigDst(c *net.TCPConn, ipv6 bool) (*net.TCPAddr, error) { var addr *net.TCPAddr rc.Control(func(fd uintptr) { if ipv6 { - addr, err = ipv6_getorigdst(fd) + addr, err = getorigdstIPv6(fd) } else { addr, err = getorigdst(fd) } @@ -150,7 +150,7 @@ func getorigdst(fd uintptr) (*net.TCPAddr, error) { // Call ipv6_getorigdst() from linux/net/ipv6/netfilter/nf_conntrack_l3proto_ipv6.c // NOTE: I haven't tried yet but it should work since Linux 3.8. -func ipv6_getorigdst(fd uintptr) (*net.TCPAddr, error) { +func getorigdstIPv6(fd uintptr) (*net.TCPAddr, error) { const _IP6T_SO_ORIGINAL_DST = 80 // from linux/include/uapi/linux/netfilter_ipv6/ip6_tables.h var raw syscall.RawSockaddrInet6 siz := unsafe.Sizeof(raw) diff --git a/proxy/redir/redir_linux_386.go b/proxy/redir/redir_linux_386.go index 32f692d..6d50aa9 100644 --- a/proxy/redir/redir_linux_386.go +++ b/proxy/redir/redir_linux_386.go @@ -5,6 +5,7 @@ import ( "unsafe" ) +// https://github.com/golang/go/blob/9e6b79a5dfb2f6fe4301ced956419a0da83bd025/src/syscall/syscall_linux_386.go#L196 const GETSOCKOPT = 15 // https://golang.org/src/syscall/syscall_linux_386.go#L183 func socketcall(call, a0, a1, a2, a3, a4, a5 uintptr) error { diff --git a/proxy/redir/redir_linux_other.go b/proxy/redir/redir_linux_other.go index 9547282..a2677b9 100644 --- a/proxy/redir/redir_linux_other.go +++ b/proxy/redir/redir_linux_other.go @@ -4,6 +4,7 @@ package redir import "syscall" +// GETSOCKOPT from syscall const GETSOCKOPT = syscall.SYS_GETSOCKOPT func socketcall(call, a0, a1, a2, a3, a4, a5 uintptr) error {