diff --git a/common/pool/buffer.go b/common/pool/buffer.go index 1ebaed0..1336c15 100644 --- a/common/pool/buffer.go +++ b/common/pool/buffer.go @@ -7,7 +7,6 @@ import ( const ( // number of pools. - // pool sizes: [1<<0 ~ 1<<(num-1)] bytes, [1B~64KB]. num = 17 maxsize = 1 << (num - 1) ) @@ -42,9 +41,10 @@ func GetBuffer(size int) []byte { // PutBuffer puts a buffer into pool. func PutBuffer(buf []byte) { - size := cap(buf) - i := bits.Len32(uint32(size)) - 1 - if i < num && sizes[i] == size { - pools[i].Put(buf) + if size := cap(buf); size >= 1 && size <= maxsize { + i := bits.Len32(uint32(size)) - 1 + if sizes[i] == size { + pools[i].Put(buf) + } } } diff --git a/dns/cache.go b/dns/cache.go index 80e6e9e..a172034 100644 --- a/dns/cache.go +++ b/dns/cache.go @@ -3,6 +3,8 @@ package dns import ( "sync" "time" + + "github.com/nadoo/glider/common/pool" ) // LongTTL is 50 years duration in seconds, used for none-expired items. @@ -15,22 +17,26 @@ type item struct { // Cache is the struct of cache. type Cache struct { - m map[string]*item - l sync.RWMutex + store map[string]*item + mutex sync.RWMutex + storeCopy bool } // NewCache returns a new cache. -func NewCache() (c *Cache) { - c = &Cache{m: make(map[string]*item)} +func NewCache(storeCopy bool) (c *Cache) { + c = &Cache{store: make(map[string]*item), storeCopy: storeCopy} go func() { for now := range time.Tick(time.Second) { - c.l.Lock() - for k, v := range c.m { + c.mutex.Lock() + for k, v := range c.store { if now.After(v.expire) { - delete(c.m, k) + delete(c.store, k) + if storeCopy { + pool.PutBuffer(v.value) + } } } - c.l.Unlock() + c.mutex.Unlock() } }() return @@ -38,29 +44,46 @@ func NewCache() (c *Cache) { // Len returns the length of cache. func (c *Cache) Len() int { - return len(c.m) + return len(c.store) } // Put an item into cache, invalid after ttl seconds. func (c *Cache) Put(k string, v []byte, ttl int) { if len(v) != 0 { - c.l.Lock() - it, ok := c.m[k] + c.mutex.Lock() + it, ok := c.store[k] if !ok { - it = &item{value: v} - c.m[k] = it + if c.storeCopy { + it = &item{value: valCopy(v)} + } else { + it = &item{value: v} + } + c.store[k] = it } it.expire = time.Now().Add(time.Duration(ttl) * time.Second) - c.l.Unlock() + c.mutex.Unlock() } } -// Get an item from cache. +// Get gets an item from cache(do not modify it). func (c *Cache) Get(k string) (v []byte) { - c.l.RLock() - if it, ok := c.m[k]; ok { + c.mutex.RLock() + if it, ok := c.store[k]; ok { v = it.value } - c.l.RUnlock() + c.mutex.RUnlock() + return +} + +// GetCopy gets an item from cache and returns it's copy(so you can modify it). +func (c *Cache) GetCopy(k string) []byte { + return valCopy(c.Get(k)) +} + +func valCopy(v []byte) (b []byte) { + if v != nil { + b = pool.GetBuffer(len(v)) + copy(b, v) + } return } diff --git a/dns/client.go b/dns/client.go index 51be180..3ed10c8 100644 --- a/dns/client.go +++ b/dns/client.go @@ -1,15 +1,16 @@ package dns import ( - "bytes" "encoding/binary" "errors" + "fmt" "io" "net" "strings" "time" "github.com/nadoo/glider/common/log" + "github.com/nadoo/glider/common/pool" "github.com/nadoo/glider/proxy" ) @@ -40,7 +41,7 @@ type Client struct { func NewClient(proxy proxy.Proxy, config *Config) (*Client, error) { c := &Client{ proxy: proxy, - cache: NewCache(), + cache: NewCache(true), config: config, upStream: NewUPStream(config.Servers), upStreamMap: make(map[string]*UPStream), @@ -63,8 +64,8 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ } if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { - v := c.cache.Get(getKey(req.Question)) - if v != nil { + v := c.cache.GetCopy(qKey(req.Question)) + if len(v) > 4 { binary.BigEndian.PutUint16(v[2:4], req.ID) log.F("[dns] %s <-> cache, type: %d, %s", clientAddr, req.Question.QTYPE, req.Question.QNAME) @@ -93,7 +94,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ // add to cache only when there's a valid ip address if len(ips) != 0 && ttl > 0 { - c.cache.Put(getKey(resp.Question), respBytes, ttl) + c.cache.Put(qKey(resp.Question), respBytes, ttl) } log.F("[dns] %s <-> %s(%s) via %s, type: %d, %s: %s", @@ -204,7 +205,7 @@ func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { return nil, err } - respBytes := make([]byte, respLen+2) + respBytes := pool.GetBuffer(int(respLen) + 2) binary.BigEndian.PutUint16(respBytes[:2], respLen) _, err := io.ReadFull(rc, respBytes[2:]) @@ -221,7 +222,7 @@ func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) { return nil, err } - respBytes := make([]byte, UDPMaxLen) + respBytes := pool.GetBuffer(UDPMaxLen) n, err := rc.Read(respBytes[2:]) if err != nil { return nil, err @@ -258,24 +259,29 @@ func (c *Client) AddHandler(h HandleFunc) { func (c *Client) AddRecord(record string) error { r := strings.Split(record, "/") domain, ip := r[0], r[1] - m, err := c.GenResponse(domain, ip) + m, err := c.MakeResponse(domain, ip) if err != nil { return err } - b, _ := m.Marshal() + wb := pool.GetWriteBuffer() + defer pool.PutWriteBuffer(wb) - var buf bytes.Buffer - binary.Write(&buf, binary.BigEndian, uint16(len(b))) - buf.Write(b) + wb.Write([]byte{0, 0}) - c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL) + n, err := m.MarshalTo(wb) + if err != nil { + return err + } + + binary.BigEndian.PutUint16(wb.Bytes()[:2], uint16(n)) + c.cache.Put(qKey(m.Question), wb.Bytes(), LongTTL) return nil } -// GenResponse generates a dns response message for the given domain and ip address. -func (c *Client) GenResponse(domain string, ip string) (*Message, error) { +// MakeResponse makes a dns response message for the given domain and ip address. +func (c *Client) MakeResponse(domain string, ip string) (*Message, error) { ipb := net.ParseIP(ip) if ipb == nil { return nil, errors.New("GenResponse: invalid ip format") @@ -301,13 +307,6 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) { return m, nil } -func getKey(q *Question) string { - var qtype string - switch q.QTYPE { - case QTypeA: - qtype = "A" - case QTypeAAAA: - qtype = "AAAA" - } - return q.QNAME + "/" + qtype +func qKey(q *Question) string { + return fmt.Sprintf("%s/%d", q.QNAME, q.QTYPE) } diff --git a/dns/message.go b/dns/message.go index 3cdde75..c76d883 100644 --- a/dns/message.go +++ b/dns/message.go @@ -91,19 +91,40 @@ func (m *Message) AddAnswer(rr *RR) error { // Marshal marshals message struct to []byte. func (m *Message) Marshal() ([]byte, error) { buf := &bytes.Buffer{} + _, err := m.MarshalTo(buf) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} +// MarshalTo marshals message struct to []byte and write to w. +func (m *Message) MarshalTo(w io.Writer) (n int, err error) { m.Header.SetQdcount(1) m.Header.SetAncount(len(m.Answers)) - // no error when write to bytes.Buffer - m.Header.MarshalTo(buf) - m.Question.MarshalTo(buf) + nn := 0 + nn, err = m.Header.MarshalTo(w) + if err != nil { + return + } + n += nn + + nn, err = m.Question.MarshalTo(w) + if err != nil { + return + } + n += nn for _, answer := range m.Answers { - answer.MarshalTo(buf) + nn, err = answer.MarshalTo(w) + if err != nil { + return + } + n += nn } - return buf.Bytes(), nil + return } // UnmarshalMessage unmarshals []bytes to Message. @@ -255,14 +276,25 @@ func NewQuestion(qtype uint16, domain string) *Question { } // MarshalTo marshals Question struct to []byte and write to w. -func (q *Question) MarshalTo(w io.Writer) (int, error) { - n, _ := MarshalDomainTo(w, q.QNAME) +func (q *Question) MarshalTo(w io.Writer) (n int, err error) { + n, err = MarshalDomainTo(w, q.QNAME) + if err != nil { + return + } - binary.Write(w, binary.BigEndian, q.QTYPE) - binary.Write(w, binary.BigEndian, q.QCLASS) - n += 4 + err = binary.Write(w, binary.BigEndian, q.QTYPE) + if err != nil { + return + } + n += 2 - return n, nil + err = binary.Write(w, binary.BigEndian, q.QCLASS) + if err != nil { + return + } + n += 2 + + return } // UnmarshalQuestion unmarshals []bytes to Question. @@ -332,19 +364,43 @@ func NewRR() *RR { } // MarshalTo marshals RR struct to []byte and write to w. -func (rr *RR) MarshalTo(w io.Writer) (int, error) { - n, _ := MarshalDomainTo(w, rr.NAME) +func (rr *RR) MarshalTo(w io.Writer) (n int, err error) { + n, err = MarshalDomainTo(w, rr.NAME) + if err != nil { + return + } - binary.Write(w, binary.BigEndian, rr.TYPE) - binary.Write(w, binary.BigEndian, rr.CLASS) - binary.Write(w, binary.BigEndian, rr.TTL) - binary.Write(w, binary.BigEndian, rr.RDLENGTH) - n += 10 + err = binary.Write(w, binary.BigEndian, rr.TYPE) + if err != nil { + return + } + n += 2 - w.Write(rr.RDATA) + err = binary.Write(w, binary.BigEndian, rr.CLASS) + if err != nil { + return + } + n += 2 + + err = binary.Write(w, binary.BigEndian, rr.TTL) + if err != nil { + return + } + n += 4 + + err = binary.Write(w, binary.BigEndian, rr.RDLENGTH) + if err != nil { + return + } + n += 2 + + _, err = w.Write(rr.RDATA) + if err != nil { + return + } n += len(rr.RDATA) - return n, nil + return } // UnmarshalRR unmarshals []bytes to RR. @@ -388,16 +444,29 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { } // MarshalDomainTo marshals domain string struct to []byte and write to w. -func MarshalDomainTo(w io.Writer, domain string) (int, error) { - n := 1 +func MarshalDomainTo(w io.Writer, domain string) (n int, err error) { + nn := 0 for _, seg := range strings.Split(domain, ".") { - w.Write([]byte{byte(len(seg))}) - io.WriteString(w, seg) - n += 1 + len(seg) - } - w.Write([]byte{0x00}) + nn, err = w.Write([]byte{byte(len(seg))}) + if err != nil { + return + } + n += nn - return n, nil + nn, err = io.WriteString(w, seg) + if err != nil { + return + } + n += nn + } + + nn, err = w.Write([]byte{0x00}) + if err != nil { + return + } + n += nn + + return } // UnmarshalDomain gets domain from bytes. diff --git a/dns/server.go b/dns/server.go index 3312c77..d6c1556 100644 --- a/dns/server.go +++ b/dns/server.go @@ -65,19 +65,24 @@ func (s *Server) ListenAndServeUDP(wg *sync.WaitGroup) { n, caddr, err := c.ReadFrom(reqBytes[2:]) if err != nil { log.F("[dns] local read error: %v", err) + pool.PutBuffer(reqBytes) continue } reqLen := uint16(n) if reqLen <= HeaderLen+2 { log.F("[dns] not enough message data") + pool.PutBuffer(reqBytes) continue } binary.BigEndian.PutUint16(reqBytes[:2], reqLen) go func() { respBytes, err := s.Exchange(reqBytes[:2+n], caddr.String(), false) - defer pool.PutBuffer(reqBytes) + defer func() { + pool.PutBuffer(reqBytes) + pool.PutBuffer(respBytes) + }() if err != nil { log.F("[dns] error in exchange: %s", err) @@ -141,6 +146,7 @@ func (s *Server) ServeTCP(c net.Conn) { binary.BigEndian.PutUint16(reqBytes[:2], reqLen) respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true) + defer pool.PutBuffer(respBytes) if err != nil { log.F("[dns-tcp] error in exchange: %s", err) return diff --git a/go.mod b/go.mod index 1114178..f0eb814 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,9 @@ require ( github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/xtaci/kcp-go/v5 v5.5.15 golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a - golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc // indirect + golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect golang.org/x/sys v0.0.0-20200821140526-fda516888d29 // indirect - golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 // indirect + golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) diff --git a/go.sum b/go.sum index 6ec9773..112f3cf 100644 --- a/go.sum +++ b/go.sum @@ -101,8 +101,8 @@ golang.org/x/net v0.0.0-20200505041828-1ed23360d12c/go.mod h1:qpuaurCH72eLCgpAm/ golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgNrpq9mjcfDemuexIKsU= golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc h1:zK/HqS5bZxDptfPJNq8v7vJfXtkU7r9TLIoSr1bXaP4= -golang.org/x/net v0.0.0-20200813134508-3edf25e44fcc/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -124,8 +124,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7 h1:E83yjTcMGvlL0ixGmFgJr/jvcp8L2LPDg7K0MQONeGA= -golang.org/x/tools v0.0.0-20200821144610-c886c0b611b7/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200822203824-307de81be3f4 h1:r0nbB2EeRbGpnVeqxlkgiBpNi/bednpSg78qzZGOuv0= +golang.org/x/tools v0.0.0-20200822203824-307de81be3f4/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=