pool: return specified size of buffer

This commit is contained in:
nadoo 2020-04-19 23:20:15 +08:00
parent 5326c0a901
commit a0542a028e
16 changed files with 158 additions and 139 deletions

View File

@ -51,18 +51,18 @@ func Relay(left, right net.Conn) (int64, int64, error) {
ch := make(chan res)
go func() {
buf := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(right, left, buf)
pool.PutBuffer(buf)
b := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(right, left, b)
pool.PutBuffer(b)
right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
ch <- res{n, err}
}()
buf := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(left, right, buf)
pool.PutBuffer(buf)
b := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(left, right, b)
pool.PutBuffer(b)
right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
@ -76,17 +76,17 @@ func Relay(left, right net.Conn) (int64, int64, error) {
// RelayUDP copys from src to dst at target with read timeout.
func RelayUDP(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration) error {
buf := pool.GetBuffer(UDPBufSize)
defer pool.PutBuffer(buf)
b := pool.GetBuffer(UDPBufSize)
defer pool.PutBuffer(b)
for {
src.SetReadDeadline(time.Now().Add(timeout))
n, _, err := src.ReadFrom(buf)
n, _, err := src.ReadFrom(b)
if err != nil {
return err
}
_, err = dst.WriteTo(buf[:n], target)
_, err = dst.WriteTo(b[:n], target)
if err != nil {
return err
}

View File

@ -29,14 +29,14 @@ var bufPools = [...]sync.Pool{
{New: func() interface{} { return make([]byte, 64<<10) }},
}
func GetBuffer(size int64) []byte {
func GetBuffer(size int) []byte {
i := 0
for ; i < len(bufSizes)-1; i++ {
if size <= int64(bufSizes[i]) {
if size <= bufSizes[i] {
break
}
}
return bufPools[i].Get().([]byte)
return bufPools[i].Get().([]byte)[:size]
}
func PutBuffer(p []byte) {

View File

@ -9,7 +9,7 @@ var writeBufPool = sync.Pool{
New: func() interface{} { return &bytes.Buffer{} },
}
func GetWriteBuffer(size int64) *bytes.Buffer {
func GetWriteBuffer() *bytes.Buffer {
return writeBufPool.Get().(*bytes.Buffer)
}

View File

@ -1,7 +1,6 @@
package dns
import (
"bytes"
"encoding/binary"
"errors"
"io"
@ -10,6 +9,7 @@ import (
"time"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy"
)
@ -257,8 +257,10 @@ func (c *Client) AddRecord(record string) error {
b, _ := m.Marshal()
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.Write(b)
c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL)

View File

@ -1,7 +1,6 @@
package http
import (
"bytes"
"encoding/base64"
"errors"
"net"
@ -9,6 +8,7 @@ import (
"github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy"
)
@ -33,7 +33,7 @@ func (s *HTTP) Dial(network, addr string) (net.Conn, error) {
return nil, err
}
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n")
buf.WriteString("Host: " + addr + "\r\n")
buf.WriteString("Proxy-Connection: Keep-Alive\r\n")
@ -46,6 +46,7 @@ func (s *HTTP) Dial(network, addr string) (net.Conn, error) {
// header ended
buf.WriteString("\r\n")
_, err = rc.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil {
return nil, err
}

View File

@ -5,8 +5,8 @@
package http
import (
"bytes"
"encoding/base64"
"io"
"net/textproto"
"net/url"
"strings"
@ -81,17 +81,17 @@ func cleanHeaders(header textproto.MIMEHeader) {
header.Del("Upgrade")
}
func writeStartLine(buf *bytes.Buffer, s1, s2, s3 string) {
buf.WriteString(s1 + " " + s2 + " " + s3 + "\r\n")
func writeStartLine(w io.Writer, s1, s2, s3 string) {
w.Write([]byte(s1 + " " + s2 + " " + s3 + "\r\n"))
}
func writeHeaders(buf *bytes.Buffer, header textproto.MIMEHeader) {
func writeHeaders(w io.Writer, header textproto.MIMEHeader) {
for key, values := range header {
for _, v := range values {
buf.WriteString(key + ": " + v + "\r\n")
w.Write([]byte(key + ": " + v + "\r\n"))
}
}
buf.WriteString("\r\n")
w.Write([]byte("\r\n"))
}
func extractUserPass(auth string) (username, password string, ok bool) {

View File

@ -2,7 +2,6 @@ package http
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
@ -11,6 +10,7 @@ import (
"github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy"
)
@ -129,7 +129,10 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) {
// copy the left request bytes to remote server. eg. length specificed or chunked body.
go func() {
if _, err := c.Reader().Peek(1); err == nil {
io.Copy(rc, c)
b := pool.GetBuffer(conn.TCPBufSize)
io.CopyBuffer(rc, c, b)
pool.PutBuffer(b)
rc.SetDeadline(time.Now())
c.SetDeadline(time.Now())
}
@ -156,12 +159,15 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) {
header.Set("Proxy-Connection", "close")
header.Set("Connection", "close")
var buf bytes.Buffer
writeStartLine(&buf, proto, code, status)
writeHeaders(&buf, header)
buf := pool.GetWriteBuffer()
writeStartLine(buf, proto, code, status)
writeHeaders(buf, header)
log.F("[http] %s <-> %s", c.RemoteAddr(), req.target)
c.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
io.Copy(c, r)
b := pool.GetBuffer(conn.TCPBufSize)
io.CopyBuffer(c, r, b)
pool.PutBuffer(b)
}

View File

@ -2,11 +2,12 @@ package obfs
import (
"bufio"
"bytes"
"crypto/rand"
"encoding/base64"
"io"
"net"
"github.com/nadoo/glider/common/pool"
)
// HTTPObfs struct
@ -42,16 +43,19 @@ func (p *HTTPObfs) NewConn(c net.Conn) (net.Conn, error) {
}
func (c *HTTPObfsConn) writeHeader() (int, error) {
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.WriteString("GET " + c.obfsURI + " HTTP/1.1\r\n")
buf.WriteString("Host: " + c.obfsHost + "\r\n")
buf.WriteString("User-Agent: " + c.obfsUA + "\r\n")
buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n")
p := make([]byte, 16)
rand.Read(p)
buf.WriteString("Sec-WebSocket-Key: " + base64.StdEncoding.EncodeToString(p) + "\r\n")
b := pool.GetBuffer(16)
rand.Read(b)
buf.WriteString("Sec-WebSocket-Key: " + base64.StdEncoding.EncodeToString(b) + "\r\n")
pool.PutBuffer(b)
buf.WriteString("\r\n")

View File

@ -17,6 +17,8 @@ import (
"io"
"net"
"time"
"github.com/nadoo/glider/common/pool"
)
const (
@ -69,12 +71,13 @@ func (c *TLSObfsConn) Write(b []byte) (int, error) {
end = n
}
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
buf.Write([]byte{0x17, 0x03, 0x03})
binary.Write(&buf, binary.BigEndian, uint16(len(b[i:end])))
binary.Write(buf, binary.BigEndian, uint16(len(b[i:end])))
buf.Write(b[i:end])
_, err := c.Conn.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil {
return 0, err
}
@ -124,7 +127,7 @@ func (c *TLSObfsConn) Read(b []byte) (int, error) {
}
func (c *TLSObfsConn) handshake(b []byte) (int, error) {
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
// prepare extension & clientHello content
bufExt, bufHello := extension(b, c.obfsHost), clientHello()
@ -142,7 +145,7 @@ func (c *TLSObfsConn) handshake(b []byte) (int, error) {
buf.Write([]byte{0x03, 0x01})
// length
binary.Write(&buf, binary.BigEndian, uint16(handshakeLen))
binary.Write(buf, binary.BigEndian, uint16(handshakeLen))
// Handshake Begin
// Handshake Type: Client Hello (1)
@ -156,12 +159,13 @@ func (c *TLSObfsConn) handshake(b []byte) (int, error) {
// Extension Begin
// ext content length
binary.Write(&buf, binary.BigEndian, uint16(extLen))
binary.Write(buf, binary.BigEndian, uint16(extLen))
// ext content
buf.Write(bufExt.Bytes())
_, err := c.Conn.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil {
return 0, err
}

View File

@ -1,13 +1,13 @@
package trojan
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/common/socks"
)
@ -62,9 +62,11 @@ func (pc *PktConn) ReadFrom(b []byte) (int, net.Addr, error) {
// WriteTo implements the necessary function of net.PacketConn.
func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) {
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.Write(pc.tgtAddr)
binary.Write(&buf, binary.BigEndian, uint16(len(b)))
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.WriteString("\r\n")
buf.Write(b)
return pc.Write(buf.Bytes())

View File

@ -4,7 +4,6 @@
package trojan
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"encoding/hex"
@ -13,6 +12,7 @@ import (
"strings"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/common/socks"
"github.com/nadoo/glider/proxy"
)
@ -105,7 +105,9 @@ func (s *Trojan) dial(network, addr string) (net.Conn, error) {
return nil, err
}
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.Write(s.pass[:])
buf.WriteString("\r\n")

View File

@ -21,7 +21,7 @@ func AEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) io.Writer {
return &aeadWriter{
Writer: w,
AEAD: aead,
buf: make([]byte, lenSize+maxChunkSize),
buf: make([]byte, lenSize+chunkSize),
nonce: make([]byte, aead.NonceSize()),
count: 0,
iv: iv,
@ -36,7 +36,7 @@ func (w *aeadWriter) Write(b []byte) (int, error) {
func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
payloadBuf := buf[lenSize : lenSize+defaultChunkSize-w.Overhead()]
payloadBuf := buf[lenSize : lenSize+chunkSize-w.Overhead()]
nr, er := r.Read(payloadBuf)
if nr > 0 {
@ -84,7 +84,7 @@ func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte) io.Reader {
return &aeadReader{
Reader: r,
AEAD: aead,
buf: make([]byte, lenSize+maxChunkSize),
buf: make([]byte, lenSize+chunkSize),
nonce: make([]byte, aead.NonceSize()),
count: 0,
iv: iv,

View File

@ -1,104 +1,90 @@
package vmess
import (
"bytes"
"encoding/binary"
"io"
"github.com/nadoo/glider/common/pool"
)
const (
lenSize = 2
maxChunkSize = 1 << 14 // 16384
defaultChunkSize = 1 << 13 // 8192
lenSize = 2
chunkSize = 1 << 14 // 16384
)
type chunkedWriter struct {
io.Writer
buf []byte
}
// ChunkedWriter returns a chunked writer
func ChunkedWriter(w io.Writer) io.Writer {
return &chunkedWriter{
Writer: w,
buf: make([]byte, lenSize+maxChunkSize),
}
return &chunkedWriter{Writer: w}
}
func (w *chunkedWriter) Write(b []byte) (int, error) {
n, err := w.ReadFrom(bytes.NewBuffer(b))
return int(n), err
}
func (w *chunkedWriter) Write(b []byte) (n int, err error) {
buf := pool.GetBuffer(lenSize + chunkSize)
defer pool.PutBuffer(buf)
func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
payloadBuf := buf[lenSize : lenSize+defaultChunkSize]
nr, er := r.Read(payloadBuf)
if nr > 0 {
n += int64(nr)
buf = buf[:lenSize+nr]
payloadBuf = payloadBuf[:nr]
binary.BigEndian.PutUint16(buf[:lenSize], uint16(nr))
_, ew := w.Writer.Write(buf)
if ew != nil {
err = ew
break
}
left := len(b)
for left != 0 {
writeLen := left
if writeLen > chunkSize {
writeLen = chunkSize
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
err = er
}
copy(buf[lenSize:], b[n:n+writeLen])
binary.BigEndian.PutUint16(buf[:lenSize], uint16(writeLen))
_, err = w.Writer.Write(buf[:lenSize+writeLen])
if err != nil {
break
}
n += writeLen
left -= writeLen
}
return n, err
return
}
type chunkedReader struct {
io.Reader
buf []byte
leftBytes int
left int
}
// ChunkedReader returns a chunked reader
func ChunkedReader(r io.Reader) io.Reader {
return &chunkedReader{
Reader: r,
buf: make([]byte, lenSize), // NOTE: buf only used to save header bytes now
}
return &chunkedReader{Reader: r}
}
func (r *chunkedReader) Read(b []byte) (int, error) {
if r.leftBytes == 0 {
if r.left == 0 {
// get length
_, err := io.ReadFull(r.Reader, r.buf[:lenSize])
buf := pool.GetBuffer(lenSize)
_, err := io.ReadFull(r.Reader, buf[:lenSize])
if err != nil {
return 0, err
}
r.leftBytes = int(binary.BigEndian.Uint16(r.buf[:lenSize]))
r.left = int(binary.BigEndian.Uint16(buf[:lenSize]))
pool.PutBuffer(buf)
// if length == 0, then this is the end
if r.leftBytes == 0 {
// if left == 0, then this is the end
if r.left == 0 {
return 0, nil
}
}
readLen := len(b)
if readLen > r.leftBytes {
readLen = r.leftBytes
if readLen > r.left {
readLen = r.left
}
m, err := r.Reader.Read(b[:readLen])
n, err := r.Reader.Read(b[:readLen])
if err != nil {
return 0, err
}
r.leftBytes -= m
return m, err
r.left -= n
return n, err
}

View File

@ -1,7 +1,6 @@
package vmess
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
@ -15,6 +14,7 @@ import (
"strings"
"time"
"github.com/nadoo/glider/common/pool"
"golang.org/x/crypto/chacha20poly1305"
)
@ -108,7 +108,7 @@ func NewClient(uuidStr, security string, alterID int) (*Client, error) {
// NewConn returns a new vmess conn.
func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
r := rand.Intn(c.count)
conn := &Conn{user: c.users[r], opt: c.opt, security: c.security}
conn := &Conn{user: c.users[r], opt: c.opt, security: c.security, Conn: rc}
var err error
conn.atyp, conn.addr, conn.port, err = ParseAddr(target)
@ -116,51 +116,50 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
return nil, err
}
randBytes := make([]byte, 33)
randBytes := pool.GetBuffer(32)
rand.Read(randBytes)
copy(conn.reqBodyIV[:], randBytes[:16])
copy(conn.reqBodyKey[:], randBytes[16:32])
conn.reqRespV = randBytes[32]
pool.PutBuffer(randBytes)
conn.reqRespV = byte(rand.Intn(1 << 8))
conn.respBodyIV = md5.Sum(conn.reqBodyIV[:])
conn.respBodyKey = md5.Sum(conn.reqBodyKey[:])
// AuthInfo
_, err = rc.Write(conn.EncodeAuthInfo())
// Auth
err = conn.Auth()
if err != nil {
return nil, err
}
// Request
req, err := conn.EncodeRequest()
err = conn.Request()
if err != nil {
return nil, err
}
_, err = rc.Write(req)
if err != nil {
return nil, err
}
conn.Conn = rc
return conn, nil
}
// EncodeAuthInfo returns HMAC("md5", UUID, UTC) result
func (c *Conn) EncodeAuthInfo() []byte {
ts := make([]byte, 8)
// Auth send auth info: HMAC("md5", UUID, UTC)
func (c *Conn) Auth() error {
ts := pool.GetBuffer(8)
defer pool.PutBuffer(ts)
binary.BigEndian.PutUint64(ts, uint64(time.Now().UTC().Unix()))
h := hmac.New(md5.New, c.user.UUID[:])
h.Write(ts)
return h.Sum(nil)
_, err := c.Conn.Write(h.Sum(nil))
return err
}
// EncodeRequest encodes requests to network bytes.
func (c *Conn) EncodeRequest() ([]byte, error) {
var buf bytes.Buffer
// Request sends request to server.
func (c *Conn) Request() error {
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
// Request
buf.WriteByte(1) // Ver
@ -178,9 +177,9 @@ func (c *Conn) EncodeRequest() ([]byte, error) {
buf.WriteByte(CmdTCP) // cmd
// target
err := binary.Write(&buf, binary.BigEndian, uint16(c.port)) // port
err := binary.Write(buf, binary.BigEndian, uint16(c.port)) // port
if err != nil {
return nil, err
return err
}
buf.WriteByte(byte(c.atyp)) // atyp
@ -188,28 +187,31 @@ func (c *Conn) EncodeRequest() ([]byte, error) {
// padding
if paddingLen > 0 {
padding := make([]byte, paddingLen)
padding := pool.GetBuffer(paddingLen)
rand.Read(padding)
buf.Write(padding)
pool.PutBuffer(padding)
}
// F
fnv1a := fnv.New32a()
_, err = fnv1a.Write(buf.Bytes())
if err != nil {
return nil, err
return err
}
buf.Write(fnv1a.Sum(nil))
block, err := aes.NewCipher(c.user.CmdKey[:])
if err != nil {
return nil, err
return err
}
stream := cipher.NewCFBEncrypter(block, TimestampHash(time.Now().UTC()))
stream.XORKeyStream(buf.Bytes(), buf.Bytes())
return buf.Bytes(), nil
_, err = c.Conn.Write(buf.Bytes())
return err
}
// DecodeRespHeader decodes response header.
@ -221,20 +223,22 @@ func (c *Conn) DecodeRespHeader() error {
stream := cipher.NewCFBDecrypter(block, c.respBodyIV[:])
buf := make([]byte, 4)
_, err = io.ReadFull(c.Conn, buf)
b := pool.GetBuffer(4)
defer pool.PutBuffer(b)
_, err = io.ReadFull(c.Conn, b)
if err != nil {
return err
}
stream.XORKeyStream(buf, buf)
stream.XORKeyStream(b, b)
if buf[0] != c.reqRespV {
if b[0] != c.reqRespV {
return errors.New("unexpected response header")
}
// TODO: Dynamic port support
if buf[2] != 0 {
if b[2] != 0 {
// dataLen := int32(buf[3])
return errors.New("dynamic port is not supported now")
}
@ -259,13 +263,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
case SecurityChacha20Poly1305:
key := make([]byte, 32)
key := pool.GetBuffer(32)
t := md5.Sum(c.reqBodyKey[:])
copy(key, t[:])
t = md5.Sum(key[:16])
copy(key[16:], t[:])
aead, _ := chacha20poly1305.New(key)
c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
pool.PutBuffer(key)
}
}
@ -294,13 +299,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
case SecurityChacha20Poly1305:
key := make([]byte, 32)
key := pool.GetBuffer(32)
t := md5.Sum(c.respBodyKey[:])
copy(key, t[:])
t = md5.Sum(key[:16])
copy(key[16:], t[:])
aead, _ := chacha20poly1305.New(key)
c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
pool.PutBuffer(key)
}
}

View File

@ -8,6 +8,8 @@ import (
"errors"
"strings"
"time"
"github.com/nadoo/glider/common/pool"
)
// User of vmess client
@ -73,10 +75,11 @@ func GetKey(uuid [16]byte) []byte {
// TimestampHash returns the iv of AES-128-CFB encrypter
// IVMD5(X + X + X + X)X = []byte(timestamp.now) (8 bytes, Big Endian)
func TimestampHash(t time.Time) []byte {
md5hash := md5.New()
ts := pool.GetBuffer(8)
defer pool.PutBuffer(ts)
ts := make([]byte, 8)
binary.BigEndian.PutUint64(ts, uint64(t.UTC().Unix()))
md5hash := md5.New()
md5hash.Write(ts)
md5hash.Write(ts)
md5hash.Write(ts)

View File

@ -2,7 +2,6 @@ package ws
import (
"bufio"
"bytes"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
@ -11,6 +10,8 @@ import (
"net"
"net/textproto"
"strings"
"github.com/nadoo/glider/common/pool"
)
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@ -47,7 +48,9 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
func (c *Conn) Handshake(host, path string) error {
clientKey := generateClientKey()
var buf bytes.Buffer
buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.WriteString("GET " + path + " HTTP/1.1\r\n")
buf.WriteString("Host: " + host + "\r\n")
buf.WriteString("Upgrade: websocket\r\n")