dns: optimize codes

This commit is contained in:
nadoo 2020-08-23 23:23:30 +08:00
parent a42d3a68d0
commit f65a983da8
7 changed files with 179 additions and 82 deletions

View File

@ -7,7 +7,6 @@ import (
const ( const (
// number of pools. // number of pools.
// pool sizes: [1<<0 ~ 1<<(num-1)] bytes, [1B~64KB].
num = 17 num = 17
maxsize = 1 << (num - 1) maxsize = 1 << (num - 1)
) )
@ -42,9 +41,10 @@ func GetBuffer(size int) []byte {
// PutBuffer puts a buffer into pool. // PutBuffer puts a buffer into pool.
func PutBuffer(buf []byte) { func PutBuffer(buf []byte) {
size := cap(buf) if size := cap(buf); size >= 1 && size <= maxsize {
i := bits.Len32(uint32(size)) - 1 i := bits.Len32(uint32(size)) - 1
if i < num && sizes[i] == size { if sizes[i] == size {
pools[i].Put(buf) pools[i].Put(buf)
}
} }
} }

View File

@ -3,6 +3,8 @@ package dns
import ( import (
"sync" "sync"
"time" "time"
"github.com/nadoo/glider/common/pool"
) )
// LongTTL is 50 years duration in seconds, used for none-expired items. // LongTTL is 50 years duration in seconds, used for none-expired items.
@ -15,22 +17,26 @@ type item struct {
// Cache is the struct of cache. // Cache is the struct of cache.
type Cache struct { type Cache struct {
m map[string]*item store map[string]*item
l sync.RWMutex mutex sync.RWMutex
storeCopy bool
} }
// NewCache returns a new cache. // NewCache returns a new cache.
func NewCache() (c *Cache) { func NewCache(storeCopy bool) (c *Cache) {
c = &Cache{m: make(map[string]*item)} c = &Cache{store: make(map[string]*item), storeCopy: storeCopy}
go func() { go func() {
for now := range time.Tick(time.Second) { for now := range time.Tick(time.Second) {
c.l.Lock() c.mutex.Lock()
for k, v := range c.m { for k, v := range c.store {
if now.After(v.expire) { if now.After(v.expire) {
delete(c.m, k) delete(c.store, k)
if storeCopy {
pool.PutBuffer(v.value)
}
} }
} }
c.l.Unlock() c.mutex.Unlock()
} }
}() }()
return return
@ -38,29 +44,46 @@ func NewCache() (c *Cache) {
// Len returns the length of cache. // Len returns the length of cache.
func (c *Cache) Len() int { func (c *Cache) Len() int {
return len(c.m) return len(c.store)
} }
// Put an item into cache, invalid after ttl seconds. // Put an item into cache, invalid after ttl seconds.
func (c *Cache) Put(k string, v []byte, ttl int) { func (c *Cache) Put(k string, v []byte, ttl int) {
if len(v) != 0 { if len(v) != 0 {
c.l.Lock() c.mutex.Lock()
it, ok := c.m[k] it, ok := c.store[k]
if !ok { if !ok {
it = &item{value: v} if c.storeCopy {
c.m[k] = it it = &item{value: valCopy(v)}
} else {
it = &item{value: v}
}
c.store[k] = it
} }
it.expire = time.Now().Add(time.Duration(ttl) * time.Second) it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
c.l.Unlock() c.mutex.Unlock()
} }
} }
// Get an item from cache. // Get gets an item from cache(do not modify it).
func (c *Cache) Get(k string) (v []byte) { func (c *Cache) Get(k string) (v []byte) {
c.l.RLock() c.mutex.RLock()
if it, ok := c.m[k]; ok { if it, ok := c.store[k]; ok {
v = it.value v = it.value
} }
c.l.RUnlock() c.mutex.RUnlock()
return
}
// GetCopy gets an item from cache and returns it's copy(so you can modify it).
func (c *Cache) GetCopy(k string) []byte {
return valCopy(c.Get(k))
}
func valCopy(v []byte) (b []byte) {
if v != nil {
b = pool.GetBuffer(len(v))
copy(b, v)
}
return return
} }

View File

@ -1,15 +1,16 @@
package dns package dns
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"net" "net"
"strings" "strings"
"time" "time"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -40,7 +41,7 @@ type Client struct {
func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) { func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) {
c := &Client{ c := &Client{
proxy: proxy, proxy: proxy,
cache: NewCache(), cache: NewCache(true),
config: config, config: config,
upStream: NewUPStream(config.Servers), upStream: NewUPStream(config.Servers),
upStreamMap: make(map[string]*UPStream), upStreamMap: make(map[string]*UPStream),
@ -63,8 +64,8 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
} }
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
v := c.cache.Get(getKey(req.Question)) v := c.cache.GetCopy(qKey(req.Question))
if v != nil { if len(v) > 4 {
binary.BigEndian.PutUint16(v[2:4], req.ID) binary.BigEndian.PutUint16(v[2:4], req.ID)
log.F("[dns] %s <-> cache, type: %d, %s", log.F("[dns] %s <-> cache, type: %d, %s",
clientAddr, req.Question.QTYPE, req.Question.QNAME) clientAddr, req.Question.QTYPE, req.Question.QNAME)
@ -93,7 +94,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([
// add to cache only when there's a valid ip address // add to cache only when there's a valid ip address
if len(ips) != 0 && ttl > 0 { if len(ips) != 0 && ttl > 0 {
c.cache.Put(getKey(resp.Question), respBytes, ttl) c.cache.Put(qKey(resp.Question), respBytes, ttl)
} }
log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s", log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s",
@ -204,7 +205,7 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
return nil, err return nil, err
} }
respBytes := make([]byte, respLen+2) respBytes := pool.GetBuffer(int(respLen) + 2)
binary.BigEndian.PutUint16(respBytes[:2], respLen) binary.BigEndian.PutUint16(respBytes[:2], respLen)
_, err := io.ReadFull(rc, respBytes[2:]) _, err := io.ReadFull(rc, respBytes[2:])
@ -221,7 +222,7 @@ func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
return nil, err return nil, err
} }
respBytes := make([]byte, UDPMaxLen) respBytes := pool.GetBuffer(UDPMaxLen)
n, err := rc.Read(respBytes[2:]) n, err := rc.Read(respBytes[2:])
if err != nil { if err != nil {
return nil, err return nil, err
@ -258,24 +259,29 @@ func (c *Client) AddHandler(h HandleFunc) {
func (c *Client) AddRecord(record string) error { func (c *Client) AddRecord(record string) error {
r := strings.Split(record, "/") r := strings.Split(record, "/")
domain, ip := r[0], r[1] domain, ip := r[0], r[1]
m, err := c.GenResponse(domain, ip) m, err := c.MakeResponse(domain, ip)
if err != nil { if err != nil {
return err return err
} }
b, _ := m.Marshal() wb := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(wb)
var buf bytes.Buffer wb.Write([]byte{0, 0})
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL) n, err := m.MarshalTo(wb)
if err != nil {
return err
}
binary.BigEndian.PutUint16(wb.Bytes()[:2], uint16(n))
c.cache.Put(qKey(m.Question), wb.Bytes(), LongTTL)
return nil return nil
} }
// GenResponse generates a dns response message for the given domain and ip address. // MakeResponse makes a dns response message for the given domain and ip address.
func (c *Client) GenResponse(domain string, ip string) (*Message, error) { func (c *Client) MakeResponse(domain string, ip string) (*Message, error) {
ipb := net.ParseIP(ip) ipb := net.ParseIP(ip)
if ipb == nil { if ipb == nil {
return nil, errors.New("GenResponse: invalid ip format") return nil, errors.New("GenResponse: invalid ip format")
@ -301,13 +307,6 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
return m, nil return m, nil
} }
func getKey(q *Question) string { func qKey(q *Question) string {
var qtype string return fmt.Sprintf("%s/%d", q.QNAME, q.QTYPE)
switch q.QTYPE {
case QTypeA:
qtype = "A"
case QTypeAAAA:
qtype = "AAAA"
}
return q.QNAME + "/" + qtype
} }

View File

@ -91,19 +91,40 @@ func (m *Message) AddAnswer(rr *RR) error {
// Marshal marshals message struct to []byte. // Marshal marshals message struct to []byte.
func (m *Message) Marshal() ([]byte, error) { func (m *Message) Marshal() ([]byte, error) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
_, err := m.MarshalTo(buf)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// MarshalTo marshals message struct to []byte and write to w.
func (m *Message) MarshalTo(w io.Writer) (n int, err error) {
m.Header.SetQdcount(1) m.Header.SetQdcount(1)
m.Header.SetAncount(len(m.Answers)) m.Header.SetAncount(len(m.Answers))
// no error when write to bytes.Buffer nn := 0
m.Header.MarshalTo(buf) nn, err = m.Header.MarshalTo(w)
m.Question.MarshalTo(buf) if err != nil {
return
}
n += nn
nn, err = m.Question.MarshalTo(w)
if err != nil {
return
}
n += nn
for _, answer := range m.Answers { for _, answer := range m.Answers {
answer.MarshalTo(buf) nn, err = answer.MarshalTo(w)
if err != nil {
return
}
n += nn
} }
return buf.Bytes(), nil return
} }
// UnmarshalMessage unmarshals []bytes to Message. // UnmarshalMessage unmarshals []bytes to Message.
@ -255,14 +276,25 @@ func NewQuestion(qtype uint16, domain string) *Question {
} }
// MarshalTo marshals Question struct to []byte and write to w. // MarshalTo marshals Question struct to []byte and write to w.
func (q *Question) MarshalTo(w io.Writer) (int, error) { func (q *Question) MarshalTo(w io.Writer) (n int, err error) {
n, _ := MarshalDomainTo(w, q.QNAME) n, err = MarshalDomainTo(w, q.QNAME)
if err != nil {
return
}
binary.Write(w, binary.BigEndian, q.QTYPE) err = binary.Write(w, binary.BigEndian, q.QTYPE)
binary.Write(w, binary.BigEndian, q.QCLASS) if err != nil {
n += 4 return
}
n += 2
return n, nil err = binary.Write(w, binary.BigEndian, q.QCLASS)
if err != nil {
return
}
n += 2
return
} }
// UnmarshalQuestion unmarshals []bytes to Question. // UnmarshalQuestion unmarshals []bytes to Question.
@ -332,19 +364,43 @@ func NewRR() *RR {
} }
// MarshalTo marshals RR struct to []byte and write to w. // MarshalTo marshals RR struct to []byte and write to w.
func (rr *RR) MarshalTo(w io.Writer) (int, error) { func (rr *RR) MarshalTo(w io.Writer) (n int, err error) {
n, _ := MarshalDomainTo(w, rr.NAME) n, err = MarshalDomainTo(w, rr.NAME)
if err != nil {
return
}
binary.Write(w, binary.BigEndian, rr.TYPE) err = binary.Write(w, binary.BigEndian, rr.TYPE)
binary.Write(w, binary.BigEndian, rr.CLASS) if err != nil {
binary.Write(w, binary.BigEndian, rr.TTL) return
binary.Write(w, binary.BigEndian, rr.RDLENGTH) }
n += 10 n += 2
w.Write(rr.RDATA) err = binary.Write(w, binary.BigEndian, rr.CLASS)
if err != nil {
return
}
n += 2
err = binary.Write(w, binary.BigEndian, rr.TTL)
if err != nil {
return
}
n += 4
err = binary.Write(w, binary.BigEndian, rr.RDLENGTH)
if err != nil {
return
}
n += 2
_, err = w.Write(rr.RDATA)
if err != nil {
return
}
n += len(rr.RDATA) n += len(rr.RDATA)
return n, nil return
} }
// UnmarshalRR unmarshals []bytes to RR. // UnmarshalRR unmarshals []bytes to RR.
@ -388,16 +444,29 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
} }
// MarshalDomainTo marshals domain string struct to []byte and write to w. // MarshalDomainTo marshals domain string struct to []byte and write to w.
func MarshalDomainTo(w io.Writer, domain string) (int, error) { func MarshalDomainTo(w io.Writer, domain string) (n int, err error) {
n := 1 nn := 0
for _, seg := range strings.Split(domain, ".") { for _, seg := range strings.Split(domain, ".") {
w.Write([]byte{byte(len(seg))}) nn, err = w.Write([]byte{byte(len(seg))})
io.WriteString(w, seg) if err != nil {
n += 1 + len(seg) return
} }
w.Write([]byte{0x00}) n += nn
return n, nil nn, err = io.WriteString(w, seg)
if err != nil {
return
}
n += nn
}
nn, err = w.Write([]byte{0x00})
if err != nil {
return
}
n += nn
return
} }
// UnmarshalDomain gets domain from bytes. // UnmarshalDomain gets domain from bytes.

View File

@ -65,19 +65,24 @@ func (s *Server) ListenAndServeUDP(wg *sync.WaitGroup) {
n, caddr, err := c.ReadFrom(reqBytes[2:]) n, caddr, err := c.ReadFrom(reqBytes[2:])
if err != nil { if err != nil {
log.F("[dns] local read error: %v", err) log.F("[dns] local read error: %v", err)
pool.PutBuffer(reqBytes)
continue continue
} }
reqLen := uint16(n) reqLen := uint16(n)
if reqLen <= HeaderLen+2 { if reqLen <= HeaderLen+2 {
log.F("[dns] not enough message data") log.F("[dns] not enough message data")
pool.PutBuffer(reqBytes)
continue continue
} }
binary.BigEndian.PutUint16(reqBytes[:2], reqLen) binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
go func() { go func() {
respBytes, err := s.Exchange(reqBytes[:2+n], caddr.String(), false) respBytes, err := s.Exchange(reqBytes[:2+n], caddr.String(), false)
defer pool.PutBuffer(reqBytes) defer func() {
pool.PutBuffer(reqBytes)
pool.PutBuffer(respBytes)
}()
if err != nil { if err != nil {
log.F("[dns] error in exchange: %s", err) log.F("[dns] error in exchange: %s", err)
@ -141,6 +146,7 @@ func (s *Server) ServeTCP(c net.Conn) {
binary.BigEndian.PutUint16(reqBytes[:2], reqLen) binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true) respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
defer pool.PutBuffer(respBytes)
if err != nil { if err != nil {
log.F("[dns-tcp] error in exchange: %s", err) log.F("[dns-tcp] error in exchange: %s", err)
return return

4
go.mod
View File

@ -9,9 +9,9 @@ require (
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/xtaci/kcp-go/v5 v5.5.15 github.com/xtaci/kcp-go/v5 v5.5.15
golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc // indirect golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect
golang.org/x/sys v0.0.0-20200821140526-fda516888d29 // indirect golang.org/x/sys v0.0.0-20200821140526-fda516888d29 // indirect
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 // indirect golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
) )

8
go.sum
View File

@ -101,8 +101,8 @@ golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/
golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU=
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc h1:zK/HqS5bZxDptfPJNq8v7vJfXtkU7r9TLIoSr1bXaP4= golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -124,8 +124,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU=
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 h1:E83yjTcMGvlL0ixGmFgJr/jvcp8L2LPDg7K0MQONeGA= golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 h1:r0nbB2EeRbGpnVeqxlkgiBpNi/bednpSg78qzZGOuv0=
golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/tools v0.0.0-20200822203824-307de81be3f4/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=