From dbd2e04521270fb28a65ff2987ed38a41da6e392 Mon Sep 17 00:00:00 2001
From: nadoo <287492+nadoo@users.noreply.github.com>
Date: Wed, 21 Apr 2021 00:29:17 +0800
Subject: [PATCH] proxy: added smux support
---
README.md | 37 +-
config.go | 8 +-
feature.go | 1 +
go.mod | 2 +-
go.sum | 4 +-
proxy/protocol/smux/LICENSE | 21 +
proxy/protocol/smux/frame.go | 81 ++
proxy/protocol/smux/mux.go | 110 +++
proxy/protocol/smux/mux_test.go | 86 +++
proxy/protocol/smux/session.go | 527 +++++++++++++
proxy/protocol/smux/session_test.go | 1090 +++++++++++++++++++++++++++
proxy/protocol/smux/shaper.go | 16 +
proxy/protocol/smux/shaper_test.go | 30 +
proxy/protocol/smux/stream.go | 545 ++++++++++++++
proxy/smux/client.go | 76 ++
proxy/smux/server.go | 98 +++
16 files changed, 2698 insertions(+), 34 deletions(-)
create mode 100644 proxy/protocol/smux/LICENSE
create mode 100644 proxy/protocol/smux/frame.go
create mode 100644 proxy/protocol/smux/mux.go
create mode 100644 proxy/protocol/smux/mux_test.go
create mode 100644 proxy/protocol/smux/session.go
create mode 100644 proxy/protocol/smux/session_test.go
create mode 100644 proxy/protocol/smux/shaper.go
create mode 100644 proxy/protocol/smux/shaper_test.go
create mode 100644 proxy/protocol/smux/stream.go
create mode 100644 proxy/smux/client.go
create mode 100644 proxy/smux/server.go
diff --git a/README.md b/README.md
index 54b52e1..5d74798 100644
--- a/README.md
+++ b/README.md
@@ -63,6 +63,7 @@ we can set up local listeners as proxy servers, and forward requests to internet
|TLS |√| |√| |transport client & server
|KCP | |√|√| |transport client & server
|Unix |√|√|√|√|transport client & server
+|Smux |√| |√| |transport client & server
|Websocket |√| |√| |transport client & server
|Simple-Obfs | | |√| |transport client only
|Redir |√| | | |linux only
@@ -96,7 +97,7 @@ glider -h
click to see details
```bash
-glider 0.13.0 usage:
+glider 0.14.0 usage:
-check string
check=tcp[://HOST:PORT]: tcp port connect check
check=http://HOST[:PORT][/URI][#expect=STRING_IN_RESP_LINE]
@@ -152,27 +153,10 @@ glider 0.13.0 usage:
forward strategy, default: rr (default "rr")
-verbose
verbose mode
-```
-
-
-run:
-```bash
-glider -config CONFIGPATH
-```
-```bash
-glider -verbose -listen :8443 -forward SCHEME://HOST:PORT
-```
-
-#### Schemes
-
-
-click to see details
-
-```bash
Available schemes:
- listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp
- forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs
+ listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp
+ forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs
Socks5 scheme:
socks://[user:pass@]host:port
@@ -251,6 +235,9 @@ TLS and Websocket with a specified proxy protocol:
Unix domain socket scheme:
unix://path
+Smux scheme:
+ smux://host:port
+
KCP scheme:
kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]
@@ -298,16 +285,8 @@ Config file format(see `./glider.conf.example` as an example):
KEY=VALUE
KEY=VALUE
# KEY equals to command line flag name: listen forward strategy...
-```
-
-
-#### Examples
-
-
-click to see details
-
-```bash
+Examples:
./glider -config glider.conf
-run glider with specified config file.
diff --git a/config.go b/config.go
index d84e137..b2eaaf2 100644
--- a/config.go
+++ b/config.go
@@ -131,8 +131,8 @@ func usage() {
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "Available schemes:\n")
- fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp\n")
- fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs\n")
+ fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp\n")
+ fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs\n")
fmt.Fprintf(w, "\n")
fmt.Fprintf(w, "Socks5 scheme:\n")
@@ -231,6 +231,10 @@ func usage() {
fmt.Fprintf(w, " unix://path\n")
fmt.Fprintf(w, "\n")
+ fmt.Fprintf(w, "Smux scheme:\n")
+ fmt.Fprintf(w, " smux://host:port\n")
+ fmt.Fprintf(w, "\n")
+
fmt.Fprintf(w, "KCP scheme:\n")
fmt.Fprintf(w, " kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]\n")
fmt.Fprintf(w, "\n")
diff --git a/feature.go b/feature.go
index 930a157..c4c5b07 100644
--- a/feature.go
+++ b/feature.go
@@ -10,6 +10,7 @@ import (
_ "github.com/nadoo/glider/proxy/mixed"
_ "github.com/nadoo/glider/proxy/obfs"
_ "github.com/nadoo/glider/proxy/reject"
+ _ "github.com/nadoo/glider/proxy/smux"
_ "github.com/nadoo/glider/proxy/socks4"
_ "github.com/nadoo/glider/proxy/socks5"
_ "github.com/nadoo/glider/proxy/ss"
diff --git a/go.mod b/go.mod
index 313b23c..0fe7e0f 100644
--- a/go.mod
+++ b/go.mod
@@ -8,7 +8,7 @@ require (
github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb
github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152
github.com/ebfe/rc2 v0.0.0-20131011165748-24b9757f5521 // indirect
- github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa
+ github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305
github.com/klauspost/cpuid/v2 v2.0.6 // indirect
github.com/klauspost/reedsolomon v1.9.12 // indirect
github.com/mdlayher/raw v0.0.0-20210412142147-51b895745faf // indirect
diff --git a/go.sum b/go.sum
index 0bb9ef8..847e797 100644
--- a/go.sum
+++ b/go.sum
@@ -39,8 +39,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis=
-github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa h1:ff6iUOGfajZ+T8RBuRtOjYMW7/6aF8iJG+iCp2wQyQ4=
-github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
+github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305 h1:DGmCtsdLE6r7tuFH5YHnDoKJU9WXjptF/8Q8iKc41Tk=
+github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
diff --git a/proxy/protocol/smux/LICENSE b/proxy/protocol/smux/LICENSE
new file mode 100644
index 0000000..eed41ac
--- /dev/null
+++ b/proxy/protocol/smux/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2016-2017 Daniel Fu
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/proxy/protocol/smux/frame.go b/proxy/protocol/smux/frame.go
new file mode 100644
index 0000000..467a058
--- /dev/null
+++ b/proxy/protocol/smux/frame.go
@@ -0,0 +1,81 @@
+package smux
+
+import (
+ "encoding/binary"
+ "fmt"
+)
+
+const ( // cmds
+ // protocol version 1:
+ cmdSYN byte = iota // stream open
+ cmdFIN // stream close, a.k.a EOF mark
+ cmdPSH // data push
+ cmdNOP // no operation
+
+ // protocol version 2 extra commands
+ // notify bytes consumed by remote peer-end
+ cmdUPD
+)
+
+const (
+ // data size of cmdUPD, format:
+ // |4B data consumed(ACK)| 4B window size(WINDOW) |
+ szCmdUPD = 8
+)
+
+const (
+ // initial peer window guess, a slow-start
+ initialPeerWindow = 262144
+)
+
+const (
+ sizeOfVer = 1
+ sizeOfCmd = 1
+ sizeOfLength = 2
+ sizeOfSid = 4
+ headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength
+)
+
+// Frame defines a packet from or to be multiplexed into a single connection
+type Frame struct {
+ ver byte
+ cmd byte
+ sid uint32
+ data []byte
+}
+
+func newFrame(version byte, cmd byte, sid uint32) Frame {
+ return Frame{ver: version, cmd: cmd, sid: sid}
+}
+
+type rawHeader [headerSize]byte
+
+func (h rawHeader) Version() byte {
+ return h[0]
+}
+
+func (h rawHeader) Cmd() byte {
+ return h[1]
+}
+
+func (h rawHeader) Length() uint16 {
+ return binary.LittleEndian.Uint16(h[2:])
+}
+
+func (h rawHeader) StreamID() uint32 {
+ return binary.LittleEndian.Uint32(h[4:])
+}
+
+func (h rawHeader) String() string {
+ return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d",
+ h.Version(), h.Cmd(), h.StreamID(), h.Length())
+}
+
+type updHeader [szCmdUPD]byte
+
+func (h updHeader) Consumed() uint32 {
+ return binary.LittleEndian.Uint32(h[:])
+}
+func (h updHeader) Window() uint32 {
+ return binary.LittleEndian.Uint32(h[4:])
+}
diff --git a/proxy/protocol/smux/mux.go b/proxy/protocol/smux/mux.go
new file mode 100644
index 0000000..c0b8ab8
--- /dev/null
+++ b/proxy/protocol/smux/mux.go
@@ -0,0 +1,110 @@
+// Package smux is a multiplexing library for Golang.
+//
+// It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP,
+// and provides stream-oriented multiplexing over a single channel.
+package smux
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "math"
+ "time"
+)
+
+// Config is used to tune the Smux session
+type Config struct {
+ // SMUX Protocol version, support 1,2
+ Version int
+
+ // Disabled keepalive
+ KeepAliveDisabled bool
+
+ // KeepAliveInterval is how often to send a NOP command to the remote
+ KeepAliveInterval time.Duration
+
+ // KeepAliveTimeout is how long the session
+ // will be closed if no data has arrived
+ KeepAliveTimeout time.Duration
+
+ // MaxFrameSize is used to control the maximum
+ // frame size to sent to the remote
+ MaxFrameSize int
+
+ // MaxReceiveBuffer is used to control the maximum
+ // number of data in the buffer pool
+ MaxReceiveBuffer int
+
+ // MaxStreamBuffer is used to control the maximum
+ // number of data per stream
+ MaxStreamBuffer int
+}
+
+// DefaultConfig is used to return a default configuration
+func DefaultConfig() *Config {
+ return &Config{
+ Version: 1,
+ KeepAliveInterval: 10 * time.Second,
+ KeepAliveTimeout: 30 * time.Second,
+ MaxFrameSize: 32768,
+ MaxReceiveBuffer: 4194304,
+ MaxStreamBuffer: 65536,
+ }
+}
+
+// VerifyConfig is used to verify the sanity of configuration
+func VerifyConfig(config *Config) error {
+ if !(config.Version == 1 || config.Version == 2) {
+ return errors.New("unsupported protocol version")
+ }
+ if !config.KeepAliveDisabled {
+ if config.KeepAliveInterval == 0 {
+ return errors.New("keep-alive interval must be positive")
+ }
+ if config.KeepAliveTimeout < config.KeepAliveInterval {
+ return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval")
+ }
+ }
+ if config.MaxFrameSize <= 0 {
+ return errors.New("max frame size must be positive")
+ }
+ if config.MaxFrameSize > 65535 {
+ return errors.New("max frame size must not be larger than 65535")
+ }
+ if config.MaxReceiveBuffer <= 0 {
+ return errors.New("max receive buffer must be positive")
+ }
+ if config.MaxStreamBuffer <= 0 {
+ return errors.New("max stream buffer must be positive")
+ }
+ if config.MaxStreamBuffer > config.MaxReceiveBuffer {
+ return errors.New("max stream buffer must not be larger than max receive buffer")
+ }
+ if config.MaxStreamBuffer > math.MaxInt32 {
+ return errors.New("max stream buffer cannot be larger than 2147483647")
+ }
+ return nil
+}
+
+// Server is used to initialize a new server-side connection.
+func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, false), nil
+}
+
+// Client is used to initialize a new client-side connection.
+func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
+ if config == nil {
+ config = DefaultConfig()
+ }
+
+ if err := VerifyConfig(config); err != nil {
+ return nil, err
+ }
+ return newSession(config, conn, true), nil
+}
diff --git a/proxy/protocol/smux/mux_test.go b/proxy/protocol/smux/mux_test.go
new file mode 100644
index 0000000..dc9c1c1
--- /dev/null
+++ b/proxy/protocol/smux/mux_test.go
@@ -0,0 +1,86 @@
+package smux
+
+import (
+ "bytes"
+ "testing"
+)
+
+type buffer struct {
+ bytes.Buffer
+}
+
+func (b *buffer) Close() error {
+ b.Buffer.Reset()
+ return nil
+}
+
+func TestConfig(t *testing.T) {
+ VerifyConfig(DefaultConfig())
+
+ config := DefaultConfig()
+ config.KeepAliveInterval = 0
+ err := VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.KeepAliveInterval = 10
+ config.KeepAliveTimeout = 5
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.MaxFrameSize = 0
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.MaxFrameSize = 65536
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.MaxReceiveBuffer = 0
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.MaxStreamBuffer = 0
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ config = DefaultConfig()
+ config.MaxStreamBuffer = 100
+ config.MaxReceiveBuffer = 99
+ err = VerifyConfig(config)
+ t.Log(err)
+ if err == nil {
+ t.Fatal(err)
+ }
+
+ var bts buffer
+ if _, err := Server(&bts, config); err == nil {
+ t.Fatal("server started with wrong config")
+ }
+
+ if _, err := Client(&bts, config); err == nil {
+ t.Fatal("client started with wrong config")
+ }
+}
diff --git a/proxy/protocol/smux/session.go b/proxy/protocol/smux/session.go
new file mode 100644
index 0000000..f8cfee7
--- /dev/null
+++ b/proxy/protocol/smux/session.go
@@ -0,0 +1,527 @@
+package smux
+
+import (
+ "container/heap"
+ "encoding/binary"
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/nadoo/glider/pool"
+)
+
+const (
+ defaultAcceptBacklog = 1024
+)
+
+var (
+ ErrInvalidProtocol = errors.New("invalid protocol")
+ ErrConsumed = errors.New("peer consumed more than sent")
+ ErrGoAway = errors.New("stream id overflows, should start a new connection")
+ ErrTimeout = errors.New("timeout")
+ ErrWouldBlock = errors.New("operation would block on IO")
+)
+
+type writeRequest struct {
+ prio uint64
+ frame Frame
+ result chan writeResult
+}
+
+type writeResult struct {
+ n int
+ err error
+}
+
+type buffersWriter interface {
+ WriteBuffers(v [][]byte) (n int, err error)
+}
+
+// Session defines a multiplexed connection for streams
+type Session struct {
+ conn io.ReadWriteCloser
+
+ config *Config
+ nextStreamID uint32 // next stream identifier
+ nextStreamIDLock sync.Mutex
+
+ bucket int32 // token bucket
+ bucketNotify chan struct{} // used for waiting for tokens
+
+ streams map[uint32]*Stream // all streams in this session
+ streamLock sync.Mutex // locks streams
+
+ die chan struct{} // flag session has died
+ dieOnce sync.Once
+
+ // socket error handling
+ socketReadError atomic.Value
+ socketWriteError atomic.Value
+ chSocketReadError chan struct{}
+ chSocketWriteError chan struct{}
+ socketReadErrorOnce sync.Once
+ socketWriteErrorOnce sync.Once
+
+ // smux protocol errors
+ protoError atomic.Value
+ chProtoError chan struct{}
+ protoErrorOnce sync.Once
+
+ chAccepts chan *Stream
+
+ dataReady int32 // flag data has arrived
+
+ goAway int32 // flag id exhausted
+
+ deadline atomic.Value
+
+ shaper chan writeRequest // a shaper for writing
+ writes chan writeRequest
+}
+
+func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
+ s := new(Session)
+ s.die = make(chan struct{})
+ s.conn = conn
+ s.config = config
+ s.streams = make(map[uint32]*Stream)
+ s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
+ s.bucket = int32(config.MaxReceiveBuffer)
+ s.bucketNotify = make(chan struct{}, 1)
+ s.shaper = make(chan writeRequest)
+ s.writes = make(chan writeRequest)
+ s.chSocketReadError = make(chan struct{})
+ s.chSocketWriteError = make(chan struct{})
+ s.chProtoError = make(chan struct{})
+
+ if client {
+ s.nextStreamID = 1
+ } else {
+ s.nextStreamID = 0
+ }
+
+ go s.shaperLoop()
+ go s.recvLoop()
+ go s.sendLoop()
+ if !config.KeepAliveDisabled {
+ go s.keepalive()
+ }
+ return s
+}
+
+// OpenStream is used to create a new stream
+func (s *Session) OpenStream() (*Stream, error) {
+ if s.IsClosed() {
+ return nil, io.ErrClosedPipe
+ }
+
+ // generate stream id
+ s.nextStreamIDLock.Lock()
+ if s.goAway > 0 {
+ s.nextStreamIDLock.Unlock()
+ return nil, ErrGoAway
+ }
+
+ s.nextStreamID += 2
+ sid := s.nextStreamID
+ if sid == sid%2 { // stream-id overflows
+ s.goAway = 1
+ s.nextStreamIDLock.Unlock()
+ return nil, ErrGoAway
+ }
+ s.nextStreamIDLock.Unlock()
+
+ stream := newStream(sid, s.config.MaxFrameSize, s)
+
+ if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
+ return nil, err
+ }
+
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ select {
+ case <-s.chSocketReadError:
+ return nil, s.socketReadError.Load().(error)
+ case <-s.chSocketWriteError:
+ return nil, s.socketWriteError.Load().(error)
+ case <-s.die:
+ return nil, io.ErrClosedPipe
+ default:
+ s.streams[sid] = stream
+ return stream, nil
+ }
+}
+
+// Open returns a generic ReadWriteCloser
+func (s *Session) Open() (io.ReadWriteCloser, error) {
+ return s.OpenStream()
+}
+
+// AcceptStream is used to block until the next available stream
+// is ready to be accepted.
+func (s *Session) AcceptStream() (*Stream, error) {
+ var deadline <-chan time.Time
+ if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ select {
+ case stream := <-s.chAccepts:
+ return stream, nil
+ case <-deadline:
+ return nil, ErrTimeout
+ case <-s.chSocketReadError:
+ return nil, s.socketReadError.Load().(error)
+ case <-s.chProtoError:
+ return nil, s.protoError.Load().(error)
+ case <-s.die:
+ return nil, io.ErrClosedPipe
+ }
+}
+
+// Accept Returns a generic ReadWriteCloser instead of smux.Stream
+func (s *Session) Accept() (io.ReadWriteCloser, error) {
+ return s.AcceptStream()
+}
+
+// Close is used to close the session and all streams.
+func (s *Session) Close() error {
+ var once bool
+ s.dieOnce.Do(func() {
+ close(s.die)
+ once = true
+ })
+
+ if once {
+ s.streamLock.Lock()
+ for k := range s.streams {
+ s.streams[k].sessionClose()
+ }
+ s.streamLock.Unlock()
+ return s.conn.Close()
+ } else {
+ return io.ErrClosedPipe
+ }
+}
+
+// notifyBucket notifies recvLoop that bucket is available
+func (s *Session) notifyBucket() {
+ select {
+ case s.bucketNotify <- struct{}{}:
+ default:
+ }
+}
+
+func (s *Session) notifyReadError(err error) {
+ s.socketReadErrorOnce.Do(func() {
+ s.socketReadError.Store(err)
+ close(s.chSocketReadError)
+ })
+}
+
+func (s *Session) notifyWriteError(err error) {
+ s.socketWriteErrorOnce.Do(func() {
+ s.socketWriteError.Store(err)
+ close(s.chSocketWriteError)
+ })
+}
+
+func (s *Session) notifyProtoError(err error) {
+ s.protoErrorOnce.Do(func() {
+ s.protoError.Store(err)
+ close(s.chProtoError)
+ })
+}
+
+// IsClosed does a safe check to see if we have shutdown
+func (s *Session) IsClosed() bool {
+ select {
+ case <-s.die:
+ return true
+ default:
+ return false
+ }
+}
+
+// NumStreams returns the number of currently open streams
+func (s *Session) NumStreams() int {
+ if s.IsClosed() {
+ return 0
+ }
+ s.streamLock.Lock()
+ defer s.streamLock.Unlock()
+ return len(s.streams)
+}
+
+// SetDeadline sets a deadline used by Accept* calls.
+// A zero time value disables the deadline.
+func (s *Session) SetDeadline(t time.Time) error {
+ s.deadline.Store(t)
+ return nil
+}
+
+// LocalAddr satisfies net.Conn interface
+func (s *Session) LocalAddr() net.Addr {
+ if ts, ok := s.conn.(interface {
+ LocalAddr() net.Addr
+ }); ok {
+ return ts.LocalAddr()
+ }
+ return nil
+}
+
+// RemoteAddr satisfies net.Conn interface
+func (s *Session) RemoteAddr() net.Addr {
+ if ts, ok := s.conn.(interface {
+ RemoteAddr() net.Addr
+ }); ok {
+ return ts.RemoteAddr()
+ }
+ return nil
+}
+
+// notify the session that a stream has closed
+func (s *Session) streamClosed(sid uint32) {
+ s.streamLock.Lock()
+ if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
+ if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
+ s.notifyBucket()
+ }
+ }
+ delete(s.streams, sid)
+ s.streamLock.Unlock()
+}
+
+// returnTokens is called by stream to return token after read
+func (s *Session) returnTokens(n int) {
+ if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
+ s.notifyBucket()
+ }
+}
+
+// recvLoop keeps on reading from underlying connection if tokens are available
+func (s *Session) recvLoop() {
+ var hdr rawHeader
+ var updHdr updHeader
+
+ for {
+ for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
+ select {
+ case <-s.bucketNotify:
+ case <-s.die:
+ return
+ }
+ }
+
+ // read header first
+ if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
+ atomic.StoreInt32(&s.dataReady, 1)
+ if hdr.Version() != byte(s.config.Version) {
+ s.notifyProtoError(ErrInvalidProtocol)
+ return
+ }
+ sid := hdr.StreamID()
+ switch hdr.Cmd() {
+ case cmdNOP:
+ case cmdSYN:
+ s.streamLock.Lock()
+ if _, ok := s.streams[sid]; !ok {
+ stream := newStream(sid, s.config.MaxFrameSize, s)
+ s.streams[sid] = stream
+ select {
+ case s.chAccepts <- stream:
+ case <-s.die:
+ }
+ }
+ s.streamLock.Unlock()
+ case cmdFIN:
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.fin()
+ stream.notifyReadEvent()
+ }
+ s.streamLock.Unlock()
+ case cmdPSH:
+ if hdr.Length() > 0 {
+ newbuf := pool.GetBuffer(int(hdr.Length()))
+ if written, err := io.ReadFull(s.conn, newbuf); err == nil {
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.pushBytes(newbuf)
+ atomic.AddInt32(&s.bucket, -int32(written))
+ stream.notifyReadEvent()
+ }
+ s.streamLock.Unlock()
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ }
+ case cmdUPD:
+ if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
+ s.streamLock.Lock()
+ if stream, ok := s.streams[sid]; ok {
+ stream.update(updHdr.Consumed(), updHdr.Window())
+ }
+ s.streamLock.Unlock()
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ default:
+ s.notifyProtoError(ErrInvalidProtocol)
+ return
+ }
+ } else {
+ s.notifyReadError(err)
+ return
+ }
+ }
+}
+
+func (s *Session) keepalive() {
+ tickerPing := time.NewTicker(s.config.KeepAliveInterval)
+ tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
+ defer tickerPing.Stop()
+ defer tickerTimeout.Stop()
+ for {
+ select {
+ case <-tickerPing.C:
+ s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
+ s.notifyBucket() // force a signal to the recvLoop
+ case <-tickerTimeout.C:
+ if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
+ // recvLoop may block while bucket is 0, in this case,
+ // session should not be closed.
+ if atomic.LoadInt32(&s.bucket) > 0 {
+ s.Close()
+ return
+ }
+ }
+ case <-s.die:
+ return
+ }
+ }
+}
+
+// shaper shapes the sending sequence among streams
+func (s *Session) shaperLoop() {
+ var reqs shaperHeap
+ var next writeRequest
+ var chWrite chan writeRequest
+
+ for {
+ if len(reqs) > 0 {
+ chWrite = s.writes
+ next = heap.Pop(&reqs).(writeRequest)
+ } else {
+ chWrite = nil
+ }
+
+ select {
+ case <-s.die:
+ return
+ case r := <-s.shaper:
+ if chWrite != nil { // next is valid, reshape
+ heap.Push(&reqs, next)
+ }
+ heap.Push(&reqs, r)
+ case chWrite <- next:
+ }
+ }
+}
+
+func (s *Session) sendLoop() {
+ var buf []byte
+ var n int
+ var err error
+ var vec [][]byte // vector for writeBuffers
+
+ bw, ok := s.conn.(buffersWriter)
+ if ok {
+ buf = make([]byte, headerSize)
+ vec = make([][]byte, 2)
+ } else {
+ buf = make([]byte, (1<<16)+headerSize)
+ }
+
+ for {
+ select {
+ case <-s.die:
+ return
+ case request := <-s.writes:
+ buf[0] = request.frame.ver
+ buf[1] = request.frame.cmd
+ binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
+ binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
+
+ if len(vec) > 0 {
+ vec[0] = buf[:headerSize]
+ vec[1] = request.frame.data
+ n, err = bw.WriteBuffers(vec)
+ } else {
+ copy(buf[headerSize:], request.frame.data)
+ n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
+ }
+
+ n -= headerSize
+ if n < 0 {
+ n = 0
+ }
+
+ result := writeResult{
+ n: n,
+ err: err,
+ }
+
+ request.result <- result
+ close(request.result)
+
+ // store conn error
+ if err != nil {
+ s.notifyWriteError(err)
+ return
+ }
+ }
+ }
+}
+
+// writeFrame writes the frame to the underlying connection
+// and returns the number of bytes written if successful
+func (s *Session) writeFrame(f Frame) (n int, err error) {
+ return s.writeFrameInternal(f, nil, 0)
+}
+
+// internal writeFrame version to support deadline used in keepalive
+func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
+ req := writeRequest{
+ prio: prio,
+ frame: f,
+ result: make(chan writeResult, 1),
+ }
+ select {
+ case s.shaper <- req:
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ case <-s.chSocketWriteError:
+ return 0, s.socketWriteError.Load().(error)
+ case <-deadline:
+ return 0, ErrTimeout
+ }
+
+ select {
+ case result := <-req.result:
+ return result.n, result.err
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ case <-s.chSocketWriteError:
+ return 0, s.socketWriteError.Load().(error)
+ case <-deadline:
+ return 0, ErrTimeout
+ }
+}
diff --git a/proxy/protocol/smux/session_test.go b/proxy/protocol/smux/session_test.go
new file mode 100644
index 0000000..3479570
--- /dev/null
+++ b/proxy/protocol/smux/session_test.go
@@ -0,0 +1,1090 @@
+package smux
+
+import (
+ "bytes"
+ crand "crypto/rand"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "log"
+ "math/rand"
+ "net"
+ "net/http"
+ _ "net/http/pprof"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+func init() {
+ go func() {
+ log.Println(http.ListenAndServe("0.0.0.0:6060", nil))
+ }()
+}
+
+// setupServer starts new server listening on a random localhost port and
+// returns address of the server, function to stop the server, new client
+// connection to this server or an error.
+func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return "", nil, nil, err
+ }
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ go handleConnection(conn)
+ }()
+ addr = ln.Addr().String()
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ ln.Close()
+ return "", nil, nil, err
+ }
+ return ln.Addr().String(), func() { ln.Close() }, conn, nil
+}
+
+func handleConnection(conn net.Conn) {
+ session, _ := Server(conn, nil)
+ for {
+ if stream, err := session.AcceptStream(); err == nil {
+ go func(s io.ReadWriteCloser) {
+ buf := make([]byte, 65536)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ return
+ }
+ s.Write(buf[:n])
+ }
+ }(stream)
+ } else {
+ return
+ }
+ }
+}
+
+// setupServer starts new server listening on a random localhost port and
+// returns address of the server, function to stop the server, new client
+// connection to this server or an error.
+func setupServerV2(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return "", nil, nil, err
+ }
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ go handleConnectionV2(conn)
+ }()
+ addr = ln.Addr().String()
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ ln.Close()
+ return "", nil, nil, err
+ }
+ return ln.Addr().String(), func() { ln.Close() }, conn, nil
+}
+
+func handleConnectionV2(conn net.Conn) {
+ config := DefaultConfig()
+ config.Version = 2
+ session, _ := Server(conn, config)
+ for {
+ if stream, err := session.AcceptStream(); err == nil {
+ go func(s io.ReadWriteCloser) {
+ buf := make([]byte, 65536)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ return
+ }
+ s.Write(buf[:n])
+ }
+ }(stream)
+ } else {
+ return
+ }
+ }
+}
+
+func TestEcho(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ const N = 100
+ buf := make([]byte, 10)
+ var sent string
+ var received string
+ for i := 0; i < N; i++ {
+ msg := fmt.Sprintf("hello%v", i)
+ stream.Write([]byte(msg))
+ sent += msg
+ if n, err := stream.Read(buf); err != nil {
+ t.Fatal(err)
+ } else {
+ received += string(buf[:n])
+ }
+ }
+ if sent != received {
+ t.Fatal("data mimatch")
+ }
+ session.Close()
+}
+
+func TestWriteTo(t *testing.T) {
+ const N = 1 << 20
+ // server
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ session, _ := Server(conn, nil)
+ for {
+ if stream, err := session.AcceptStream(); err == nil {
+ go func(s io.ReadWriteCloser) {
+ numBytes := 0
+ buf := make([]byte, 65536)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ return
+ }
+ s.Write(buf[:n])
+ numBytes += n
+
+ if numBytes == N {
+ s.Close()
+ return
+ }
+ }
+ }(stream)
+ } else {
+ return
+ }
+ }
+ }()
+
+ addr := ln.Addr().String()
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // client
+ session, _ := Client(conn, nil)
+ stream, _ := session.OpenStream()
+ sndbuf := make([]byte, N)
+ for i := range sndbuf {
+ sndbuf[i] = byte(rand.Int())
+ }
+
+ go stream.Write(sndbuf)
+
+ var rcvbuf bytes.Buffer
+ nw, ew := stream.WriteTo(&rcvbuf)
+ if ew != io.EOF {
+ t.Fatal(ew)
+ }
+
+ if nw != N {
+ t.Fatal("WriteTo nw mismatch", nw)
+ }
+
+ if bytes.Compare(sndbuf, rcvbuf.Bytes()) != 0 {
+ t.Fatal("mismatched echo bytes")
+ }
+}
+
+func TestWriteToV2(t *testing.T) {
+ config := DefaultConfig()
+ config.Version = 2
+ const N = 1 << 20
+ // server
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+
+ go func() {
+ conn, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ session, _ := Server(conn, config)
+ for {
+ if stream, err := session.AcceptStream(); err == nil {
+ go func(s io.ReadWriteCloser) {
+ numBytes := 0
+ buf := make([]byte, 65536)
+ for {
+ n, err := s.Read(buf)
+ if err != nil {
+ return
+ }
+ s.Write(buf[:n])
+ numBytes += n
+
+ if numBytes == N {
+ s.Close()
+ return
+ }
+ }
+ }(stream)
+ } else {
+ return
+ }
+ }
+ }()
+
+ addr := ln.Addr().String()
+ conn, err := net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ // client
+ session, _ := Client(conn, config)
+ stream, _ := session.OpenStream()
+ sndbuf := make([]byte, N)
+ for i := range sndbuf {
+ sndbuf[i] = byte(rand.Int())
+ }
+
+ go stream.Write(sndbuf)
+
+ var rcvbuf bytes.Buffer
+ nw, ew := stream.WriteTo(&rcvbuf)
+ if ew != io.EOF {
+ t.Fatal(ew)
+ }
+
+ if nw != N {
+ t.Fatal("WriteTo nw mismatch", nw)
+ }
+
+ if bytes.Compare(sndbuf, rcvbuf.Bytes()) != 0 {
+ t.Fatal("mismatched echo bytes")
+ }
+}
+
+func TestGetDieCh(t *testing.T) {
+ cs, ss, err := getSmuxStreamPair()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ss.Close()
+ dieCh := ss.GetDieCh()
+ go func() {
+ select {
+ case <-dieCh:
+ case <-time.Tick(time.Second):
+ t.Fatal("wait die chan timeout")
+ }
+ }()
+ cs.Close()
+}
+
+func TestSpeed(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ t.Log(stream.LocalAddr(), stream.RemoteAddr())
+
+ start := time.Now()
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ buf := make([]byte, 1024*1024)
+ nrecv := 0
+ for {
+ n, err := stream.Read(buf)
+ if err != nil {
+ t.Error(err)
+ break
+ } else {
+ nrecv += n
+ if nrecv == 4096*4096 {
+ break
+ }
+ }
+ }
+ stream.Close()
+ t.Log("time for 16MB rtt", time.Since(start))
+ wg.Done()
+ }()
+ msg := make([]byte, 8192)
+ for i := 0; i < 2048; i++ {
+ stream.Write(msg)
+ }
+ wg.Wait()
+ session.Close()
+}
+
+func TestParallel(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+
+ par := 1000
+ messages := 100
+ var wg sync.WaitGroup
+ wg.Add(par)
+ for i := 0; i < par; i++ {
+ stream, _ := session.OpenStream()
+ go func(s *Stream) {
+ buf := make([]byte, 20)
+ for j := 0; j < messages; j++ {
+ msg := fmt.Sprintf("hello%v", j)
+ s.Write([]byte(msg))
+ if _, err := s.Read(buf); err != nil {
+ break
+ }
+ }
+ s.Close()
+ wg.Done()
+ }(stream)
+ }
+ t.Log("created", session.NumStreams(), "streams")
+ wg.Wait()
+ session.Close()
+}
+
+func TestParallelV2(t *testing.T) {
+ config := DefaultConfig()
+ config.Version = 2
+ _, stop, cli, err := setupServerV2(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, config)
+
+ par := 1000
+ messages := 100
+ var wg sync.WaitGroup
+ wg.Add(par)
+ for i := 0; i < par; i++ {
+ stream, _ := session.OpenStream()
+ go func(s *Stream) {
+ buf := make([]byte, 20)
+ for j := 0; j < messages; j++ {
+ msg := fmt.Sprintf("hello%v", j)
+ s.Write([]byte(msg))
+ if _, err := s.Read(buf); err != nil {
+ break
+ }
+ }
+ s.Close()
+ wg.Done()
+ }(stream)
+ }
+ t.Log("created", session.NumStreams(), "streams")
+ wg.Wait()
+ session.Close()
+}
+
+func TestCloseThenOpen(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ session.Close()
+ if _, err := session.OpenStream(); err == nil {
+ t.Fatal("opened after close")
+ }
+}
+
+func TestSessionDoubleClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ session.Close()
+ if err := session.Close(); err == nil {
+ t.Fatal("session double close doesn't return error")
+ }
+}
+
+func TestStreamDoubleClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ stream.Close()
+ if err := stream.Close(); err == nil {
+ t.Fatal("stream double close doesn't return error")
+ }
+ session.Close()
+}
+
+func TestConcurrentClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ numStreams := 100
+ streams := make([]*Stream, 0, numStreams)
+ var wg sync.WaitGroup
+ wg.Add(numStreams)
+ for i := 0; i < 100; i++ {
+ stream, _ := session.OpenStream()
+ streams = append(streams, stream)
+ }
+ for _, s := range streams {
+ stream := s
+ go func() {
+ stream.Close()
+ wg.Done()
+ }()
+ }
+ session.Close()
+ wg.Wait()
+}
+
+func TestTinyReadBuffer(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ const N = 100
+ tinybuf := make([]byte, 6)
+ var sent string
+ var received string
+ for i := 0; i < N; i++ {
+ msg := fmt.Sprintf("hello%v", i)
+ sent += msg
+ nsent, err := stream.Write([]byte(msg))
+ if err != nil {
+ t.Fatal("cannot write")
+ }
+ nrecv := 0
+ for nrecv < nsent {
+ if n, err := stream.Read(tinybuf); err == nil {
+ nrecv += n
+ received += string(tinybuf[:n])
+ } else {
+ t.Fatal("cannot read with tiny buffer")
+ }
+ }
+ }
+
+ if sent != received {
+ t.Fatal("data mimatch")
+ }
+ session.Close()
+}
+
+func TestIsClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ session.Close()
+ if !session.IsClosed() {
+ t.Fatal("still open after close")
+ }
+}
+
+func TestKeepAliveTimeout(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ go func() {
+ ln.Accept()
+ }()
+
+ cli, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cli.Close()
+
+ config := DefaultConfig()
+ config.KeepAliveInterval = time.Second
+ config.KeepAliveTimeout = 2 * time.Second
+ session, _ := Client(cli, config)
+ time.Sleep(3 * time.Second)
+ if !session.IsClosed() {
+ t.Fatal("keepalive-timeout failed")
+ }
+}
+
+type blockWriteConn struct {
+ net.Conn
+}
+
+func (c *blockWriteConn) Write(b []byte) (n int, err error) {
+ forever := time.Hour * 24
+ time.Sleep(forever)
+ return c.Conn.Write(b)
+}
+
+func TestKeepAliveBlockWriteTimeout(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ go func() {
+ ln.Accept()
+ }()
+
+ cli, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cli.Close()
+ //when writeFrame block, keepalive in old version never timeout
+ blockWriteCli := &blockWriteConn{cli}
+
+ config := DefaultConfig()
+ config.KeepAliveInterval = time.Second
+ config.KeepAliveTimeout = 2 * time.Second
+ session, _ := Client(blockWriteCli, config)
+ time.Sleep(3 * time.Second)
+ if !session.IsClosed() {
+ t.Fatal("keepalive-timeout failed")
+ }
+}
+
+func TestServerEcho(t *testing.T) {
+ ln, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ln.Close()
+ go func() {
+ err := func() error {
+ conn, err := ln.Accept()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+ session, err := Server(conn, nil)
+ if err != nil {
+ return err
+ }
+ defer session.Close()
+ buf := make([]byte, 10)
+ stream, err := session.OpenStream()
+ if err != nil {
+ return err
+ }
+ defer stream.Close()
+ for i := 0; i < 100; i++ {
+ msg := fmt.Sprintf("hello%v", i)
+ stream.Write([]byte(msg))
+ n, err := stream.Read(buf)
+ if err != nil {
+ return err
+ }
+ if got := string(buf[:n]); got != msg {
+ return fmt.Errorf("got: %q, want: %q", got, msg)
+ }
+ }
+ return nil
+ }()
+ if err != nil {
+ t.Error(err)
+ }
+ }()
+
+ cli, err := net.Dial("tcp", ln.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cli.Close()
+ if session, err := Client(cli, nil); err == nil {
+ if stream, err := session.AcceptStream(); err == nil {
+ buf := make([]byte, 65536)
+ for {
+ n, err := stream.Read(buf)
+ if err != nil {
+ break
+ }
+ stream.Write(buf[:n])
+ }
+ } else {
+ t.Fatal(err)
+ }
+ } else {
+ t.Fatal(err)
+ }
+}
+
+func TestSendWithoutRecv(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ const N = 100
+ for i := 0; i < N; i++ {
+ msg := fmt.Sprintf("hello%v", i)
+ stream.Write([]byte(msg))
+ }
+ buf := make([]byte, 1)
+ if _, err := stream.Read(buf); err != nil {
+ t.Fatal(err)
+ }
+ stream.Close()
+}
+
+func TestWriteAfterClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ stream.Close()
+ if _, err := stream.Write([]byte("write after close")); err == nil {
+ t.Fatal("write after close failed")
+ }
+}
+
+func TestReadStreamAfterSessionClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ session.Close()
+ buf := make([]byte, 10)
+ if _, err := stream.Read(buf); err != nil {
+ t.Log(err)
+ } else {
+ t.Fatal("read stream after session close succeeded")
+ }
+}
+
+func TestWriteStreamAfterConnectionClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ session.conn.Close()
+ if _, err := stream.Write([]byte("write after connection close")); err == nil {
+ t.Fatal("write after connection close failed")
+ }
+}
+
+func TestNumStreamAfterClose(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ if _, err := session.OpenStream(); err == nil {
+ if session.NumStreams() != 1 {
+ t.Fatal("wrong number of streams after opened")
+ }
+ session.Close()
+ if session.NumStreams() != 0 {
+ t.Fatal("wrong number of streams after session closed")
+ }
+ } else {
+ t.Fatal(err)
+ }
+ cli.Close()
+}
+
+func TestRandomFrame(t *testing.T) {
+ addr, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ // pure random
+ session, _ := Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ rnd := make([]byte, rand.Uint32()%1024)
+ io.ReadFull(crand.Reader, rnd)
+ session.conn.Write(rnd)
+ }
+ cli.Close()
+
+ // double syn
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, cmdSYN, 1000)
+ session.writeFrame(f)
+ }
+ cli.Close()
+
+ // random cmds
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
+ session, _ = Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
+ session.writeFrame(f)
+ }
+ cli.Close()
+
+ // random cmds & sids
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ session.writeFrame(f)
+ }
+ cli.Close()
+
+ // random version
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ f.ver = byte(rand.Uint32())
+ session.writeFrame(f)
+ }
+ cli.Close()
+
+ // incorrect size
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ rnd := make([]byte, rand.Uint32()%1024)
+ io.ReadFull(crand.Reader, rnd)
+ f.data = rnd
+
+ buf := make([]byte, headerSize+len(f.data))
+ buf[0] = f.ver
+ buf[1] = f.cmd
+ binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size
+ binary.LittleEndian.PutUint32(buf[4:], f.sid)
+ copy(buf[headerSize:], f.data)
+
+ session.conn.Write(buf)
+ cli.Close()
+
+ // writeFrame after die
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+ //close first
+ session.Close()
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ session.writeFrame(f)
+ }
+}
+
+func TestWriteFrameInternal(t *testing.T) {
+ addr, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ // pure random
+ session, _ := Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ rnd := make([]byte, rand.Uint32()%1024)
+ io.ReadFull(crand.Reader, rnd)
+ session.conn.Write(rnd)
+ }
+ cli.Close()
+
+ // writeFrame after die
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ session, _ = Client(cli, nil)
+ //close first
+ session.Close()
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
+ }
+
+ // random cmds
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
+ session, _ = Client(cli, nil)
+ for i := 0; i < 100; i++ {
+ f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
+ session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0)
+ }
+ //deadline occur
+ {
+ c := make(chan time.Time)
+ close(c)
+ f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32())
+ _, err := session.writeFrameInternal(f, c, 0)
+ if !strings.Contains(err.Error(), "timeout") {
+ t.Fatal("write frame with deadline failed", err)
+ }
+ }
+ cli.Close()
+
+ {
+ cli, err = net.Dial("tcp", addr)
+ if err != nil {
+ t.Fatal(err)
+ }
+ config := DefaultConfig()
+ config.KeepAliveInterval = time.Second
+ config.KeepAliveTimeout = 2 * time.Second
+ session, _ = Client(&blockWriteConn{cli}, config)
+ f := newFrame(1, byte(rand.Uint32()), rand.Uint32())
+ c := make(chan time.Time)
+ go func() {
+ //die first, deadline second, better for coverage
+ time.Sleep(time.Second)
+ session.Close()
+ time.Sleep(time.Second)
+ close(c)
+ }()
+ _, err = session.writeFrameInternal(f, c, 0)
+ if !strings.Contains(err.Error(), "closed pipe") {
+ t.Fatal("write frame with to closed conn failed", err)
+ }
+ }
+}
+
+func TestReadDeadline(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ const N = 100
+ buf := make([]byte, 10)
+ var readErr error
+ for i := 0; i < N; i++ {
+ stream.SetReadDeadline(time.Now().Add(-1 * time.Minute))
+ if _, readErr = stream.Read(buf); readErr != nil {
+ break
+ }
+ }
+ if readErr != nil {
+ if !strings.Contains(readErr.Error(), "timeout") {
+ t.Fatalf("Wrong error: %v", readErr)
+ }
+ } else {
+ t.Fatal("No error when reading with past deadline")
+ }
+ session.Close()
+}
+
+func TestWriteDeadline(t *testing.T) {
+ _, stop, cli, err := setupServer(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ stream, _ := session.OpenStream()
+ buf := make([]byte, 10)
+ var writeErr error
+ for {
+ stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute))
+ if _, writeErr = stream.Write(buf); writeErr != nil {
+ if !strings.Contains(writeErr.Error(), "timeout") {
+ t.Fatalf("Wrong error: %v", writeErr)
+ }
+ break
+ }
+ }
+ session.Close()
+}
+
+func BenchmarkAcceptClose(b *testing.B) {
+ _, stop, cli, err := setupServer(b)
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer stop()
+ session, _ := Client(cli, nil)
+ for i := 0; i < b.N; i++ {
+ if stream, err := session.OpenStream(); err == nil {
+ stream.Close()
+ } else {
+ b.Fatal(err)
+ }
+ }
+}
+func BenchmarkConnSmux(b *testing.B) {
+ cs, ss, err := getSmuxStreamPair()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cs.Close()
+ defer ss.Close()
+ bench(b, cs, ss)
+}
+
+func BenchmarkConnTCP(b *testing.B) {
+ cs, ss, err := getTCPConnectionPair()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer cs.Close()
+ defer ss.Close()
+ bench(b, cs, ss)
+}
+
+func getSmuxStreamPair() (*Stream, *Stream, error) {
+ c1, c2, err := getTCPConnectionPair()
+ if err != nil {
+ return nil, nil, err
+ }
+
+ s, err := Server(c2, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ c, err := Client(c1, nil)
+ if err != nil {
+ return nil, nil, err
+ }
+ var ss *Stream
+ done := make(chan error)
+ go func() {
+ var rerr error
+ ss, rerr = s.AcceptStream()
+ done <- rerr
+ close(done)
+ }()
+ cs, err := c.OpenStream()
+ if err != nil {
+ return nil, nil, err
+ }
+ err = <-done
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return cs, ss, nil
+}
+
+func getTCPConnectionPair() (net.Conn, net.Conn, error) {
+ lst, err := net.Listen("tcp", "localhost:0")
+ if err != nil {
+ return nil, nil, err
+ }
+ defer lst.Close()
+
+ var conn0 net.Conn
+ var err0 error
+ done := make(chan struct{})
+ go func() {
+ conn0, err0 = lst.Accept()
+ close(done)
+ }()
+
+ conn1, err := net.Dial("tcp", lst.Addr().String())
+ if err != nil {
+ return nil, nil, err
+ }
+
+ <-done
+ if err0 != nil {
+ return nil, nil, err0
+ }
+ return conn0, conn1, nil
+}
+
+func bench(b *testing.B, rd io.Reader, wr io.Writer) {
+ buf := make([]byte, 128*1024)
+ buf2 := make([]byte, 128*1024)
+ b.SetBytes(128 * 1024)
+ b.ResetTimer()
+ b.ReportAllocs()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ count := 0
+ for {
+ n, _ := rd.Read(buf2)
+ count += n
+ if count == 128*1024*b.N {
+ return
+ }
+ }
+ }()
+ for i := 0; i < b.N; i++ {
+ wr.Write(buf)
+ }
+ wg.Wait()
+}
diff --git a/proxy/protocol/smux/shaper.go b/proxy/protocol/smux/shaper.go
new file mode 100644
index 0000000..be03406
--- /dev/null
+++ b/proxy/protocol/smux/shaper.go
@@ -0,0 +1,16 @@
+package smux
+
+type shaperHeap []writeRequest
+
+func (h shaperHeap) Len() int { return len(h) }
+func (h shaperHeap) Less(i, j int) bool { return h[i].prio < h[j].prio }
+func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }
+
+func (h *shaperHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[0 : n-1]
+ return x
+}
diff --git a/proxy/protocol/smux/shaper_test.go b/proxy/protocol/smux/shaper_test.go
new file mode 100644
index 0000000..6f6d265
--- /dev/null
+++ b/proxy/protocol/smux/shaper_test.go
@@ -0,0 +1,30 @@
+package smux
+
+import (
+ "container/heap"
+ "testing"
+)
+
+func TestShaper(t *testing.T) {
+ w1 := writeRequest{prio: 10}
+ w2 := writeRequest{prio: 10}
+ w3 := writeRequest{prio: 20}
+ w4 := writeRequest{prio: 100}
+
+ var reqs shaperHeap
+ heap.Push(&reqs, w4)
+ heap.Push(&reqs, w3)
+ heap.Push(&reqs, w2)
+ heap.Push(&reqs, w1)
+
+ var lastPrio uint64
+ for len(reqs) > 0 {
+ w := heap.Pop(&reqs).(writeRequest)
+ if w.prio < lastPrio {
+ t.Fatal("incorrect shaper priority")
+ }
+
+ t.Log("prio:", w.prio)
+ lastPrio = w.prio
+ }
+}
diff --git a/proxy/protocol/smux/stream.go b/proxy/protocol/smux/stream.go
new file mode 100644
index 0000000..f0ce104
--- /dev/null
+++ b/proxy/protocol/smux/stream.go
@@ -0,0 +1,545 @@
+package smux
+
+import (
+ "encoding/binary"
+ "io"
+ "net"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/nadoo/glider/pool"
+)
+
+// Stream implements net.Conn
+type Stream struct {
+ id uint32
+ sess *Session
+
+ buffers [][]byte
+ heads [][]byte // slice heads kept for recycle
+
+ bufferLock sync.Mutex
+ frameSize int
+
+ // notify a read event
+ chReadEvent chan struct{}
+
+ // flag the stream has closed
+ die chan struct{}
+ dieOnce sync.Once
+
+ // FIN command
+ chFinEvent chan struct{}
+ finEventOnce sync.Once
+
+ // deadlines
+ readDeadline atomic.Value
+ writeDeadline atomic.Value
+
+ // per stream sliding window control
+ numRead uint32 // number of consumed bytes
+ numWritten uint32 // count num of bytes written
+ incr uint32 // counting for sending
+
+ // UPD command
+ peerConsumed uint32 // num of bytes the peer has consumed
+ peerWindow uint32 // peer window, initialized to 256KB, updated by peer
+ chUpdate chan struct{} // notify of remote data consuming and window update
+}
+
+// newStream initiates a Stream struct
+func newStream(id uint32, frameSize int, sess *Session) *Stream {
+ s := new(Stream)
+ s.id = id
+ s.chReadEvent = make(chan struct{}, 1)
+ s.chUpdate = make(chan struct{}, 1)
+ s.frameSize = frameSize
+ s.sess = sess
+ s.die = make(chan struct{})
+ s.chFinEvent = make(chan struct{})
+ s.peerWindow = initialPeerWindow // set to initial window size
+ return s
+}
+
+// ID returns the unique stream ID.
+func (s *Stream) ID() uint32 {
+ return s.id
+}
+
+// Read implements net.Conn
+func (s *Stream) Read(b []byte) (n int, err error) {
+ for {
+ n, err = s.tryRead(b)
+ if err == ErrWouldBlock {
+ if ew := s.waitRead(); ew != nil {
+ return 0, ew
+ }
+ } else {
+ return n, err
+ }
+ }
+}
+
+// tryRead is the nonblocking version of Read
+func (s *Stream) tryRead(b []byte) (n int, err error) {
+ if s.sess.config.Version == 2 {
+ return s.tryReadv2(b)
+ }
+
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ n = copy(b, s.buffers[0])
+ s.buffers[0] = s.buffers[0][n:]
+ if len(s.buffers[0]) == 0 {
+ s.buffers[0] = nil
+ s.buffers = s.buffers[1:]
+ // full recycle
+ pool.PutBuffer(s.heads[0])
+ s.heads = s.heads[1:]
+ }
+ }
+ s.bufferLock.Unlock()
+
+ if n > 0 {
+ s.sess.returnTokens(n)
+ return n, nil
+ }
+
+ select {
+ case <-s.die:
+ return 0, io.EOF
+ default:
+ return 0, ErrWouldBlock
+ }
+}
+
+func (s *Stream) tryReadv2(b []byte) (n int, err error) {
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ var notifyConsumed uint32
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ n = copy(b, s.buffers[0])
+ s.buffers[0] = s.buffers[0][n:]
+ if len(s.buffers[0]) == 0 {
+ s.buffers[0] = nil
+ s.buffers = s.buffers[1:]
+ // full recycle
+ pool.PutBuffer(s.heads[0])
+ s.heads = s.heads[1:]
+ }
+ }
+
+ // in an ideal environment:
+ // if more than half of buffer has consumed, send read ack to peer
+ // based on round-trip time of ACK, continous flowing data
+ // won't slow down because of waiting for ACK, as long as the
+ // consumer keeps on reading data
+ // s.numRead == n also notify window at the first read
+ s.numRead += uint32(n)
+ s.incr += uint32(n)
+ if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
+ notifyConsumed = s.numRead
+ s.incr = 0
+ }
+ s.bufferLock.Unlock()
+
+ if n > 0 {
+ s.sess.returnTokens(n)
+ if notifyConsumed > 0 {
+ err := s.sendWindowUpdate(notifyConsumed)
+ return n, err
+ } else {
+ return n, nil
+ }
+ }
+
+ select {
+ case <-s.die:
+ return 0, io.EOF
+ default:
+ return 0, ErrWouldBlock
+ }
+}
+
+// WriteTo implements io.WriteTo
+func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
+ if s.sess.config.Version == 2 {
+ return s.writeTov2(w)
+ }
+
+ for {
+ var buf []byte
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ buf = s.buffers[0]
+ s.buffers = s.buffers[1:]
+ s.heads = s.heads[1:]
+ }
+ s.bufferLock.Unlock()
+
+ if buf != nil {
+ nw, ew := w.Write(buf)
+ s.sess.returnTokens(len(buf))
+ pool.PutBuffer(buf)
+ if nw > 0 {
+ n += int64(nw)
+ }
+
+ if ew != nil {
+ return n, ew
+ }
+ } else if ew := s.waitRead(); ew != nil {
+ return n, ew
+ }
+ }
+}
+
+func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
+ for {
+ var notifyConsumed uint32
+ var buf []byte
+ s.bufferLock.Lock()
+ if len(s.buffers) > 0 {
+ buf = s.buffers[0]
+ s.buffers = s.buffers[1:]
+ s.heads = s.heads[1:]
+ }
+ s.numRead += uint32(len(buf))
+ s.incr += uint32(len(buf))
+ if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
+ notifyConsumed = s.numRead
+ s.incr = 0
+ }
+ s.bufferLock.Unlock()
+
+ if buf != nil {
+ nw, ew := w.Write(buf)
+ s.sess.returnTokens(len(buf))
+ pool.PutBuffer(buf)
+ if nw > 0 {
+ n += int64(nw)
+ }
+
+ if ew != nil {
+ return n, ew
+ }
+
+ if notifyConsumed > 0 {
+ if err := s.sendWindowUpdate(notifyConsumed); err != nil {
+ return n, err
+ }
+ }
+ } else if ew := s.waitRead(); ew != nil {
+ return n, ew
+ }
+ }
+}
+
+func (s *Stream) sendWindowUpdate(consumed uint32) error {
+ var timer *time.Timer
+ var deadline <-chan time.Time
+ if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer = time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
+ var hdr updHeader
+ binary.LittleEndian.PutUint32(hdr[:], consumed)
+ binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
+ frame.data = hdr[:]
+ _, err := s.sess.writeFrameInternal(frame, deadline, 0)
+ return err
+}
+
+func (s *Stream) waitRead() error {
+ var timer *time.Timer
+ var deadline <-chan time.Time
+ if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer = time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ select {
+ case <-s.chReadEvent:
+ return nil
+ case <-s.chFinEvent:
+ return io.EOF
+ case <-s.sess.chSocketReadError:
+ return s.sess.socketReadError.Load().(error)
+ case <-s.sess.chProtoError:
+ return s.sess.protoError.Load().(error)
+ case <-deadline:
+ return ErrTimeout
+ case <-s.die:
+ return io.ErrClosedPipe
+ }
+
+}
+
+// Write implements net.Conn
+//
+// Note that the behavior when multiple goroutines write concurrently is not deterministic,
+// frames may interleave in random way.
+func (s *Stream) Write(b []byte) (n int, err error) {
+ if s.sess.config.Version == 2 {
+ return s.writeV2(b)
+ }
+
+ var deadline <-chan time.Time
+ if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ // check if stream has closed
+ select {
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ default:
+ }
+
+ // frame split and transmit
+ sent := 0
+ frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
+ bts := b
+ for len(bts) > 0 {
+ sz := len(bts)
+ if sz > s.frameSize {
+ sz = s.frameSize
+ }
+ frame.data = bts[:sz]
+ bts = bts[sz:]
+ n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten))
+ s.numWritten++
+ sent += n
+ if err != nil {
+ return sent, err
+ }
+ }
+
+ return sent, nil
+}
+
+func (s *Stream) writeV2(b []byte) (n int, err error) {
+ // check empty input
+ if len(b) == 0 {
+ return 0, nil
+ }
+
+ // check if stream has closed
+ select {
+ case <-s.die:
+ return 0, io.ErrClosedPipe
+ default:
+ }
+
+ // create write deadline timer
+ var deadline <-chan time.Time
+ if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
+ timer := time.NewTimer(time.Until(d))
+ defer timer.Stop()
+ deadline = timer.C
+ }
+
+ // frame split and transmit process
+ sent := 0
+ frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
+
+ for {
+ // per stream sliding window control
+ // [.... [consumed... numWritten] ... win... ]
+ // [.... [consumed...................+rmtwnd]]
+ var bts []byte
+ // note:
+ // even if uint32 overflow, this math still works:
+ // eg1: uint32(0) - uint32(math.MaxUint32) = 1
+ // eg2: int32(uint32(0) - uint32(1)) = -1
+ // security check for misbehavior
+ inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
+ if inflight < 0 {
+ return 0, ErrConsumed
+ }
+
+ win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
+ if win > 0 {
+ if win > int32(len(b)) {
+ bts = b
+ b = nil
+ } else {
+ bts = b[:win]
+ b = b[win:]
+ }
+
+ for len(bts) > 0 {
+ sz := len(bts)
+ if sz > s.frameSize {
+ sz = s.frameSize
+ }
+ frame.data = bts[:sz]
+ bts = bts[sz:]
+ n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
+ atomic.AddUint32(&s.numWritten, uint32(sz))
+ sent += n
+ if err != nil {
+ return sent, err
+ }
+ }
+ }
+
+ // if there is any data remaining to be sent
+ // wait until stream closes, window changes or deadline reached
+ // this blocking behavior will inform upper layer to do flow control
+ if len(b) > 0 {
+ select {
+ case <-s.chFinEvent: // if fin arrived, future window update is impossible
+ return 0, io.EOF
+ case <-s.die:
+ return sent, io.ErrClosedPipe
+ case <-deadline:
+ return sent, ErrTimeout
+ case <-s.sess.chSocketWriteError:
+ return sent, s.sess.socketWriteError.Load().(error)
+ case <-s.chUpdate:
+ continue
+ }
+ } else {
+ return sent, nil
+ }
+ }
+}
+
+// Close implements net.Conn
+func (s *Stream) Close() error {
+ var once bool
+ var err error
+ s.dieOnce.Do(func() {
+ close(s.die)
+ once = true
+ })
+
+ if once {
+ _, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id))
+ s.sess.streamClosed(s.id)
+ return err
+ } else {
+ return io.ErrClosedPipe
+ }
+}
+
+// GetDieCh returns a readonly chan which can be readable
+// when the stream is to be closed.
+func (s *Stream) GetDieCh() <-chan struct{} {
+ return s.die
+}
+
+// SetReadDeadline sets the read deadline as defined by
+// net.Conn.SetReadDeadline.
+// A zero time value disables the deadline.
+func (s *Stream) SetReadDeadline(t time.Time) error {
+ s.readDeadline.Store(t)
+ s.notifyReadEvent()
+ return nil
+}
+
+// SetWriteDeadline sets the write deadline as defined by
+// net.Conn.SetWriteDeadline.
+// A zero time value disables the deadline.
+func (s *Stream) SetWriteDeadline(t time.Time) error {
+ s.writeDeadline.Store(t)
+ return nil
+}
+
+// SetDeadline sets both read and write deadlines as defined by
+// net.Conn.SetDeadline.
+// A zero time value disables the deadlines.
+func (s *Stream) SetDeadline(t time.Time) error {
+ if err := s.SetReadDeadline(t); err != nil {
+ return err
+ }
+ if err := s.SetWriteDeadline(t); err != nil {
+ return err
+ }
+ return nil
+}
+
+// session closes
+func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
+
+// LocalAddr satisfies net.Conn interface
+func (s *Stream) LocalAddr() net.Addr {
+ if ts, ok := s.sess.conn.(interface {
+ LocalAddr() net.Addr
+ }); ok {
+ return ts.LocalAddr()
+ }
+ return nil
+}
+
+// RemoteAddr satisfies net.Conn interface
+func (s *Stream) RemoteAddr() net.Addr {
+ if ts, ok := s.sess.conn.(interface {
+ RemoteAddr() net.Addr
+ }); ok {
+ return ts.RemoteAddr()
+ }
+ return nil
+}
+
+// pushBytes append buf to buffers
+func (s *Stream) pushBytes(buf []byte) (written int, err error) {
+ s.bufferLock.Lock()
+ s.buffers = append(s.buffers, buf)
+ s.heads = append(s.heads, buf)
+ s.bufferLock.Unlock()
+ return
+}
+
+// recycleTokens transform remaining bytes to tokens(will truncate buffer)
+func (s *Stream) recycleTokens() (n int) {
+ s.bufferLock.Lock()
+ for k := range s.buffers {
+ n += len(s.buffers[k])
+ pool.PutBuffer(s.heads[k])
+ }
+ s.buffers = nil
+ s.heads = nil
+ s.bufferLock.Unlock()
+ return
+}
+
+// notify read event
+func (s *Stream) notifyReadEvent() {
+ select {
+ case s.chReadEvent <- struct{}{}:
+ default:
+ }
+}
+
+// update command
+func (s *Stream) update(consumed uint32, window uint32) {
+ atomic.StoreUint32(&s.peerConsumed, consumed)
+ atomic.StoreUint32(&s.peerWindow, window)
+ select {
+ case s.chUpdate <- struct{}{}:
+ default:
+ }
+}
+
+// mark this stream has been closed in protocol
+func (s *Stream) fin() {
+ s.finEventOnce.Do(func() {
+ close(s.chFinEvent)
+ })
+}
diff --git a/proxy/smux/client.go b/proxy/smux/client.go
new file mode 100644
index 0000000..a586364
--- /dev/null
+++ b/proxy/smux/client.go
@@ -0,0 +1,76 @@
+package smux
+
+import (
+ "errors"
+ "net"
+ "net/url"
+
+ "github.com/nadoo/glider/log"
+ "github.com/nadoo/glider/proxy"
+
+ "github.com/nadoo/glider/proxy/protocol/smux"
+)
+
+// SmuxClient struct.
+type SmuxClient struct {
+ dialer proxy.Dialer
+ addr string
+ session *smux.Session
+}
+
+func init() {
+ proxy.RegisterDialer("smux", NewSmuxDialer)
+}
+
+// NewSmuxDialer returns a smux dialer.
+func NewSmuxDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
+ u, err := url.Parse(s)
+ if err != nil {
+ log.F("[smux] parse url err: %s", err)
+ return nil, err
+ }
+
+ c := &SmuxClient{
+ dialer: d,
+ addr: u.Host,
+ }
+
+ return c, nil
+}
+
+// Addr returns forwarder's address.
+func (s *SmuxClient) Addr() string {
+ if s.addr == "" {
+ return s.dialer.Addr()
+ }
+ return s.addr
+}
+
+// Dial connects to the address addr on the network net via the proxy.
+func (s *SmuxClient) Dial(network, addr string) (net.Conn, error) {
+ if s.session != nil {
+ if c, err := s.session.OpenStream(); err == nil {
+ return c, err
+ }
+ s.session.Close()
+ }
+ if err := s.initConn(); err != nil {
+ return nil, err
+ }
+ return s.session.OpenStream()
+}
+
+// DialUDP connects to the given address via the proxy.
+func (s *SmuxClient) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) {
+ return nil, nil, errors.New("smux client does not support udp now")
+}
+
+func (s *SmuxClient) initConn() error {
+ conn, err := s.dialer.Dial("tcp", s.addr)
+ if err != nil {
+ log.F("[smux] dial to %s error: %s", s.addr, err)
+ return err
+ }
+ s.session, err = smux.Client(conn, nil)
+ return err
+}
diff --git a/proxy/smux/server.go b/proxy/smux/server.go
new file mode 100644
index 0000000..3adbd9c
--- /dev/null
+++ b/proxy/smux/server.go
@@ -0,0 +1,98 @@
+package smux
+
+import (
+ "net"
+ "net/url"
+ "strings"
+
+ "github.com/nadoo/glider/log"
+ "github.com/nadoo/glider/proxy"
+
+ "github.com/nadoo/glider/proxy/protocol/smux"
+)
+
+// SmuxServer struct.
+type SmuxServer struct {
+ proxy proxy.Proxy
+ addr string
+ server proxy.Server
+}
+
+func init() {
+ proxy.RegisterServer("smux", NewSmuxServer)
+}
+
+// NewSmuxServer returns a smux transport layer before the real server.
+func NewSmuxServer(s string, p proxy.Proxy) (proxy.Server, error) {
+ transport := strings.Split(s, ",")
+
+ u, err := url.Parse(transport[0])
+ if err != nil {
+ log.F("[smux] parse url err: %s", err)
+ return nil, err
+ }
+
+ m := &SmuxServer{
+ proxy: p,
+ addr: u.Host,
+ }
+
+ if len(transport) > 1 {
+ m.server, err = proxy.ServerFromURL(transport[1], p)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ return m, nil
+}
+
+// ListenAndServe listens on server's addr and serves connections.
+func (s *SmuxServer) ListenAndServe() {
+ l, err := net.Listen("tcp", s.addr)
+ if err != nil {
+ log.F("[smux] failed to listen on %s: %v", s.addr, err)
+ return
+ }
+ defer l.Close()
+
+ log.F("[smux] listening mux on %s", s.addr)
+
+ for {
+ c, err := l.Accept()
+ if err != nil {
+ log.F("[smux] failed to accept: %v", err)
+ continue
+ }
+
+ go s.Serve(c)
+ }
+}
+
+// Serve serves a connection.
+func (s *SmuxServer) Serve(c net.Conn) {
+ // we know the internal server will close the connection after serve
+ // defer c.Close()
+
+ session, err := smux.Server(c, nil)
+ if err != nil {
+ log.F("[smux] failed to create session: %v", err)
+ return
+ }
+
+ for {
+ // Accept a stream
+ stream, err := session.AcceptStream()
+ if err != nil {
+ session.Close()
+ break
+ }
+ go s.ServeStream(stream)
+ }
+}
+
+func (s *SmuxServer) ServeStream(c *smux.Stream) {
+ if s.server != nil {
+ s.server.Serve(c)
+ }
+}