glider/vendor/github.com/shadowsocks/go-shadowsocks2/shadowaead/stream.go

271 lines
5.5 KiB
Go
Raw Normal View History

2018-07-07 11:07:38 +08:00
package shadowaead
import (
"bytes"
"crypto/cipher"
"crypto/rand"
"io"
"net"
)
// payloadSizeMask is the maximum size of payload in bytes.
const payloadSizeMask = 0x3FFF // 16*1024 - 1
type writer struct {
io.Writer
cipher.AEAD
nonce []byte
buf []byte
}
// NewWriter wraps an io.Writer with AEAD encryption.
func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { return newWriter(w, aead) }
func newWriter(w io.Writer, aead cipher.AEAD) *writer {
return &writer{
Writer: w,
AEAD: aead,
buf: make([]byte, 2+aead.Overhead()+payloadSizeMask+aead.Overhead()),
nonce: make([]byte, aead.NonceSize()),
}
}
// Write encrypts b and writes to the embedded io.Writer.
func (w *writer) Write(b []byte) (int, error) {
n, err := w.ReadFrom(bytes.NewBuffer(b))
return int(n), err
}
// ReadFrom reads from the given io.Reader until EOF or error, encrypts and
// writes to the embedded io.Writer. Returns number of bytes read from r and
// any error encountered.
func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
payloadBuf := buf[2+w.Overhead() : 2+w.Overhead()+payloadSizeMask]
nr, er := r.Read(payloadBuf)
if nr > 0 {
n += int64(nr)
buf = buf[:2+w.Overhead()+nr+w.Overhead()]
payloadBuf = payloadBuf[:nr]
buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size
w.Seal(buf[:0], w.nonce, buf[:2], nil)
increment(w.nonce)
w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil)
increment(w.nonce)
_, ew := w.Writer.Write(buf)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
err = er
}
break
}
}
return n, err
}
type reader struct {
io.Reader
cipher.AEAD
nonce []byte
buf []byte
leftover []byte
}
// NewReader wraps an io.Reader with AEAD decryption.
func NewReader(r io.Reader, aead cipher.AEAD) io.Reader { return newReader(r, aead) }
func newReader(r io.Reader, aead cipher.AEAD) *reader {
return &reader{
Reader: r,
AEAD: aead,
buf: make([]byte, payloadSizeMask+aead.Overhead()),
nonce: make([]byte, aead.NonceSize()),
}
}
// read and decrypt a record into the internal buffer. Return decrypted payload length and any error encountered.
func (r *reader) read() (int, error) {
// decrypt payload size
buf := r.buf[:2+r.Overhead()]
_, err := io.ReadFull(r.Reader, buf)
if err != nil {
return 0, err
}
_, err = r.Open(buf[:0], r.nonce, buf, nil)
increment(r.nonce)
if err != nil {
return 0, err
}
size := (int(buf[0])<<8 + int(buf[1])) & payloadSizeMask
// decrypt payload
buf = r.buf[:size+r.Overhead()]
_, err = io.ReadFull(r.Reader, buf)
if err != nil {
return 0, err
}
_, err = r.Open(buf[:0], r.nonce, buf, nil)
increment(r.nonce)
if err != nil {
return 0, err
}
return size, nil
}
// Read reads from the embedded io.Reader, decrypts and writes to b.
func (r *reader) Read(b []byte) (int, error) {
// copy decrypted bytes (if any) from previous record first
if len(r.leftover) > 0 {
n := copy(b, r.leftover)
r.leftover = r.leftover[n:]
return n, nil
}
n, err := r.read()
m := copy(b, r.buf[:n])
if m < n { // insufficient len(b), keep leftover for next read
r.leftover = r.buf[m:n]
}
return m, err
}
// WriteTo reads from the embedded io.Reader, decrypts and writes to w until
// there's no more data to write or when an error occurs. Return number of
// bytes written to w and any error encountered.
func (r *reader) WriteTo(w io.Writer) (n int64, err error) {
// write decrypted bytes left over from previous record
for len(r.leftover) > 0 {
nw, ew := w.Write(r.leftover)
r.leftover = r.leftover[nw:]
n += int64(nw)
if ew != nil {
return n, ew
}
}
for {
nr, er := r.read()
if nr > 0 {
nw, ew := w.Write(r.buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut)
err = er
}
break
}
}
return n, err
}
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
func increment(b []byte) {
for i := range b {
b[i]++
if b[i] != 0 {
return
}
}
}
type streamConn struct {
net.Conn
Cipher
r *reader
w *writer
}
func (c *streamConn) initReader() error {
salt := make([]byte, c.SaltSize())
if _, err := io.ReadFull(c.Conn, salt); err != nil {
return err
}
aead, err := c.Decrypter(salt)
if err != nil {
return err
}
c.r = newReader(c.Conn, aead)
return nil
}
func (c *streamConn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *streamConn) WriteTo(w io.Writer) (int64, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.WriteTo(w)
}
func (c *streamConn) initWriter() error {
salt := make([]byte, c.SaltSize())
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return err
}
aead, err := c.Encrypter(salt)
if err != nil {
return err
}
_, err = c.Conn.Write(salt)
if err != nil {
return err
}
c.w = newWriter(c.Conn, aead)
return nil
}
func (c *streamConn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}
func (c *streamConn) ReadFrom(r io.Reader) (int64, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.ReadFrom(r)
}
// NewConn wraps a stream-oriented net.Conn with cipher.
func NewConn(c net.Conn, ciph Cipher) net.Conn { return &streamConn{Conn: c, Cipher: ciph} }