dns: optimize codes

This commit is contained in:
nadoo 2020-08-23 23:23:30 +08:00
parent a42d3a68d0
commit f65a983da8
7 changed files with 179 additions and 82 deletions

View File

@ -7,7 +7,6 @@ import (
const (
// number of pools.
// pool sizes: [1<<0 ~ 1<<(num-1)] bytes, [1B~64KB].
num = 17
maxsize = 1 << (num - 1)
)
@ -42,9 +41,10 @@ func GetBuffer(size int) []byte {
// PutBuffer puts a buffer into pool.
func PutBuffer(buf []byte) {
size := cap(buf)
i := bits.Len32(uint32(size)) - 1
if i < num && sizes[i] == size {
pools[i].Put(buf)
if size := cap(buf); size >= 1 && size <= maxsize {
i := bits.Len32(uint32(size)) - 1
if sizes[i] == size {
pools[i].Put(buf)
}
}
}

View File

@ -3,6 +3,8 @@ package dns
import (
"sync"
"time"
"github.com/nadoo/glider/common/pool"
)
// LongTTL is 50 years duration in seconds, used for none-expired items.
@ -15,22 +17,26 @@ type item struct {
// Cache is the struct of cache.
type Cache struct {
m map[string]*item
l sync.RWMutex
store map[string]*item
mutex sync.RWMutex
storeCopy bool
}
// NewCache returns a new cache.
func NewCache() (c *Cache) {
c = &Cache{m: make(map[string]*item)}
func NewCache(storeCopy bool) (c *Cache) {
c = &Cache{store: make(map[string]*item), storeCopy: storeCopy}
go func() {
for now := range time.Tick(time.Second) {
c.l.Lock()
for k, v := range c.m {
c.mutex.Lock()
for k, v := range c.store {
if now.After(v.expire) {
delete(c.m, k)
delete(c.store, k)
if storeCopy {
pool.PutBuffer(v.value)
}
}
}
c.l.Unlock()
c.mutex.Unlock()
}
}()
return
@ -38,29 +44,46 @@ func NewCache() (c *Cache) {
// Len returns the length of cache.
func (c *Cache) Len() int {
return len(c.m)
return len(c.store)
}
// Put an item into cache, invalid after ttl seconds.
func (c *Cache) Put(k string, v []byte, ttl int) {
if len(v) != 0 {
c.l.Lock()
it, ok := c.m[k]
c.mutex.Lock()
it, ok := c.store[k]
if !ok {
it = &item{value: v}
c.m[k] = it
if c.storeCopy {
it = &item{value: valCopy(v)}
} else {
it = &item{value: v}
}
c.store[k] = it
}
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
c.l.Unlock()
c.mutex.Unlock()
}
}
// Get an item from cache.
// Get gets an item from cache(do not modify it).
func (c *Cache) Get(k string) (v []byte) {
c.l.RLock()
if it, ok := c.m[k]; ok {
c.mutex.RLock()
if it, ok := c.store[k]; ok {
v = it.value
}
c.l.RUnlock()
c.mutex.RUnlock()
return
}
// GetCopy gets an item from cache and returns it's copy(so you can modify it).
func (c *Cache) GetCopy(k string) []byte {
return valCopy(c.Get(k))
}
func valCopy(v []byte) (b []byte) {
if v != nil {
b = pool.GetBuffer(len(v))
copy(b, v)
}
return
}

View File

@ -1,15 +1,16 @@
package dns
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"strings"
"time"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy"
)
@ -40,7 +41,7 @@ type Client struct {
func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) {
c := &Client{
proxy: proxy,
cache: NewCache(),
cache: NewCache(true),
config: config,
upStream: NewUPStream(config.Servers),
upStreamMap: make(map[string]*UPStream),
@ -63,8 +64,8 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
}
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
v := c.cache.Get(getKey(req.Question))
if v != nil {
v := c.cache.GetCopy(qKey(req.Question))
if len(v) > 4 {
binary.BigEndian.PutUint16(v[2:4], req.ID)
log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME)
@ -93,7 +94,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
// add to cache only when there's a valid ip address
if len(ips) != 0 && ttl > 0 {
c.cache.Put(getKey(resp.Question), respBytes, ttl)
c.cache.Put(qKey(resp.Question), respBytes, ttl)
}
log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s",
@ -204,7 +205,7 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
return nil, err
}
respBytes := make([]byte, respLen+2)
respBytes := pool.GetBuffer(int(respLen) + 2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err := io.ReadFull(rc, respBytes[2:])
@ -221,7 +222,7 @@ func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
return nil, err
}
respBytes := make([]byte, UDPMaxLen)
respBytes := pool.GetBuffer(UDPMaxLen)
n, err := rc.Read(respBytes[2:])
if err != nil {
return nil, err
@ -258,24 +259,29 @@ func (c *Client) AddHandler(h HandleFunc) {
func (c *Client) AddRecord(record string) error {
r := strings.Split(record, "/")
domain, ip := r[0], r[1]
m, err := c.GenResponse(domain, ip)
m, err := c.MakeResponse(domain, ip)
if err != nil {
return err
}
b, _ := m.Marshal()
wb := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(wb)
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
wb.Write([]byte{0, 0})
c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL)
n, err := m.MarshalTo(wb)
if err != nil {
return err
}
binary.BigEndian.PutUint16(wb.Bytes()[:2], uint16(n))
c.cache.Put(qKey(m.Question), wb.Bytes(), LongTTL)
return nil
}
// GenResponse generates a dns response message for the given domain and ip address.
func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
// MakeResponse makes a dns response message for the given domain and ip address.
func (c *Client) MakeResponse(domain string, ip string) (*Message, error) {
ipb := net.ParseIP(ip)
if ipb == nil {
return nil, errors.New("GenResponse: invalid ip format")
@ -301,13 +307,6 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
return m, nil
}
func getKey(q *Question) string {
var qtype string
switch q.QTYPE {
case QTypeA:
qtype = "A"
case QTypeAAAA:
qtype = "AAAA"
}
return q.QNAME + "/" + qtype
func qKey(q *Question) string {
return fmt.Sprintf("%s/%d", q.QNAME, q.QTYPE)
}

View File

@ -91,19 +91,40 @@ func (m *Message) AddAnswer(rr *RR) error {
// Marshal marshals message struct to []byte.
func (m *Message) Marshal() ([]byte, error) {
buf := &bytes.Buffer{}
_, err := m.MarshalTo(buf)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// MarshalTo marshals message struct to []byte and write to w.
func (m *Message) MarshalTo(w io.Writer) (n int, err error) {
m.Header.SetQdcount(1)
m.Header.SetAncount(len(m.Answers))
// no error when write to bytes.Buffer
m.Header.MarshalTo(buf)
m.Question.MarshalTo(buf)
nn := 0
nn, err = m.Header.MarshalTo(w)
if err != nil {
return
}
n += nn
nn, err = m.Question.MarshalTo(w)
if err != nil {
return
}
n += nn
for _, answer := range m.Answers {
answer.MarshalTo(buf)
nn, err = answer.MarshalTo(w)
if err != nil {
return
}
n += nn
}
return buf.Bytes(), nil
return
}
// UnmarshalMessage unmarshals []bytes to Message.
@ -255,14 +276,25 @@ func NewQuestion(qtype uint16, domain string) *Question {
}
// MarshalTo marshals Question struct to []byte and write to w.
func (q *Question) MarshalTo(w io.Writer) (int, error) {
n, _ := MarshalDomainTo(w, q.QNAME)
func (q *Question) MarshalTo(w io.Writer) (n int, err error) {
n, err = MarshalDomainTo(w, q.QNAME)
if err != nil {
return
}
binary.Write(w, binary.BigEndian, q.QTYPE)
binary.Write(w, binary.BigEndian, q.QCLASS)
n += 4
err = binary.Write(w, binary.BigEndian, q.QTYPE)
if err != nil {
return
}
n += 2
return n, nil
err = binary.Write(w, binary.BigEndian, q.QCLASS)
if err != nil {
return
}
n += 2
return
}
// UnmarshalQuestion unmarshals []bytes to Question.
@ -332,19 +364,43 @@ func NewRR() *RR {
}
// MarshalTo marshals RR struct to []byte and write to w.
func (rr *RR) MarshalTo(w io.Writer) (int, error) {
n, _ := MarshalDomainTo(w, rr.NAME)
func (rr *RR) MarshalTo(w io.Writer) (n int, err error) {
n, err = MarshalDomainTo(w, rr.NAME)
if err != nil {
return
}
binary.Write(w, binary.BigEndian, rr.TYPE)
binary.Write(w, binary.BigEndian, rr.CLASS)
binary.Write(w, binary.BigEndian, rr.TTL)
binary.Write(w, binary.BigEndian, rr.RDLENGTH)
n += 10
err = binary.Write(w, binary.BigEndian, rr.TYPE)
if err != nil {
return
}
n += 2
w.Write(rr.RDATA)
err = binary.Write(w, binary.BigEndian, rr.CLASS)
if err != nil {
return
}
n += 2
err = binary.Write(w, binary.BigEndian, rr.TTL)
if err != nil {
return
}
n += 4
err = binary.Write(w, binary.BigEndian, rr.RDLENGTH)
if err != nil {
return
}
n += 2
_, err = w.Write(rr.RDATA)
if err != nil {
return
}
n += len(rr.RDATA)
return n, nil
return
}
// UnmarshalRR unmarshals []bytes to RR.
@ -388,16 +444,29 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
}
// MarshalDomainTo marshals domain string struct to []byte and write to w.
func MarshalDomainTo(w io.Writer, domain string) (int, error) {
n := 1
func MarshalDomainTo(w io.Writer, domain string) (n int, err error) {
nn := 0
for _, seg := range strings.Split(domain, ".") {
w.Write([]byte{byte(len(seg))})
io.WriteString(w, seg)
n += 1 + len(seg)
}
w.Write([]byte{0x00})
nn, err = w.Write([]byte{byte(len(seg))})
if err != nil {
return
}
n += nn
return n, nil
nn, err = io.WriteString(w, seg)
if err != nil {
return
}
n += nn
}
nn, err = w.Write([]byte{0x00})
if err != nil {
return
}
n += nn
return
}
// UnmarshalDomain gets domain from bytes.

View File

@ -65,19 +65,24 @@ func (s *Server) ListenAndServeUDP(wg *sync.WaitGroup) {
n, caddr, err := c.ReadFrom(reqBytes[2:])
if err != nil {
log.F("[dns] local read error: %v", err)
pool.PutBuffer(reqBytes)
continue
}
reqLen := uint16(n)
if reqLen <= HeaderLen+2 {
log.F("[dns] not enough message data")
pool.PutBuffer(reqBytes)
continue
}
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
go func() {
respBytes, err := s.Exchange(reqBytes[:2+n], caddr.String(), false)
defer pool.PutBuffer(reqBytes)
defer func() {
pool.PutBuffer(reqBytes)
pool.PutBuffer(respBytes)
}()
if err != nil {
log.F("[dns] error in exchange: %s", err)
@ -141,6 +146,7 @@ func (s *Server) ServeTCP(c net.Conn) {
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
defer pool.PutBuffer(respBytes)
if err != nil {
log.F("[dns-tcp] error in exchange: %s", err)
return

4
go.mod
View File

@ -9,9 +9,9 @@ require (
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/xtaci/kcp-go/v5 v5.5.15
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc // indirect
golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect
golang.org/x/sys v0.0.0-20200821140526-fda516888d29 // indirect
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 // indirect
golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
)

8
go.sum
View File

@ -101,8 +101,8 @@ golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc h1:zK/HqS5bZxDptfPJNq8v7vJfXtkU7r9TLIoSr1bXaP4=
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -124,8 +124,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU=
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 h1:E83yjTcMGvlL0ixGmFgJr/jvcp8L2LPDg7K0MQONeGA=
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 h1:r0nbB2EeRbGpnVeqxlkgiBpNi/bednpSg78qzZGOuv0=
golang.org/x/tools v0.0.0-20200822203824-307de81be3f4/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=