mirror of
https://github.com/nadoo/glider.git
synced 2025-04-11 14:53:15 +08:00
616 lines
15 KiB
Go
616 lines
15 KiB
Go
// MIT License
|
|
//
|
|
// Copyright (c) 2016-2017 xtaci
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
// of this software and associated documentation files (the "Software"), to deal
|
|
// in the Software without restriction, including without limitation the rights
|
|
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
// copies of the Software, and to permit persons to whom the Software is
|
|
// furnished to do so, subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in all
|
|
// copies or substantial portions of the Software.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
// SOFTWARE.
|
|
|
|
package smux
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/nadoo/glider/pkg/pool"
|
|
)
|
|
|
|
// wrapper for GC
|
|
type Stream struct {
|
|
*stream
|
|
}
|
|
|
|
// Stream implements net.Conn
|
|
type stream struct {
|
|
id uint32 // Stream identifier
|
|
sess *Session
|
|
|
|
buffers [][]byte // the sequential buffers of stream
|
|
heads [][]byte // slice heads of the buffers above, kept for recycle
|
|
|
|
bufferLock sync.Mutex // Mutex to protect access to buffers
|
|
frameSize int // Maximum frame size for the stream
|
|
|
|
// notify a read event
|
|
chReadEvent chan struct{}
|
|
|
|
// flag the stream has closed
|
|
die chan struct{}
|
|
dieOnce sync.Once // Ensures die channel is closed only once
|
|
|
|
// FIN command
|
|
chFinEvent chan struct{}
|
|
finEventOnce sync.Once // Ensures chFinEvent is closed only once
|
|
|
|
// deadlines
|
|
readDeadline atomic.Value
|
|
writeDeadline atomic.Value
|
|
|
|
// per stream sliding window control
|
|
numRead uint32 // count num of bytes read
|
|
numWritten uint32 // count num of bytes written
|
|
incr uint32 // bytes sent since last window update
|
|
|
|
// UPD command
|
|
peerConsumed uint32 // num of bytes the peer has consumed
|
|
peerWindow uint32 // peer window, initialized to 256KB, updated by peer
|
|
chUpdate chan struct{} // notify of remote data consuming and window update
|
|
}
|
|
|
|
// newStream initializes and returns a new Stream.
|
|
func newStream(id uint32, frameSize int, sess *Session) *stream {
|
|
s := new(stream)
|
|
s.id = id
|
|
s.chReadEvent = make(chan struct{}, 1)
|
|
s.chUpdate = make(chan struct{}, 1)
|
|
s.frameSize = frameSize
|
|
s.sess = sess
|
|
s.die = make(chan struct{})
|
|
s.chFinEvent = make(chan struct{})
|
|
s.peerWindow = initialPeerWindow // set to initial window size
|
|
|
|
return s
|
|
}
|
|
|
|
// ID returns the stream's unique identifier.
|
|
func (s *stream) ID() uint32 {
|
|
return s.id
|
|
}
|
|
|
|
// Read reads data from the stream into the provided buffer.
|
|
func (s *stream) Read(b []byte) (n int, err error) {
|
|
for {
|
|
n, err = s.tryRead(b)
|
|
if err == ErrWouldBlock {
|
|
if ew := s.waitRead(); ew != nil {
|
|
return 0, ew
|
|
}
|
|
} else {
|
|
return n, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// tryRead attempts to read data from the stream without blocking.
|
|
func (s *stream) tryRead(b []byte) (n int, err error) {
|
|
if s.sess.config.Version == 2 {
|
|
return s.tryReadv2(b)
|
|
}
|
|
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// A critical section to copy data from buffers to
|
|
s.bufferLock.Lock()
|
|
if len(s.buffers) > 0 {
|
|
n = copy(b, s.buffers[0])
|
|
s.buffers[0] = s.buffers[0][n:]
|
|
if len(s.buffers[0]) == 0 {
|
|
s.buffers[0] = nil
|
|
s.buffers = s.buffers[1:]
|
|
// full recycle
|
|
pool.PutBuffer(s.heads[0])
|
|
s.heads = s.heads[1:]
|
|
}
|
|
}
|
|
s.bufferLock.Unlock()
|
|
|
|
if n > 0 {
|
|
s.sess.returnTokens(n)
|
|
return n, nil
|
|
}
|
|
|
|
select {
|
|
case <-s.die:
|
|
return 0, io.EOF
|
|
default:
|
|
return 0, ErrWouldBlock
|
|
}
|
|
}
|
|
|
|
// tryReadv2 is the non-blocking version of Read for version 2 streams.
|
|
func (s *stream) tryReadv2(b []byte) (n int, err error) {
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
var notifyConsumed uint32
|
|
s.bufferLock.Lock()
|
|
if len(s.buffers) > 0 {
|
|
n = copy(b, s.buffers[0])
|
|
s.buffers[0] = s.buffers[0][n:]
|
|
if len(s.buffers[0]) == 0 {
|
|
s.buffers[0] = nil
|
|
s.buffers = s.buffers[1:]
|
|
// full recycle
|
|
pool.PutBuffer(s.heads[0])
|
|
s.heads = s.heads[1:]
|
|
}
|
|
}
|
|
|
|
// in an ideal environment:
|
|
// if more than half of buffer has consumed, send read ack to peer
|
|
// based on round-trip time of ACK, continous flowing data
|
|
// won't slow down due to waiting for ACK, as long as the
|
|
// consumer keeps on reading data.
|
|
//
|
|
// s.numRead == n implies that it's the initial reading
|
|
s.numRead += uint32(n)
|
|
s.incr += uint32(n)
|
|
|
|
// for initial reading, send window update
|
|
if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
|
|
notifyConsumed = s.numRead
|
|
s.incr = 0 // reset couting for next window update
|
|
}
|
|
s.bufferLock.Unlock()
|
|
|
|
if n > 0 {
|
|
s.sess.returnTokens(n)
|
|
|
|
// send window update if necessary
|
|
if notifyConsumed > 0 {
|
|
err := s.sendWindowUpdate(notifyConsumed)
|
|
return n, err
|
|
} else {
|
|
return n, nil
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-s.die:
|
|
return 0, io.EOF
|
|
default:
|
|
return 0, ErrWouldBlock
|
|
}
|
|
}
|
|
|
|
// WriteTo implements io.WriteTo
|
|
// WriteTo writes data to w until there's no more data to write or when an error occurs.
|
|
// The return value n is the number of bytes written. Any error encountered during the write is also returned.
|
|
// WriteTo calls Write in a loop until there is no more data to write or when an error occurs.
|
|
// If the underlying stream is a v2 stream, it will send window update to peer when necessary.
|
|
// If the underlying stream is a v1 stream, it will not send window update to peer.
|
|
func (s *stream) WriteTo(w io.Writer) (n int64, err error) {
|
|
if s.sess.config.Version == 2 {
|
|
return s.writeTov2(w)
|
|
}
|
|
|
|
for {
|
|
var buf []byte
|
|
s.bufferLock.Lock()
|
|
if len(s.buffers) > 0 {
|
|
buf = s.buffers[0]
|
|
s.buffers = s.buffers[1:]
|
|
s.heads = s.heads[1:]
|
|
}
|
|
s.bufferLock.Unlock()
|
|
|
|
if buf != nil {
|
|
nw, ew := w.Write(buf)
|
|
// NOTE: WriteTo is a reader, so we need to return tokens here
|
|
s.sess.returnTokens(len(buf))
|
|
pool.PutBuffer(buf)
|
|
if nw > 0 {
|
|
n += int64(nw)
|
|
}
|
|
|
|
if ew != nil {
|
|
return n, ew
|
|
}
|
|
} else if ew := s.waitRead(); ew != nil {
|
|
return n, ew
|
|
}
|
|
}
|
|
}
|
|
|
|
// check comments in WriteTo
|
|
func (s *stream) writeTov2(w io.Writer) (n int64, err error) {
|
|
for {
|
|
var notifyConsumed uint32
|
|
var buf []byte
|
|
s.bufferLock.Lock()
|
|
if len(s.buffers) > 0 {
|
|
buf = s.buffers[0]
|
|
s.buffers = s.buffers[1:]
|
|
s.heads = s.heads[1:]
|
|
}
|
|
s.numRead += uint32(len(buf))
|
|
s.incr += uint32(len(buf))
|
|
if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
|
|
notifyConsumed = s.numRead
|
|
s.incr = 0
|
|
}
|
|
s.bufferLock.Unlock()
|
|
|
|
if buf != nil {
|
|
nw, ew := w.Write(buf)
|
|
// NOTE: WriteTo is a reader, so we need to return tokens here
|
|
s.sess.returnTokens(len(buf))
|
|
pool.PutBuffer(buf)
|
|
if nw > 0 {
|
|
n += int64(nw)
|
|
}
|
|
|
|
if ew != nil {
|
|
return n, ew
|
|
}
|
|
|
|
if notifyConsumed > 0 {
|
|
if err := s.sendWindowUpdate(notifyConsumed); err != nil {
|
|
return n, err
|
|
}
|
|
}
|
|
} else if ew := s.waitRead(); ew != nil {
|
|
return n, ew
|
|
}
|
|
}
|
|
}
|
|
|
|
// sendWindowUpdate sends a window update frame to the peer.
|
|
func (s *stream) sendWindowUpdate(consumed uint32) error {
|
|
var timer *time.Timer
|
|
var deadline <-chan time.Time
|
|
if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
|
|
timer = time.NewTimer(time.Until(d))
|
|
defer timer.Stop()
|
|
deadline = timer.C
|
|
}
|
|
|
|
frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
|
|
var hdr updHeader
|
|
binary.LittleEndian.PutUint32(hdr[:], consumed)
|
|
binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
|
|
frame.data = hdr[:]
|
|
_, err := s.sess.writeFrameInternal(frame, deadline, CLSCTRL)
|
|
return err
|
|
}
|
|
|
|
// waitRead blocks until a read event occurs or a deadline is reached.
|
|
func (s *stream) waitRead() error {
|
|
var timer *time.Timer
|
|
var deadline <-chan time.Time
|
|
if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
|
|
timer = time.NewTimer(time.Until(d))
|
|
defer timer.Stop()
|
|
deadline = timer.C
|
|
}
|
|
|
|
select {
|
|
case <-s.chReadEvent: // notify some data has arrived, or closed
|
|
return nil
|
|
case <-s.chFinEvent:
|
|
// BUGFIX(xtaci): Fix for https://github.com/xtaci/smux/issues/82
|
|
s.bufferLock.Lock()
|
|
defer s.bufferLock.Unlock()
|
|
if len(s.buffers) > 0 {
|
|
return nil
|
|
}
|
|
return io.EOF
|
|
case <-s.sess.chSocketReadError:
|
|
return s.sess.socketReadError.Load().(error)
|
|
case <-s.sess.chProtoError:
|
|
return s.sess.protoError.Load().(error)
|
|
case <-deadline:
|
|
return ErrTimeout
|
|
case <-s.die:
|
|
return io.ErrClosedPipe
|
|
}
|
|
|
|
}
|
|
|
|
// Write implements net.Conn
|
|
//
|
|
// Note that the behavior when multiple goroutines write concurrently is not deterministic,
|
|
// frames may interleave in random way.
|
|
func (s *stream) Write(b []byte) (n int, err error) {
|
|
if s.sess.config.Version == 2 {
|
|
return s.writeV2(b)
|
|
}
|
|
|
|
var deadline <-chan time.Time
|
|
if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
|
|
timer := time.NewTimer(time.Until(d))
|
|
defer timer.Stop()
|
|
deadline = timer.C
|
|
}
|
|
|
|
// check if stream has closed
|
|
select {
|
|
case <-s.chFinEvent: // passive closing
|
|
return 0, io.EOF
|
|
case <-s.die:
|
|
return 0, io.ErrClosedPipe
|
|
default:
|
|
}
|
|
|
|
// frame split and transmit
|
|
sent := 0
|
|
frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
|
|
bts := b
|
|
for len(bts) > 0 {
|
|
sz := len(bts)
|
|
if sz > s.frameSize {
|
|
sz = s.frameSize
|
|
}
|
|
frame.data = bts[:sz]
|
|
bts = bts[sz:]
|
|
n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
|
|
s.numWritten++
|
|
sent += n
|
|
if err != nil {
|
|
return sent, err
|
|
}
|
|
}
|
|
|
|
return sent, nil
|
|
}
|
|
|
|
// writeV2 writes data to the stream for version 2 streams.
|
|
func (s *stream) writeV2(b []byte) (n int, err error) {
|
|
// check empty input
|
|
if len(b) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// check if stream has closed
|
|
select {
|
|
case <-s.chFinEvent:
|
|
return 0, io.EOF
|
|
case <-s.die:
|
|
return 0, io.ErrClosedPipe
|
|
default:
|
|
}
|
|
|
|
// create write deadline timer
|
|
var deadline <-chan time.Time
|
|
if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
|
|
timer := time.NewTimer(time.Until(d))
|
|
defer timer.Stop()
|
|
deadline = timer.C
|
|
}
|
|
|
|
// frame split and transmit process
|
|
sent := 0
|
|
frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
|
|
|
|
for {
|
|
// per stream sliding window control
|
|
// [.... [consumed... numWritten] ... win... ]
|
|
// [.... [consumed...................+rmtwnd]]
|
|
var bts []byte
|
|
// note:
|
|
// even if uint32 overflow, this math still works:
|
|
// eg1: uint32(0) - uint32(math.MaxUint32) = 1
|
|
// eg2: int32(uint32(0) - uint32(1)) = -1
|
|
//
|
|
// basicially, you can take it as a MODULAR ARITHMETIC
|
|
inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
|
|
if inflight < 0 { // security check for malformed data
|
|
return 0, ErrConsumed
|
|
}
|
|
|
|
// make sure you understand 'win' is calculated in modular arithmetic(2^32(4GB))
|
|
win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
|
|
|
|
if win > 0 {
|
|
// determine how many bytes to send
|
|
if win > int32(len(b)) {
|
|
bts = b
|
|
b = nil
|
|
} else {
|
|
bts = b[:win]
|
|
b = b[win:]
|
|
}
|
|
|
|
// frame split and transmit
|
|
for len(bts) > 0 {
|
|
// splitting frame
|
|
sz := len(bts)
|
|
if sz > s.frameSize {
|
|
sz = s.frameSize
|
|
}
|
|
frame.data = bts[:sz]
|
|
bts = bts[sz:]
|
|
|
|
// transmit of frame
|
|
n, err := s.sess.writeFrameInternal(frame, deadline, CLSDATA)
|
|
atomic.AddUint32(&s.numWritten, uint32(sz))
|
|
sent += n
|
|
if err != nil {
|
|
return sent, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// if there is any data left to be sent,
|
|
// wait until stream closes, window changes or deadline reached
|
|
// this blocking behavior will back propagate flow control to upper layer.
|
|
if len(b) > 0 {
|
|
select {
|
|
case <-s.chFinEvent:
|
|
return 0, io.EOF
|
|
case <-s.die:
|
|
return sent, io.ErrClosedPipe
|
|
case <-deadline:
|
|
return sent, ErrTimeout
|
|
case <-s.sess.chSocketWriteError:
|
|
return sent, s.sess.socketWriteError.Load().(error)
|
|
case <-s.chUpdate: // notify of remote data consuming and window update
|
|
continue
|
|
}
|
|
} else {
|
|
return sent, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close implements net.Conn
|
|
func (s *stream) Close() error {
|
|
var once bool
|
|
var err error
|
|
s.dieOnce.Do(func() {
|
|
close(s.die)
|
|
once = true
|
|
})
|
|
|
|
if once {
|
|
// send FIN in order
|
|
f := newFrame(byte(s.sess.config.Version), cmdFIN, s.id)
|
|
|
|
timer := time.NewTimer(openCloseTimeout)
|
|
defer timer.Stop()
|
|
|
|
_, err = s.sess.writeFrameInternal(f, timer.C, CLSDATA)
|
|
s.sess.streamClosed(s.id)
|
|
return err
|
|
} else {
|
|
return io.ErrClosedPipe
|
|
}
|
|
}
|
|
|
|
// GetDieCh returns a readonly chan which can be readable
|
|
// when the stream is to be closed.
|
|
func (s *stream) GetDieCh() <-chan struct{} {
|
|
return s.die
|
|
}
|
|
|
|
// SetReadDeadline sets the read deadline as defined by
|
|
// net.Conn.SetReadDeadline.
|
|
// A zero time value disables the deadline.
|
|
func (s *stream) SetReadDeadline(t time.Time) error {
|
|
s.readDeadline.Store(t)
|
|
s.notifyReadEvent()
|
|
return nil
|
|
}
|
|
|
|
// SetWriteDeadline sets the write deadline as defined by
|
|
// net.Conn.SetWriteDeadline.
|
|
// A zero time value disables the deadline.
|
|
func (s *stream) SetWriteDeadline(t time.Time) error {
|
|
s.writeDeadline.Store(t)
|
|
return nil
|
|
}
|
|
|
|
// SetDeadline sets both read and write deadlines as defined by
|
|
// net.Conn.SetDeadline.
|
|
// A zero time value disables the deadlines.
|
|
func (s *stream) SetDeadline(t time.Time) error {
|
|
if err := s.SetReadDeadline(t); err != nil {
|
|
return err
|
|
}
|
|
if err := s.SetWriteDeadline(t); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// session closes
|
|
func (s *stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
|
|
|
|
// LocalAddr satisfies net.Conn interface
|
|
func (s *stream) LocalAddr() net.Addr {
|
|
if ts, ok := s.sess.conn.(interface {
|
|
LocalAddr() net.Addr
|
|
}); ok {
|
|
return ts.LocalAddr()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RemoteAddr satisfies net.Conn interface
|
|
func (s *stream) RemoteAddr() net.Addr {
|
|
if ts, ok := s.sess.conn.(interface {
|
|
RemoteAddr() net.Addr
|
|
}); ok {
|
|
return ts.RemoteAddr()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// pushBytes append buf to buffers
|
|
func (s *stream) pushBytes(buf []byte) (written int, err error) {
|
|
s.bufferLock.Lock()
|
|
s.buffers = append(s.buffers, buf)
|
|
s.heads = append(s.heads, buf)
|
|
s.bufferLock.Unlock()
|
|
return
|
|
}
|
|
|
|
// recycleTokens transform remaining bytes to tokens(will truncate buffer)
|
|
func (s *stream) recycleTokens() (n int) {
|
|
s.bufferLock.Lock()
|
|
for k := range s.buffers {
|
|
n += len(s.buffers[k])
|
|
pool.PutBuffer(s.heads[k])
|
|
}
|
|
s.buffers = nil
|
|
s.heads = nil
|
|
s.bufferLock.Unlock()
|
|
return
|
|
}
|
|
|
|
// notify read event
|
|
func (s *stream) notifyReadEvent() {
|
|
select {
|
|
case s.chReadEvent <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// update command
|
|
func (s *stream) update(consumed uint32, window uint32) {
|
|
atomic.StoreUint32(&s.peerConsumed, consumed)
|
|
atomic.StoreUint32(&s.peerWindow, window)
|
|
select {
|
|
case s.chUpdate <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
// mark this stream has been closed in protocol
|
|
func (s *stream) fin() {
|
|
s.finEventOnce.Do(func() {
|
|
close(s.chFinEvent)
|
|
})
|
|
}
|