From 91f41c6dfa8e4370f1ca2570c6db63ec3b02da2d Mon Sep 17 00:00:00 2001 From: mzz2017 Date: Thu, 23 Dec 2021 18:53:40 +0800 Subject: [PATCH] fix(vmess): panic caused by length exceeded --- proxy/vmess/aead.go | 61 ++++++++++++++++++++++----------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/proxy/vmess/aead.go b/proxy/vmess/aead.go index 7436c61..632b0b2 100644 --- a/proxy/vmess/aead.go +++ b/proxy/vmess/aead.go @@ -79,54 +79,53 @@ func AEADReader(r io.Reader, aead cipher.AEAD, iv []byte, chunkSizeDecoder Chunk return ar } -func (r *aeadReader) read(p []byte) (int, error) { - if _, err := io.ReadFull(r.Reader, p[:r.chunkSizeDecoder.SizeBytes()]); err != nil { - return 0, err +func (r *aeadReader) readChunkPool() ([]byte, error) { + bSize := pool.GetBuffer(int(r.chunkSizeDecoder.SizeBytes())) + defer pool.PutBuffer(bSize) + if _, err := io.ReadFull(r.Reader, bSize); err != nil { + return nil, err } - size, err := r.chunkSizeDecoder.Decode(p[:r.chunkSizeDecoder.SizeBytes()]) + size, err := r.chunkSizeDecoder.Decode(bSize) if err != nil { - return 0, err + return nil, err } - p = p[:size] - if _, err := io.ReadFull(r.Reader, p); err != nil { - return 0, err + chunk := pool.GetBuffer(int(size)) + if _, err := io.ReadFull(r.Reader, chunk); err != nil { + return nil, err } binary.BigEndian.PutUint16(r.nonce[:2], r.count) - _, err = r.Open(p[:0], r.nonce[:r.NonceSize()], p, nil) + _, err = r.Open(chunk[:0], r.nonce[:r.NonceSize()], chunk, nil) r.count++ if err != nil { - return 0, err + return nil, err } - return int(size) - r.Overhead(), nil + return chunk[:int(size)-r.Overhead()], nil } func (r *aeadReader) Read(p []byte) (int, error) { - if r.buf == nil { - if len(p) >= chunkSize { - return r.read(p) + if r.buf != nil { + n := copy(p, r.buf[r.offset:]) + r.offset += n + if r.offset >= len(r.buf) { + pool.PutBuffer(r.buf) + r.buf = nil } - - buf := pool.GetBuffer(chunkSize) - n, err := r.read(buf) - if err != nil || n == 0 { - pool.PutBuffer(buf) - return 0, err - } - - r.buf = buf[:n] - r.offset = 0 + return n, nil } - - n := copy(p, r.buf[r.offset:]) - r.offset += n - if r.offset == len(r.buf) { - pool.PutBuffer(r.buf) - r.buf = nil + chunk, err := r.readChunkPool() + if err != nil { + return 0, err + } + n := copy(p, chunk) + if len(chunk) > len(p) { + r.buf = chunk + r.offset = len(p) + } else { + pool.PutBuffer(chunk) } - return n, nil }