pool: return specified size of buffer

This commit is contained in:
nadoo 2020-04-19 23:20:15 +08:00
parent 5326c0a901
commit a0542a028e
16 changed files with 158 additions and 139 deletions

View File

@ -51,18 +51,18 @@ func Relay(left, right net.Conn) (int64, int64, error) {
ch := make(chan res) ch := make(chan res)
go func() { go func() {
buf := pool.GetBuffer(TCPBufSize) b := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(right, left, buf) n, err := io.CopyBuffer(right, left, b)
pool.PutBuffer(buf) pool.PutBuffer(b)
right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
ch <- res{n, err} ch <- res{n, err}
}() }()
buf := pool.GetBuffer(TCPBufSize) b := pool.GetBuffer(TCPBufSize)
n, err := io.CopyBuffer(left, right, buf) n, err := io.CopyBuffer(left, right, b)
pool.PutBuffer(buf) pool.PutBuffer(b)
right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
@ -76,17 +76,17 @@ func Relay(left, right net.Conn) (int64, int64, error) {
// RelayUDP copys from src to dst at target with read timeout. // RelayUDP copys from src to dst at target with read timeout.
func RelayUDP(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration) error { func RelayUDP(dst net.PacketConn, target net.Addr, src net.PacketConn, timeout time.Duration) error {
buf := pool.GetBuffer(UDPBufSize) b := pool.GetBuffer(UDPBufSize)
defer pool.PutBuffer(buf) defer pool.PutBuffer(b)
for { for {
src.SetReadDeadline(time.Now().Add(timeout)) src.SetReadDeadline(time.Now().Add(timeout))
n, _, err := src.ReadFrom(buf) n, _, err := src.ReadFrom(b)
if err != nil { if err != nil {
return err return err
} }
_, err = dst.WriteTo(buf[:n], target) _, err = dst.WriteTo(b[:n], target)
if err != nil { if err != nil {
return err return err
} }

View File

@ -29,14 +29,14 @@ var bufPools = [...]sync.Pool{
{New: func() interface{} { return make([]byte, 64<<10) }}, {New: func() interface{} { return make([]byte, 64<<10) }},
} }
func GetBuffer(size int64) []byte { func GetBuffer(size int) []byte {
i := 0 i := 0
for ; i < len(bufSizes)-1; i++ { for ; i < len(bufSizes)-1; i++ {
if size <= int64(bufSizes[i]) { if size <= bufSizes[i] {
break break
} }
} }
return bufPools[i].Get().([]byte) return bufPools[i].Get().([]byte)[:size]
} }
func PutBuffer(p []byte) { func PutBuffer(p []byte) {

View File

@ -9,7 +9,7 @@ var writeBufPool = sync.Pool{
New: func() interface{} { return &bytes.Buffer{} }, New: func() interface{} { return &bytes.Buffer{} },
} }
func GetWriteBuffer(size int64) *bytes.Buffer { func GetWriteBuffer() *bytes.Buffer {
return writeBufPool.Get().(*bytes.Buffer) return writeBufPool.Get().(*bytes.Buffer)
} }

View File

@ -1,7 +1,6 @@
package dns package dns
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
@ -10,6 +9,7 @@ import (
"time" "time"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -257,8 +257,10 @@ func (c *Client) AddRecord(record string) error {
b, _ := m.Marshal() b, _ := m.Marshal()
var buf bytes.Buffer buf := pool.GetWriteBuffer()
binary.Write(&buf, binary.BigEndian, uint16(len(b))) defer pool.PutWriteBuffer(buf)
binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.Write(b) buf.Write(b)
c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL) c.cache.Put(getKey(m.Question), buf.Bytes(), LongTTL)

View File

@ -1,7 +1,6 @@
package http package http
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"errors" "errors"
"net" "net"
@ -9,6 +8,7 @@ import (
"github.com/nadoo/glider/common/conn" "github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -33,7 +33,7 @@ func (s *HTTP) Dial(network, addr string) (net.Conn, error) {
return nil, err return nil, err
} }
var buf bytes.Buffer buf := pool.GetWriteBuffer()
buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n") buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n")
buf.WriteString("Host: " + addr + "\r\n") buf.WriteString("Host: " + addr + "\r\n")
buf.WriteString("Proxy-Connection: Keep-Alive\r\n") buf.WriteString("Proxy-Connection: Keep-Alive\r\n")
@ -46,6 +46,7 @@ func (s *HTTP) Dial(network, addr string) (net.Conn, error) {
// header ended // header ended
buf.WriteString("\r\n") buf.WriteString("\r\n")
_, err = rc.Write(buf.Bytes()) _, err = rc.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -5,8 +5,8 @@
package http package http
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"io"
"net/textproto" "net/textproto"
"net/url" "net/url"
"strings" "strings"
@ -81,17 +81,17 @@ func cleanHeaders(header textproto.MIMEHeader) {
header.Del("Upgrade") header.Del("Upgrade")
} }
func writeStartLine(buf *bytes.Buffer, s1, s2, s3 string) { func writeStartLine(w io.Writer, s1, s2, s3 string) {
buf.WriteString(s1 + " " + s2 + " " + s3 + "\r\n") w.Write([]byte(s1 + " " + s2 + " " + s3 + "\r\n"))
} }
func writeHeaders(buf *bytes.Buffer, header textproto.MIMEHeader) { func writeHeaders(w io.Writer, header textproto.MIMEHeader) {
for key, values := range header { for key, values := range header {
for _, v := range values { for _, v := range values {
buf.WriteString(key + ": " + v + "\r\n") w.Write([]byte(key + ": " + v + "\r\n"))
} }
} }
buf.WriteString("\r\n") w.Write([]byte("\r\n"))
} }
func extractUserPass(auth string) (username, password string, ok bool) { func extractUserPass(auth string) (username, password string, ok bool) {

View File

@ -2,7 +2,6 @@ package http
import ( import (
"bufio" "bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -11,6 +10,7 @@ import (
"github.com/nadoo/glider/common/conn" "github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -129,7 +129,10 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) {
// copy the left request bytes to remote server. eg. length specificed or chunked body. // copy the left request bytes to remote server. eg. length specificed or chunked body.
go func() { go func() {
if _, err := c.Reader().Peek(1); err == nil { if _, err := c.Reader().Peek(1); err == nil {
io.Copy(rc, c) b := pool.GetBuffer(conn.TCPBufSize)
io.CopyBuffer(rc, c, b)
pool.PutBuffer(b)
rc.SetDeadline(time.Now()) rc.SetDeadline(time.Now())
c.SetDeadline(time.Now()) c.SetDeadline(time.Now())
} }
@ -156,12 +159,15 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) {
header.Set("Proxy-Connection", "close") header.Set("Proxy-Connection", "close")
header.Set("Connection", "close") header.Set("Connection", "close")
var buf bytes.Buffer buf := pool.GetWriteBuffer()
writeStartLine(&buf, proto, code, status) writeStartLine(buf, proto, code, status)
writeHeaders(&buf, header) writeHeaders(buf, header)
log.F("[http] %s <-> %s", c.RemoteAddr(), req.target) log.F("[http] %s <-> %s", c.RemoteAddr(), req.target)
c.Write(buf.Bytes()) c.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
io.Copy(c, r) b := pool.GetBuffer(conn.TCPBufSize)
io.CopyBuffer(c, r, b)
pool.PutBuffer(b)
} }

View File

@ -2,11 +2,12 @@ package obfs
import ( import (
"bufio" "bufio"
"bytes"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"io" "io"
"net" "net"
"github.com/nadoo/glider/common/pool"
) )
// HTTPObfs struct // HTTPObfs struct
@ -42,16 +43,19 @@ func (p *HTTPObfs) NewConn(c net.Conn) (net.Conn, error) {
} }
func (c *HTTPObfsConn) writeHeader() (int, error) { func (c *HTTPObfsConn) writeHeader() (int, error) {
var buf bytes.Buffer buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.WriteString("GET " + c.obfsURI + " HTTP/1.1\r\n") buf.WriteString("GET " + c.obfsURI + " HTTP/1.1\r\n")
buf.WriteString("Host: " + c.obfsHost + "\r\n") buf.WriteString("Host: " + c.obfsHost + "\r\n")
buf.WriteString("User-Agent: " + c.obfsUA + "\r\n") buf.WriteString("User-Agent: " + c.obfsUA + "\r\n")
buf.WriteString("Upgrade: websocket\r\n") buf.WriteString("Upgrade: websocket\r\n")
buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("Connection: Upgrade\r\n")
p := make([]byte, 16) b := pool.GetBuffer(16)
rand.Read(p) rand.Read(b)
buf.WriteString("Sec-WebSocket-Key: " + base64.StdEncoding.EncodeToString(p) + "\r\n") buf.WriteString("Sec-WebSocket-Key: " + base64.StdEncoding.EncodeToString(b) + "\r\n")
pool.PutBuffer(b)
buf.WriteString("\r\n") buf.WriteString("\r\n")

View File

@ -17,6 +17,8 @@ import (
"io" "io"
"net" "net"
"time" "time"
"github.com/nadoo/glider/common/pool"
) )
const ( const (
@ -69,12 +71,13 @@ func (c *TLSObfsConn) Write(b []byte) (int, error) {
end = n end = n
} }
var buf bytes.Buffer buf := pool.GetWriteBuffer()
buf.Write([]byte{0x17, 0x03, 0x03}) buf.Write([]byte{0x17, 0x03, 0x03})
binary.Write(&buf, binary.BigEndian, uint16(len(b[i:end]))) binary.Write(buf, binary.BigEndian, uint16(len(b[i:end])))
buf.Write(b[i:end]) buf.Write(b[i:end])
_, err := c.Conn.Write(buf.Bytes()) _, err := c.Conn.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -124,7 +127,7 @@ func (c *TLSObfsConn) Read(b []byte) (int, error) {
} }
func (c *TLSObfsConn) handshake(b []byte) (int, error) { func (c *TLSObfsConn) handshake(b []byte) (int, error) {
var buf bytes.Buffer buf := pool.GetWriteBuffer()
// prepare extension & clientHello content // prepare extension & clientHello content
bufExt, bufHello := extension(b, c.obfsHost), clientHello() bufExt, bufHello := extension(b, c.obfsHost), clientHello()
@ -142,7 +145,7 @@ func (c *TLSObfsConn) handshake(b []byte) (int, error) {
buf.Write([]byte{0x03, 0x01}) buf.Write([]byte{0x03, 0x01})
// length // length
binary.Write(&buf, binary.BigEndian, uint16(handshakeLen)) binary.Write(buf, binary.BigEndian, uint16(handshakeLen))
// Handshake Begin // Handshake Begin
// Handshake Type: Client Hello (1) // Handshake Type: Client Hello (1)
@ -156,12 +159,13 @@ func (c *TLSObfsConn) handshake(b []byte) (int, error) {
// Extension Begin // Extension Begin
// ext content length // ext content length
binary.Write(&buf, binary.BigEndian, uint16(extLen)) binary.Write(buf, binary.BigEndian, uint16(extLen))
// ext content // ext content
buf.Write(bufExt.Bytes()) buf.Write(bufExt.Bytes())
_, err := c.Conn.Write(buf.Bytes()) _, err := c.Conn.Write(buf.Bytes())
pool.PutWriteBuffer(buf)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View File

@ -1,13 +1,13 @@
package trojan package trojan
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"io" "io"
"net" "net"
"github.com/nadoo/glider/common/conn" "github.com/nadoo/glider/common/conn"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/common/socks" "github.com/nadoo/glider/common/socks"
) )
@ -62,9 +62,11 @@ func (pc *PktConn) ReadFrom(b []byte) (int, net.Addr, error) {
// WriteTo implements the necessary function of net.PacketConn. // WriteTo implements the necessary function of net.PacketConn.
func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) {
var buf bytes.Buffer buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.Write(pc.tgtAddr) buf.Write(pc.tgtAddr)
binary.Write(&buf, binary.BigEndian, uint16(len(b))) binary.Write(buf, binary.BigEndian, uint16(len(b)))
buf.WriteString("\r\n") buf.WriteString("\r\n")
buf.Write(b) buf.Write(b)
return pc.Write(buf.Bytes()) return pc.Write(buf.Bytes())

View File

@ -4,7 +4,6 @@
package trojan package trojan
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
@ -13,6 +12,7 @@ import (
"strings" "strings"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/common/pool"
"github.com/nadoo/glider/common/socks" "github.com/nadoo/glider/common/socks"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
) )
@ -105,7 +105,9 @@ func (s *Trojan) dial(network, addr string) (net.Conn, error) {
return nil, err return nil, err
} }
var buf bytes.Buffer buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.Write(s.pass[:]) buf.Write(s.pass[:])
buf.WriteString("\r\n") buf.WriteString("\r\n")

View File

@ -21,7 +21,7 @@ func AEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) io.Writer {
return &aeadWriter{ return &aeadWriter{
Writer: w, Writer: w,
AEAD: aead, AEAD: aead,
buf: make([]byte, lenSize+maxChunkSize), buf: make([]byte, lenSize+chunkSize),
nonce: make([]byte, aead.NonceSize()), nonce: make([]byte, aead.NonceSize()),
count: 0, count: 0,
iv: iv, iv: iv,
@ -36,7 +36,7 @@ func (w *aeadWriter) Write(b []byte) (int, error) {
func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) { func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) {
for { for {
buf := w.buf buf := w.buf
payloadBuf := buf[lenSize : lenSize+defaultChunkSize-w.Overhead()] payloadBuf := buf[lenSize : lenSize+chunkSize-w.Overhead()]
nr, er := r.Read(payloadBuf) nr, er := r.Read(payloadBuf)
if nr > 0 { if nr > 0 {
@ -84,7 +84,7 @@ func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte) io.Reader {
return &aeadReader{ return &aeadReader{
Reader: r, Reader: r,
AEAD: aead, AEAD: aead,
buf: make([]byte, lenSize+maxChunkSize), buf: make([]byte, lenSize+chunkSize),
nonce: make([]byte, aead.NonceSize()), nonce: make([]byte, aead.NonceSize()),
count: 0, count: 0,
iv: iv, iv: iv,

View File

@ -1,104 +1,90 @@
package vmess package vmess
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"io" "io"
"github.com/nadoo/glider/common/pool"
) )
const ( const (
lenSize = 2 lenSize = 2
maxChunkSize = 1 << 14 // 16384 chunkSize = 1 << 14 // 16384
defaultChunkSize = 1 << 13 // 8192
) )
type chunkedWriter struct { type chunkedWriter struct {
io.Writer io.Writer
buf []byte
} }
// ChunkedWriter returns a chunked writer // ChunkedWriter returns a chunked writer
func ChunkedWriter(w io.Writer) io.Writer { func ChunkedWriter(w io.Writer) io.Writer {
return &chunkedWriter{ return &chunkedWriter{Writer: w}
Writer: w,
buf: make([]byte, lenSize+maxChunkSize),
}
} }
func (w *chunkedWriter) Write(b []byte) (int, error) { func (w *chunkedWriter) Write(b []byte) (n int, err error) {
n, err := w.ReadFrom(bytes.NewBuffer(b)) buf := pool.GetBuffer(lenSize + chunkSize)
return int(n), err defer pool.PutBuffer(buf)
}
func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) { left := len(b)
for { for left != 0 {
buf := w.buf writeLen := left
payloadBuf := buf[lenSize : lenSize+defaultChunkSize] if writeLen > chunkSize {
writeLen = chunkSize
}
nr, er := r.Read(payloadBuf) copy(buf[lenSize:], b[n:n+writeLen])
if nr > 0 { binary.BigEndian.PutUint16(buf[:lenSize], uint16(writeLen))
n += int64(nr)
buf = buf[:lenSize+nr]
payloadBuf = payloadBuf[:nr]
binary.BigEndian.PutUint16(buf[:lenSize], uint16(nr))
_, ew := w.Writer.Write(buf) _, err = w.Writer.Write(buf[:lenSize+writeLen])
if ew != nil { if err != nil {
err = ew
break break
} }
n += writeLen
left -= writeLen
} }
if er != nil { return
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
err = er
}
break
}
}
return n, err
} }
type chunkedReader struct { type chunkedReader struct {
io.Reader io.Reader
buf []byte left int
leftBytes int
} }
// ChunkedReader returns a chunked reader // ChunkedReader returns a chunked reader
func ChunkedReader(r io.Reader) io.Reader { func ChunkedReader(r io.Reader) io.Reader {
return &chunkedReader{ return &chunkedReader{Reader: r}
Reader: r,
buf: make([]byte, lenSize), // NOTE: buf only used to save header bytes now
}
} }
func (r *chunkedReader) Read(b []byte) (int, error) { func (r *chunkedReader) Read(b []byte) (int, error) {
if r.leftBytes == 0 { if r.left == 0 {
// get length // get length
_, err := io.ReadFull(r.Reader, r.buf[:lenSize]) buf := pool.GetBuffer(lenSize)
_, err := io.ReadFull(r.Reader, buf[:lenSize])
if err != nil { if err != nil {
return 0, err return 0, err
} }
r.leftBytes = int(binary.BigEndian.Uint16(r.buf[:lenSize])) r.left = int(binary.BigEndian.Uint16(buf[:lenSize]))
pool.PutBuffer(buf)
// if length == 0, then this is the end // if left == 0, then this is the end
if r.leftBytes == 0 { if r.left == 0 {
return 0, nil return 0, nil
} }
} }
readLen := len(b) readLen := len(b)
if readLen > r.leftBytes { if readLen > r.left {
readLen = r.leftBytes readLen = r.left
} }
m, err := r.Reader.Read(b[:readLen]) n, err := r.Reader.Read(b[:readLen])
if err != nil { if err != nil {
return 0, err return 0, err
} }
r.leftBytes -= m r.left -= n
return m, err
return n, err
} }

View File

@ -1,7 +1,6 @@
package vmess package vmess
import ( import (
"bytes"
"crypto/aes" "crypto/aes"
"crypto/cipher" "crypto/cipher"
"crypto/hmac" "crypto/hmac"
@ -15,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/nadoo/glider/common/pool"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )
@ -108,7 +108,7 @@ func NewClient(uuidStr, security string, alterID int) (*Client, error) {
// NewConn returns a new vmess conn. // NewConn returns a new vmess conn.
func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
r := rand.Intn(c.count) r := rand.Intn(c.count)
conn := &Conn{user: c.users[r], opt: c.opt, security: c.security} conn := &Conn{user: c.users[r], opt: c.opt, security: c.security, Conn: rc}
var err error var err error
conn.atyp, conn.addr, conn.port, err = ParseAddr(target) conn.atyp, conn.addr, conn.port, err = ParseAddr(target)
@ -116,51 +116,50 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
return nil, err return nil, err
} }
randBytes := make([]byte, 33) randBytes := pool.GetBuffer(32)
rand.Read(randBytes) rand.Read(randBytes)
copy(conn.reqBodyIV[:], randBytes[:16]) copy(conn.reqBodyIV[:], randBytes[:16])
copy(conn.reqBodyKey[:], randBytes[16:32]) copy(conn.reqBodyKey[:], randBytes[16:32])
conn.reqRespV = randBytes[32] pool.PutBuffer(randBytes)
conn.reqRespV = byte(rand.Intn(1 << 8))
conn.respBodyIV = md5.Sum(conn.reqBodyIV[:]) conn.respBodyIV = md5.Sum(conn.reqBodyIV[:])
conn.respBodyKey = md5.Sum(conn.reqBodyKey[:]) conn.respBodyKey = md5.Sum(conn.reqBodyKey[:])
// AuthInfo // Auth
_, err = rc.Write(conn.EncodeAuthInfo()) err = conn.Auth()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Request // Request
req, err := conn.EncodeRequest() err = conn.Request()
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, err = rc.Write(req)
if err != nil {
return nil, err
}
conn.Conn = rc
return conn, nil return conn, nil
} }
// EncodeAuthInfo returns HMAC("md5", UUID, UTC) result // Auth send auth info: HMAC("md5", UUID, UTC)
func (c *Conn) EncodeAuthInfo() []byte { func (c *Conn) Auth() error {
ts := make([]byte, 8) ts := pool.GetBuffer(8)
defer pool.PutBuffer(ts)
binary.BigEndian.PutUint64(ts, uint64(time.Now().UTC().Unix())) binary.BigEndian.PutUint64(ts, uint64(time.Now().UTC().Unix()))
h := hmac.New(md5.New, c.user.UUID[:]) h := hmac.New(md5.New, c.user.UUID[:])
h.Write(ts) h.Write(ts)
return h.Sum(nil) _, err := c.Conn.Write(h.Sum(nil))
return err
} }
// EncodeRequest encodes requests to network bytes. // Request sends request to server.
func (c *Conn) EncodeRequest() ([]byte, error) { func (c *Conn) Request() error {
var buf bytes.Buffer buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
// Request // Request
buf.WriteByte(1) // Ver buf.WriteByte(1) // Ver
@ -178,9 +177,9 @@ func (c *Conn) EncodeRequest() ([]byte, error) {
buf.WriteByte(CmdTCP) // cmd buf.WriteByte(CmdTCP) // cmd
// target // target
err := binary.Write(&buf, binary.BigEndian, uint16(c.port)) // port err := binary.Write(buf, binary.BigEndian, uint16(c.port)) // port
if err != nil { if err != nil {
return nil, err return err
} }
buf.WriteByte(byte(c.atyp)) // atyp buf.WriteByte(byte(c.atyp)) // atyp
@ -188,28 +187,31 @@ func (c *Conn) EncodeRequest() ([]byte, error) {
// padding // padding
if paddingLen > 0 { if paddingLen > 0 {
padding := make([]byte, paddingLen) padding := pool.GetBuffer(paddingLen)
rand.Read(padding) rand.Read(padding)
buf.Write(padding) buf.Write(padding)
pool.PutBuffer(padding)
} }
// F // F
fnv1a := fnv.New32a() fnv1a := fnv.New32a()
_, err = fnv1a.Write(buf.Bytes()) _, err = fnv1a.Write(buf.Bytes())
if err != nil { if err != nil {
return nil, err return err
} }
buf.Write(fnv1a.Sum(nil)) buf.Write(fnv1a.Sum(nil))
block, err := aes.NewCipher(c.user.CmdKey[:]) block, err := aes.NewCipher(c.user.CmdKey[:])
if err != nil { if err != nil {
return nil, err return err
} }
stream := cipher.NewCFBEncrypter(block, TimestampHash(time.Now().UTC())) stream := cipher.NewCFBEncrypter(block, TimestampHash(time.Now().UTC()))
stream.XORKeyStream(buf.Bytes(), buf.Bytes()) stream.XORKeyStream(buf.Bytes(), buf.Bytes())
return buf.Bytes(), nil _, err = c.Conn.Write(buf.Bytes())
return err
} }
// DecodeRespHeader decodes response header. // DecodeRespHeader decodes response header.
@ -221,20 +223,22 @@ func (c *Conn) DecodeRespHeader() error {
stream := cipher.NewCFBDecrypter(block, c.respBodyIV[:]) stream := cipher.NewCFBDecrypter(block, c.respBodyIV[:])
buf := make([]byte, 4) b := pool.GetBuffer(4)
_, err = io.ReadFull(c.Conn, buf) defer pool.PutBuffer(b)
_, err = io.ReadFull(c.Conn, b)
if err != nil { if err != nil {
return err return err
} }
stream.XORKeyStream(buf, buf) stream.XORKeyStream(b, b)
if buf[0] != c.reqRespV { if b[0] != c.reqRespV {
return errors.New("unexpected response header") return errors.New("unexpected response header")
} }
// TODO: Dynamic port support // TODO: Dynamic port support
if buf[2] != 0 { if b[2] != 0 {
// dataLen := int32(buf[3]) // dataLen := int32(buf[3])
return errors.New("dynamic port is not supported now") return errors.New("dynamic port is not supported now")
} }
@ -259,13 +263,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:]) c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
case SecurityChacha20Poly1305: case SecurityChacha20Poly1305:
key := make([]byte, 32) key := pool.GetBuffer(32)
t := md5.Sum(c.reqBodyKey[:]) t := md5.Sum(c.reqBodyKey[:])
copy(key, t[:]) copy(key, t[:])
t = md5.Sum(key[:16]) t = md5.Sum(key[:16])
copy(key[16:], t[:]) copy(key[16:], t[:])
aead, _ := chacha20poly1305.New(key) aead, _ := chacha20poly1305.New(key)
c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:]) c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:])
pool.PutBuffer(key)
} }
} }
@ -294,13 +299,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:]) c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
case SecurityChacha20Poly1305: case SecurityChacha20Poly1305:
key := make([]byte, 32) key := pool.GetBuffer(32)
t := md5.Sum(c.respBodyKey[:]) t := md5.Sum(c.respBodyKey[:])
copy(key, t[:]) copy(key, t[:])
t = md5.Sum(key[:16]) t = md5.Sum(key[:16])
copy(key[16:], t[:]) copy(key[16:], t[:])
aead, _ := chacha20poly1305.New(key) aead, _ := chacha20poly1305.New(key)
c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:]) c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:])
pool.PutBuffer(key)
} }
} }

View File

@ -8,6 +8,8 @@ import (
"errors" "errors"
"strings" "strings"
"time" "time"
"github.com/nadoo/glider/common/pool"
) )
// User of vmess client // User of vmess client
@ -73,10 +75,11 @@ func GetKey(uuid [16]byte) []byte {
// TimestampHash returns the iv of AES-128-CFB encrypter // TimestampHash returns the iv of AES-128-CFB encrypter
// IVMD5(X + X + X + X)X = []byte(timestamp.now) (8 bytes, Big Endian) // IVMD5(X + X + X + X)X = []byte(timestamp.now) (8 bytes, Big Endian)
func TimestampHash(t time.Time) []byte { func TimestampHash(t time.Time) []byte {
md5hash := md5.New() ts := pool.GetBuffer(8)
defer pool.PutBuffer(ts)
ts := make([]byte, 8)
binary.BigEndian.PutUint64(ts, uint64(t.UTC().Unix())) binary.BigEndian.PutUint64(ts, uint64(t.UTC().Unix()))
md5hash := md5.New()
md5hash.Write(ts) md5hash.Write(ts)
md5hash.Write(ts) md5hash.Write(ts)
md5hash.Write(ts) md5hash.Write(ts)

View File

@ -2,7 +2,6 @@ package ws
import ( import (
"bufio" "bufio"
"bytes"
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
@ -11,6 +10,8 @@ import (
"net" "net"
"net/textproto" "net/textproto"
"strings" "strings"
"github.com/nadoo/glider/common/pool"
) )
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@ -47,7 +48,9 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
func (c *Conn) Handshake(host, path string) error { func (c *Conn) Handshake(host, path string) error {
clientKey := generateClientKey() clientKey := generateClientKey()
var buf bytes.Buffer buf := pool.GetWriteBuffer()
defer pool.PutWriteBuffer(buf)
buf.WriteString("GET " + path + " HTTP/1.1\r\n") buf.WriteString("GET " + path + " HTTP/1.1\r\n")
buf.WriteString("Host: " + host + "\r\n") buf.WriteString("Host: " + host + "\r\n")
buf.WriteString("Upgrade: websocket\r\n") buf.WriteString("Upgrade: websocket\r\n")