glider/vendor/github.com/nadoo/go-shadowsocks2/shadowstream/stream.go

172 lines
3.0 KiB
Go
Raw Normal View History

2018-09-02 22:16:24 +08:00
package shadowstream
import (
"bytes"
"crypto/cipher"
"crypto/rand"
"io"
"net"
)
const bufSize = 32 * 1024
type writer struct {
io.Writer
cipher.Stream
buf []byte
}
// NewWriter wraps an io.Writer with stream cipher encryption.
func NewWriter(w io.Writer, s cipher.Stream) io.Writer {
return &writer{Writer: w, Stream: s, buf: make([]byte, bufSize)}
}
func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
nr, er := r.Read(buf)
if nr > 0 {
n += int64(nr)
buf = buf[:nr]
w.XORKeyStream(buf, buf)
_, ew := w.Writer.Write(buf)
if ew != nil {
err = ew
return
}
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
err = er
}
return
}
}
}
func (w *writer) Write(b []byte) (int, error) {
n, err := w.ReadFrom(bytes.NewBuffer(b))
return int(n), err
}
type reader struct {
io.Reader
cipher.Stream
buf []byte
}
// NewReader wraps an io.Reader with stream cipher decryption.
func NewReader(r io.Reader, s cipher.Stream) io.Reader {
return &reader{Reader: r, Stream: s, buf: make([]byte, bufSize)}
}
func (r *reader) Read(b []byte) (int, error) {
n, err := r.Reader.Read(b)
if err != nil {
return 0, err
}
b = b[:n]
r.XORKeyStream(b, b)
return n, nil
}
func (r *reader) WriteTo(w io.Writer) (n int64, err error) {
for {
buf := r.buf
nr, er := r.Read(buf)
if nr > 0 {
nw, ew := w.Write(buf[:nr])
n += int64(nw)
if ew != nil {
err = ew
return
}
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut)
err = er
}
return
}
}
}
type conn struct {
net.Conn
Cipher
r *reader
w *writer
}
// NewConn wraps a stream-oriented net.Conn with stream cipher encryption/decryption.
func NewConn(c net.Conn, ciph Cipher) net.Conn {
return &conn{Conn: c, Cipher: ciph}
}
func (c *conn) initReader() error {
if c.r == nil {
buf := make([]byte, bufSize)
iv := buf[:c.IVSize()]
if _, err := io.ReadFull(c.Conn, iv); err != nil {
return err
}
c.r = &reader{Reader: c.Conn, Stream: c.Decrypter(iv), buf: buf}
}
return nil
}
func (c *conn) Read(b []byte) (int, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.Read(b)
}
func (c *conn) WriteTo(w io.Writer) (int64, error) {
if c.r == nil {
if err := c.initReader(); err != nil {
return 0, err
}
}
return c.r.WriteTo(w)
}
func (c *conn) initWriter() error {
if c.w == nil {
buf := make([]byte, bufSize)
iv := buf[:c.IVSize()]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return err
}
if _, err := c.Conn.Write(iv); err != nil {
return err
}
c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf}
}
return nil
}
func (c *conn) Write(b []byte) (int, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.Write(b)
}
func (c *conn) ReadFrom(r io.Reader) (int64, error) {
if c.w == nil {
if err := c.initWriter(); err != nil {
return 0, err
}
}
return c.w.ReadFrom(r)
}