mirror of
https://github.com/nadoo/glider.git
synced 2025-02-24 17:55:41 +08:00
271 lines
5.5 KiB
Go
271 lines
5.5 KiB
Go
![]() |
package shadowaead
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"crypto/cipher"
|
||
|
"crypto/rand"
|
||
|
"io"
|
||
|
"net"
|
||
|
)
|
||
|
|
||
|
// payloadSizeMask is the maximum size of payload in bytes.
|
||
|
const payloadSizeMask = 0x3FFF // 16*1024 - 1
|
||
|
|
||
|
type writer struct {
|
||
|
io.Writer
|
||
|
cipher.AEAD
|
||
|
nonce []byte
|
||
|
buf []byte
|
||
|
}
|
||
|
|
||
|
// NewWriter wraps an io.Writer with AEAD encryption.
|
||
|
func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { return newWriter(w, aead) }
|
||
|
|
||
|
func newWriter(w io.Writer, aead cipher.AEAD) *writer {
|
||
|
return &writer{
|
||
|
Writer: w,
|
||
|
AEAD: aead,
|
||
|
buf: make([]byte, 2+aead.Overhead()+payloadSizeMask+aead.Overhead()),
|
||
|
nonce: make([]byte, aead.NonceSize()),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Write encrypts b and writes to the embedded io.Writer.
|
||
|
func (w *writer) Write(b []byte) (int, error) {
|
||
|
n, err := w.ReadFrom(bytes.NewBuffer(b))
|
||
|
return int(n), err
|
||
|
}
|
||
|
|
||
|
// ReadFrom reads from the given io.Reader until EOF or error, encrypts and
|
||
|
// writes to the embedded io.Writer. Returns number of bytes read from r and
|
||
|
// any error encountered.
|
||
|
func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||
|
for {
|
||
|
buf := w.buf
|
||
|
payloadBuf := buf[2+w.Overhead() : 2+w.Overhead()+payloadSizeMask]
|
||
|
nr, er := r.Read(payloadBuf)
|
||
|
|
||
|
if nr > 0 {
|
||
|
n += int64(nr)
|
||
|
buf = buf[:2+w.Overhead()+nr+w.Overhead()]
|
||
|
payloadBuf = payloadBuf[:nr]
|
||
|
buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size
|
||
|
w.Seal(buf[:0], w.nonce, buf[:2], nil)
|
||
|
increment(w.nonce)
|
||
|
|
||
|
w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil)
|
||
|
increment(w.nonce)
|
||
|
|
||
|
_, ew := w.Writer.Write(buf)
|
||
|
if ew != nil {
|
||
|
err = ew
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if er != nil {
|
||
|
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
|
||
|
err = er
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
type reader struct {
|
||
|
io.Reader
|
||
|
cipher.AEAD
|
||
|
nonce []byte
|
||
|
buf []byte
|
||
|
leftover []byte
|
||
|
}
|
||
|
|
||
|
// NewReader wraps an io.Reader with AEAD decryption.
|
||
|
func NewReader(r io.Reader, aead cipher.AEAD) io.Reader { return newReader(r, aead) }
|
||
|
|
||
|
func newReader(r io.Reader, aead cipher.AEAD) *reader {
|
||
|
return &reader{
|
||
|
Reader: r,
|
||
|
AEAD: aead,
|
||
|
buf: make([]byte, payloadSizeMask+aead.Overhead()),
|
||
|
nonce: make([]byte, aead.NonceSize()),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// read and decrypt a record into the internal buffer. Return decrypted payload length and any error encountered.
|
||
|
func (r *reader) read() (int, error) {
|
||
|
// decrypt payload size
|
||
|
buf := r.buf[:2+r.Overhead()]
|
||
|
_, err := io.ReadFull(r.Reader, buf)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
_, err = r.Open(buf[:0], r.nonce, buf, nil)
|
||
|
increment(r.nonce)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
size := (int(buf[0])<<8 + int(buf[1])) & payloadSizeMask
|
||
|
|
||
|
// decrypt payload
|
||
|
buf = r.buf[:size+r.Overhead()]
|
||
|
_, err = io.ReadFull(r.Reader, buf)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
_, err = r.Open(buf[:0], r.nonce, buf, nil)
|
||
|
increment(r.nonce)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
return size, nil
|
||
|
}
|
||
|
|
||
|
// Read reads from the embedded io.Reader, decrypts and writes to b.
|
||
|
func (r *reader) Read(b []byte) (int, error) {
|
||
|
// copy decrypted bytes (if any) from previous record first
|
||
|
if len(r.leftover) > 0 {
|
||
|
n := copy(b, r.leftover)
|
||
|
r.leftover = r.leftover[n:]
|
||
|
return n, nil
|
||
|
}
|
||
|
|
||
|
n, err := r.read()
|
||
|
m := copy(b, r.buf[:n])
|
||
|
if m < n { // insufficient len(b), keep leftover for next read
|
||
|
r.leftover = r.buf[m:n]
|
||
|
}
|
||
|
return m, err
|
||
|
}
|
||
|
|
||
|
// WriteTo reads from the embedded io.Reader, decrypts and writes to w until
|
||
|
// there's no more data to write or when an error occurs. Return number of
|
||
|
// bytes written to w and any error encountered.
|
||
|
func (r *reader) WriteTo(w io.Writer) (n int64, err error) {
|
||
|
// write decrypted bytes left over from previous record
|
||
|
for len(r.leftover) > 0 {
|
||
|
nw, ew := w.Write(r.leftover)
|
||
|
r.leftover = r.leftover[nw:]
|
||
|
n += int64(nw)
|
||
|
if ew != nil {
|
||
|
return n, ew
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
nr, er := r.read()
|
||
|
if nr > 0 {
|
||
|
nw, ew := w.Write(r.buf[:nr])
|
||
|
n += int64(nw)
|
||
|
|
||
|
if ew != nil {
|
||
|
err = ew
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if er != nil {
|
||
|
if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut)
|
||
|
err = er
|
||
|
}
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return n, err
|
||
|
}
|
||
|
|
||
|
// increment little-endian encoded unsigned integer b. Wrap around on overflow.
|
||
|
func increment(b []byte) {
|
||
|
for i := range b {
|
||
|
b[i]++
|
||
|
if b[i] != 0 {
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type streamConn struct {
|
||
|
net.Conn
|
||
|
Cipher
|
||
|
r *reader
|
||
|
w *writer
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) initReader() error {
|
||
|
salt := make([]byte, c.SaltSize())
|
||
|
if _, err := io.ReadFull(c.Conn, salt); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
aead, err := c.Decrypter(salt)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
c.r = newReader(c.Conn, aead)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) Read(b []byte) (int, error) {
|
||
|
if c.r == nil {
|
||
|
if err := c.initReader(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return c.r.Read(b)
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) WriteTo(w io.Writer) (int64, error) {
|
||
|
if c.r == nil {
|
||
|
if err := c.initReader(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return c.r.WriteTo(w)
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) initWriter() error {
|
||
|
salt := make([]byte, c.SaltSize())
|
||
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
aead, err := c.Encrypter(salt)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = c.Conn.Write(salt)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
c.w = newWriter(c.Conn, aead)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) Write(b []byte) (int, error) {
|
||
|
if c.w == nil {
|
||
|
if err := c.initWriter(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return c.w.Write(b)
|
||
|
}
|
||
|
|
||
|
func (c *streamConn) ReadFrom(r io.Reader) (int64, error) {
|
||
|
if c.w == nil {
|
||
|
if err := c.initWriter(); err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
}
|
||
|
return c.w.ReadFrom(r)
|
||
|
}
|
||
|
|
||
|
// NewConn wraps a stream-oriented net.Conn with cipher.
|
||
|
func NewConn(c net.Conn, ciph Cipher) net.Conn { return &streamConn{Conn: c, Cipher: ciph} }
|