From d84f540c4da3f12870049e91d417db1efc57de15 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Wed, 11 Jul 2018 00:26:05 +0800 Subject: [PATCH] vmess: add aead chpher --- proxy/vmess/aead.go | 135 ++++++++++++++++++++++++++++++++++++++++++ proxy/vmess/chunk.go | 107 ++++++++++++++++----------------- proxy/vmess/client.go | 99 ++++++++++++++++++++----------- 3 files changed, 255 insertions(+), 86 deletions(-) create mode 100644 proxy/vmess/aead.go diff --git a/proxy/vmess/aead.go b/proxy/vmess/aead.go new file mode 100644 index 0000000..959d285 --- /dev/null +++ b/proxy/vmess/aead.go @@ -0,0 +1,135 @@ +package vmess + +import ( + "bytes" + "crypto/cipher" + "encoding/binary" + "io" +) + +type aeadWriter struct { + io.Writer + cipher.AEAD + nonce []byte + buf []byte + count uint16 + iv []byte +} + +// AEADWriter returns a aead writer +func AEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) io.Writer { + return &aeadWriter{ + Writer: w, + AEAD: aead, + buf: make([]byte, lenSize+maxChunkSize), + nonce: make([]byte, aead.NonceSize()), + count: 0, + iv: iv, + } +} + +func (w *aeadWriter) Write(b []byte) (int, error) { + n, err := w.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +func (w *aeadWriter) ReadFrom(r io.Reader) (n int64, err error) { + for { + buf := w.buf + payloadBuf := buf[lenSize : lenSize+defaultChunkSize] + + nr, er := r.Read(payloadBuf) + if nr > 0 { + n += int64(nr) + buf = buf[:lenSize+nr+w.Overhead()] + payloadBuf = payloadBuf[:nr] + binary.BigEndian.PutUint16(buf[:lenSize], uint16(nr)) + + binary.BigEndian.PutUint16(w.nonce[:2], w.count) + copy(w.nonce[2:], w.iv[2:12]) + + w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil) + w.count++ + + _, ew := w.Writer.Write(buf) + if ew != nil { + err = ew + break + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.ReaderFrom contract + err = er + } + break + } + } + + return n, err +} + +type aeadReader struct { + io.Reader + cipher.AEAD + nonce []byte + buf []byte + leftover []byte + count uint16 + iv []byte +} + +// AEADReader returns a aead reader +func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte) io.Reader { + return &aeadReader{ + Reader: r, + AEAD: aead, + buf: make([]byte, lenSize+maxChunkSize), + nonce: make([]byte, aead.NonceSize()), + count: 0, + iv: iv, + } +} + +func (r *aeadReader) Read(b []byte) (int, error) { + if len(r.leftover) > 0 { + n := copy(b, r.leftover) + r.leftover = r.leftover[n:] + return n, nil + } + + // get length + _, err := io.ReadFull(r.Reader, r.buf[:lenSize]) + if err != nil { + return 0, err + } + + // if length == 0, then this is the end + len := binary.BigEndian.Uint16(r.buf[:lenSize]) + if len == 0 { + return 0, nil + } + + // get payload + buf := r.buf[:len] + _, err = io.ReadFull(r.Reader, buf) + if err != nil { + return 0, err + } + + binary.BigEndian.PutUint16(r.nonce[:2], r.count) + copy(r.nonce[2:], r.iv[2:12]) + + _, err = r.Open(buf[:0], r.nonce, buf, nil) + r.count++ + if err != nil { + return 0, err + } + + m := copy(b, r.buf[:len]) + if m < int(len) { + r.leftover = r.buf[m:len] + } + + return m, err +} diff --git a/proxy/vmess/chunk.go b/proxy/vmess/chunk.go index 51e93c5..f845799 100644 --- a/proxy/vmess/chunk.go +++ b/proxy/vmess/chunk.go @@ -6,68 +6,22 @@ import ( "io" ) -// chunk: plain, AES-128-CFB, AES-128-GCM, ChaCha20-Poly1305 - const ( + lenSize = 2 maxChunkSize = 1 << 14 // 16384 defaultChunkSize = 1 << 13 // 8192 ) -type chunkedReader struct { - io.Reader - buf []byte - leftover []byte -} - -func newChunkedReader(r io.Reader) io.Reader { - return &chunkedReader{ - Reader: r, - buf: make([]byte, 2+maxChunkSize), - } -} - -func (r *chunkedReader) Read(b []byte) (int, error) { - if len(r.leftover) > 0 { - n := copy(b, r.leftover) - r.leftover = r.leftover[n:] - return n, nil - } - - // get length - _, err := io.ReadFull(r.Reader, r.buf[:2]) - if err != nil { - return 0, err - } - - // if length == 0, then this is the end - len := binary.BigEndian.Uint16(r.buf[:2]) - if len == 0 { - return 0, nil - } - - // get payload - _, err = io.ReadFull(r.Reader, r.buf[:len]) - if err != nil { - return 0, err - } - - m := copy(b, r.buf[:len]) - if m < int(len) { - r.leftover = r.buf[m:len] - } - - return m, err -} - type chunkedWriter struct { io.Writer buf []byte } -func newChunkedWriter(w io.Writer) io.Writer { +// ChunkedWriter returns a chunked writer +func ChunkedWriter(w io.Writer) io.Writer { return &chunkedWriter{ Writer: w, - buf: make([]byte, 2+maxChunkSize), + buf: make([]byte, lenSize+maxChunkSize), } } @@ -79,14 +33,14 @@ func (w *chunkedWriter) Write(b []byte) (int, error) { func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) { for { buf := w.buf - payloadBuf := buf[2 : 2+defaultChunkSize] + payloadBuf := buf[lenSize : lenSize+defaultChunkSize] nr, er := r.Read(payloadBuf) if nr > 0 { n += int64(nr) - buf = buf[:2+nr] + buf = buf[:lenSize+nr] payloadBuf = payloadBuf[:nr] - binary.BigEndian.PutUint16(buf[:2], uint16(nr)) + binary.BigEndian.PutUint16(buf[:lenSize], uint16(nr)) _, ew := w.Writer.Write(buf) if ew != nil { @@ -105,3 +59,50 @@ func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) { return n, err } + +type chunkedReader struct { + io.Reader + buf []byte + leftover []byte +} + +// ChunkedReader returns a chunked reader +func ChunkedReader(r io.Reader) io.Reader { + return &chunkedReader{ + Reader: r, + buf: make([]byte, lenSize+maxChunkSize), + } +} + +func (r *chunkedReader) Read(b []byte) (int, error) { + if len(r.leftover) > 0 { + n := copy(b, r.leftover) + r.leftover = r.leftover[n:] + return n, nil + } + + // get length + _, err := io.ReadFull(r.Reader, r.buf[:lenSize]) + if err != nil { + return 0, err + } + + // if length == 0, then this is the end + len := binary.BigEndian.Uint16(r.buf[:lenSize]) + if len == 0 { + return 0, nil + } + + // get payload + _, err = io.ReadFull(r.Reader, r.buf[:len]) + if err != nil { + return 0, err + } + + m := copy(b, r.buf[:len]) + if m < int(len) { + r.leftover = r.buf[m:len] + } + + return m, err +} diff --git a/proxy/vmess/client.go b/proxy/vmess/client.go index 83f5d49..a2dc9b9 100644 --- a/proxy/vmess/client.go +++ b/proxy/vmess/client.go @@ -14,21 +14,20 @@ import ( "net" "strings" "time" + + "golang.org/x/crypto/chacha20poly1305" ) // Request Options const ( - OptBasicFormat byte = 0 - OptChunkStream byte = 1 - OptReuseTCPConnection byte = 2 - OptMetadataObfuscate byte = 4 + OptBasicFormat byte = 0 + OptChunkStream byte = 1 + // OptReuseTCPConnection byte = 2 + // OptMetadataObfuscate byte = 4 ) // Security types const ( - SecurityUnknown byte = 0 // don't use in client - SecurityLegacy byte = 1 // don't use in client (aes-128-cfb) - SecurityAuto byte = 2 // don't use in client SecurityAES128GCM byte = 3 SecurityChacha20Poly1305 byte = 4 SecurityNone byte = 5 @@ -61,12 +60,10 @@ type Conn struct { reqBodyIV [16]byte reqBodyKey [16]byte reqRespV byte - respBodyKey [16]byte respBodyIV [16]byte + respBodyKey [16]byte net.Conn - connected bool - dataReader io.Reader dataWriter io.Writer } @@ -84,19 +81,22 @@ func NewClient(uuidStr, security string, alterID int) (*Client, error) { c.users = append(c.users, user.GenAlterIDUsers(alterID)...) c.count = len(c.users) - // TODO: config? c.opt = OptChunkStream security = strings.ToLower(security) switch security { - case "aes-128-cfb": - c.security = SecurityLegacy case "aes-128-gcm": c.security = SecurityAES128GCM case "chacha20-poly1305": c.security = SecurityChacha20Poly1305 - default: + case "none": c.security = SecurityNone + case "": + // NOTE: use basic format when no method specified + c.opt = OptBasicFormat + c.security = SecurityNone + default: + return nil, errors.New("unknown security type: " + security) } return c, nil @@ -218,33 +218,66 @@ func (c *Conn) DecodeRespHeader() error { return errors.New("dynamic port is not supported now") } - c.connected = true return nil } -func (c *Conn) Read(b []byte) (n int, err error) { - if !c.connected { - c.DecodeRespHeader() +func (c *Conn) Write(b []byte) (n int, err error) { + if c.dataWriter != nil { + return c.dataWriter.Write(b) } - if c.opt&OptChunkStream != 0 { - if c.dataReader == nil { - c.dataReader = newChunkedReader(c.Conn) - } + c.dataWriter = c.Conn + if c.opt&OptChunkStream == OptChunkStream { + switch c.security { + case SecurityNone: + c.dataWriter = ChunkedWriter(c.Conn) + case SecurityAES128GCM: + block, _ := aes.NewCipher(c.reqBodyKey[:]) + aead, _ := cipher.NewGCM(block) + c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:]) + + case SecurityChacha20Poly1305: + h := md5.New() + h.Write(c.reqBodyKey[:]) + key := h.Sum(h.Sum(nil)) + aead, _ := chacha20poly1305.New(key) + c.dataWriter = AEADWriter(c.Conn, aead, c.reqBodyIV[:]) + } + } + + return c.dataWriter.Write(b) +} + +func (c *Conn) Read(b []byte) (n int, err error) { + if c.dataReader != nil { return c.dataReader.Read(b) } - return c.Conn.Read(b) -} - -func (c *Conn) Write(b []byte) (n int, err error) { - if c.opt&OptChunkStream != 0 { - if c.dataWriter == nil { - c.dataWriter = newChunkedWriter(c.Conn) - } - - return c.dataWriter.Write(b) + err = c.DecodeRespHeader() + if err != nil { + return 0, err } - return c.Conn.Write(b) + + c.dataReader = c.Conn + if c.opt&OptChunkStream == OptChunkStream { + switch c.security { + case SecurityNone: + c.dataReader = ChunkedReader(c.Conn) + + case SecurityAES128GCM: + block, _ := aes.NewCipher(c.respBodyKey[:]) + aead, _ := cipher.NewGCM(block) + c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:]) + + case SecurityChacha20Poly1305: + h := md5.New() + h.Write(c.respBodyKey[:]) + key := h.Sum(h.Sum(nil)) + aead, _ := chacha20poly1305.New(key) + c.dataReader = AEADReader(c.Conn, aead, c.respBodyIV[:]) + } + } + + return c.dataReader.Read(b) }