From dbd2e04521270fb28a65ff2987ed38a41da6e392 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Wed, 21 Apr 2021 00:29:17 +0800 Subject: [PATCH] proxy: added smux support --- README.md | 37 +- config.go | 8 +- feature.go | 1 + go.mod | 2 +- go.sum | 4 +- proxy/protocol/smux/LICENSE | 21 + proxy/protocol/smux/frame.go | 81 ++ proxy/protocol/smux/mux.go | 110 +++ proxy/protocol/smux/mux_test.go | 86 +++ proxy/protocol/smux/session.go | 527 +++++++++++++ proxy/protocol/smux/session_test.go | 1090 +++++++++++++++++++++++++++ proxy/protocol/smux/shaper.go | 16 + proxy/protocol/smux/shaper_test.go | 30 + proxy/protocol/smux/stream.go | 545 ++++++++++++++ proxy/smux/client.go | 76 ++ proxy/smux/server.go | 98 +++ 16 files changed, 2698 insertions(+), 34 deletions(-) create mode 100644 proxy/protocol/smux/LICENSE create mode 100644 proxy/protocol/smux/frame.go create mode 100644 proxy/protocol/smux/mux.go create mode 100644 proxy/protocol/smux/mux_test.go create mode 100644 proxy/protocol/smux/session.go create mode 100644 proxy/protocol/smux/session_test.go create mode 100644 proxy/protocol/smux/shaper.go create mode 100644 proxy/protocol/smux/shaper_test.go create mode 100644 proxy/protocol/smux/stream.go create mode 100644 proxy/smux/client.go create mode 100644 proxy/smux/server.go diff --git a/README.md b/README.md index 54b52e1..5d74798 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ we can set up local listeners as proxy servers, and forward requests to internet |TLS |√| |√| |transport client & server |KCP | |√|√| |transport client & server |Unix |√|√|√|√|transport client & server +|Smux |√| |√| |transport client & server |Websocket |√| |√| |transport client & server |Simple-Obfs | | |√| |transport client only |Redir |√| | | |linux only @@ -96,7 +97,7 @@ glider -h click to see details ```bash -glider 0.13.0 usage: +glider 0.14.0 usage: -check string check=tcp[://HOST:PORT]: tcp port connect check check=http://HOST[:PORT][/URI][#expect=STRING_IN_RESP_LINE] @@ -152,27 +153,10 @@ glider 0.13.0 usage: forward strategy, default: rr (default "rr") -verbose verbose mode -``` - - -run: -```bash -glider -config CONFIGPATH -``` -```bash -glider -verbose -listen :8443 -forward SCHEME://HOST:PORT -``` - -#### Schemes - -
-click to see details - -```bash Available schemes: - listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp - forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs + listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp + forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs Socks5 scheme: socks://[user:pass@]host:port @@ -251,6 +235,9 @@ TLS and Websocket with a specified proxy protocol: Unix domain socket scheme: unix://path +Smux scheme: + smux://host:port + KCP scheme: kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE] @@ -298,16 +285,8 @@ Config file format(see `./glider.conf.example` as an example): KEY=VALUE KEY=VALUE # KEY equals to command line flag name: listen forward strategy... -``` -
- -#### Examples - -
-click to see details - -```bash +Examples: ./glider -config glider.conf -run glider with specified config file. diff --git a/config.go b/config.go index d84e137..b2eaaf2 100644 --- a/config.go +++ b/config.go @@ -131,8 +131,8 @@ func usage() { fmt.Fprintf(w, "\n") fmt.Fprintf(w, "Available schemes:\n") - fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp\n") - fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs\n") + fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp\n") + fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs\n") fmt.Fprintf(w, "\n") fmt.Fprintf(w, "Socks5 scheme:\n") @@ -231,6 +231,10 @@ func usage() { fmt.Fprintf(w, " unix://path\n") fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "Smux scheme:\n") + fmt.Fprintf(w, " smux://host:port\n") + fmt.Fprintf(w, "\n") + fmt.Fprintf(w, "KCP scheme:\n") fmt.Fprintf(w, " kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]\n") fmt.Fprintf(w, "\n") diff --git a/feature.go b/feature.go index 930a157..c4c5b07 100644 --- a/feature.go +++ b/feature.go @@ -10,6 +10,7 @@ import ( _ "github.com/nadoo/glider/proxy/mixed" _ "github.com/nadoo/glider/proxy/obfs" _ "github.com/nadoo/glider/proxy/reject" + _ "github.com/nadoo/glider/proxy/smux" _ "github.com/nadoo/glider/proxy/socks4" _ "github.com/nadoo/glider/proxy/socks5" _ "github.com/nadoo/glider/proxy/ss" diff --git a/go.mod b/go.mod index 313b23c..0fe7e0f 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152 github.com/ebfe/rc2 v0.0.0-20131011165748-24b9757f5521 // indirect - github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa + github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305 github.com/klauspost/cpuid/v2 v2.0.6 // indirect github.com/klauspost/reedsolomon v1.9.12 // indirect github.com/mdlayher/raw v0.0.0-20210412142147-51b895745faf // indirect diff --git a/go.sum b/go.sum index 0bb9ef8..847e797 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8= github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis= -github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa h1:ff6iUOGfajZ+T8RBuRtOjYMW7/6aF8iJG+iCp2wQyQ4= -github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= +github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305 h1:DGmCtsdLE6r7tuFH5YHnDoKJU9WXjptF/8Q8iKc41Tk= +github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI= github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw= github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ= github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok= diff --git a/proxy/protocol/smux/LICENSE b/proxy/protocol/smux/LICENSE new file mode 100644 index 0000000..eed41ac --- /dev/null +++ b/proxy/protocol/smux/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016-2017 Daniel Fu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/proxy/protocol/smux/frame.go b/proxy/protocol/smux/frame.go new file mode 100644 index 0000000..467a058 --- /dev/null +++ b/proxy/protocol/smux/frame.go @@ -0,0 +1,81 @@ +package smux + +import ( + "encoding/binary" + "fmt" +) + +const ( // cmds + // protocol version 1: + cmdSYN byte = iota // stream open + cmdFIN // stream close, a.k.a EOF mark + cmdPSH // data push + cmdNOP // no operation + + // protocol version 2 extra commands + // notify bytes consumed by remote peer-end + cmdUPD +) + +const ( + // data size of cmdUPD, format: + // |4B data consumed(ACK)| 4B window size(WINDOW) | + szCmdUPD = 8 +) + +const ( + // initial peer window guess, a slow-start + initialPeerWindow = 262144 +) + +const ( + sizeOfVer = 1 + sizeOfCmd = 1 + sizeOfLength = 2 + sizeOfSid = 4 + headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength +) + +// Frame defines a packet from or to be multiplexed into a single connection +type Frame struct { + ver byte + cmd byte + sid uint32 + data []byte +} + +func newFrame(version byte, cmd byte, sid uint32) Frame { + return Frame{ver: version, cmd: cmd, sid: sid} +} + +type rawHeader [headerSize]byte + +func (h rawHeader) Version() byte { + return h[0] +} + +func (h rawHeader) Cmd() byte { + return h[1] +} + +func (h rawHeader) Length() uint16 { + return binary.LittleEndian.Uint16(h[2:]) +} + +func (h rawHeader) StreamID() uint32 { + return binary.LittleEndian.Uint32(h[4:]) +} + +func (h rawHeader) String() string { + return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d", + h.Version(), h.Cmd(), h.StreamID(), h.Length()) +} + +type updHeader [szCmdUPD]byte + +func (h updHeader) Consumed() uint32 { + return binary.LittleEndian.Uint32(h[:]) +} +func (h updHeader) Window() uint32 { + return binary.LittleEndian.Uint32(h[4:]) +} diff --git a/proxy/protocol/smux/mux.go b/proxy/protocol/smux/mux.go new file mode 100644 index 0000000..c0b8ab8 --- /dev/null +++ b/proxy/protocol/smux/mux.go @@ -0,0 +1,110 @@ +// Package smux is a multiplexing library for Golang. +// +// It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP, +// and provides stream-oriented multiplexing over a single channel. +package smux + +import ( + "errors" + "fmt" + "io" + "math" + "time" +) + +// Config is used to tune the Smux session +type Config struct { + // SMUX Protocol version, support 1,2 + Version int + + // Disabled keepalive + KeepAliveDisabled bool + + // KeepAliveInterval is how often to send a NOP command to the remote + KeepAliveInterval time.Duration + + // KeepAliveTimeout is how long the session + // will be closed if no data has arrived + KeepAliveTimeout time.Duration + + // MaxFrameSize is used to control the maximum + // frame size to sent to the remote + MaxFrameSize int + + // MaxReceiveBuffer is used to control the maximum + // number of data in the buffer pool + MaxReceiveBuffer int + + // MaxStreamBuffer is used to control the maximum + // number of data per stream + MaxStreamBuffer int +} + +// DefaultConfig is used to return a default configuration +func DefaultConfig() *Config { + return &Config{ + Version: 1, + KeepAliveInterval: 10 * time.Second, + KeepAliveTimeout: 30 * time.Second, + MaxFrameSize: 32768, + MaxReceiveBuffer: 4194304, + MaxStreamBuffer: 65536, + } +} + +// VerifyConfig is used to verify the sanity of configuration +func VerifyConfig(config *Config) error { + if !(config.Version == 1 || config.Version == 2) { + return errors.New("unsupported protocol version") + } + if !config.KeepAliveDisabled { + if config.KeepAliveInterval == 0 { + return errors.New("keep-alive interval must be positive") + } + if config.KeepAliveTimeout < config.KeepAliveInterval { + return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval") + } + } + if config.MaxFrameSize <= 0 { + return errors.New("max frame size must be positive") + } + if config.MaxFrameSize > 65535 { + return errors.New("max frame size must not be larger than 65535") + } + if config.MaxReceiveBuffer <= 0 { + return errors.New("max receive buffer must be positive") + } + if config.MaxStreamBuffer <= 0 { + return errors.New("max stream buffer must be positive") + } + if config.MaxStreamBuffer > config.MaxReceiveBuffer { + return errors.New("max stream buffer must not be larger than max receive buffer") + } + if config.MaxStreamBuffer > math.MaxInt32 { + return errors.New("max stream buffer cannot be larger than 2147483647") + } + return nil +} + +// Server is used to initialize a new server-side connection. +func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, false), nil +} + +// Client is used to initialize a new client-side connection. +func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) { + if config == nil { + config = DefaultConfig() + } + + if err := VerifyConfig(config); err != nil { + return nil, err + } + return newSession(config, conn, true), nil +} diff --git a/proxy/protocol/smux/mux_test.go b/proxy/protocol/smux/mux_test.go new file mode 100644 index 0000000..dc9c1c1 --- /dev/null +++ b/proxy/protocol/smux/mux_test.go @@ -0,0 +1,86 @@ +package smux + +import ( + "bytes" + "testing" +) + +type buffer struct { + bytes.Buffer +} + +func (b *buffer) Close() error { + b.Buffer.Reset() + return nil +} + +func TestConfig(t *testing.T) { + VerifyConfig(DefaultConfig()) + + config := DefaultConfig() + config.KeepAliveInterval = 0 + err := VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.KeepAliveInterval = 10 + config.KeepAliveTimeout = 5 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.MaxFrameSize = 0 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.MaxFrameSize = 65536 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.MaxReceiveBuffer = 0 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.MaxStreamBuffer = 0 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + config = DefaultConfig() + config.MaxStreamBuffer = 100 + config.MaxReceiveBuffer = 99 + err = VerifyConfig(config) + t.Log(err) + if err == nil { + t.Fatal(err) + } + + var bts buffer + if _, err := Server(&bts, config); err == nil { + t.Fatal("server started with wrong config") + } + + if _, err := Client(&bts, config); err == nil { + t.Fatal("client started with wrong config") + } +} diff --git a/proxy/protocol/smux/session.go b/proxy/protocol/smux/session.go new file mode 100644 index 0000000..f8cfee7 --- /dev/null +++ b/proxy/protocol/smux/session.go @@ -0,0 +1,527 @@ +package smux + +import ( + "container/heap" + "encoding/binary" + "errors" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/nadoo/glider/pool" +) + +const ( + defaultAcceptBacklog = 1024 +) + +var ( + ErrInvalidProtocol = errors.New("invalid protocol") + ErrConsumed = errors.New("peer consumed more than sent") + ErrGoAway = errors.New("stream id overflows, should start a new connection") + ErrTimeout = errors.New("timeout") + ErrWouldBlock = errors.New("operation would block on IO") +) + +type writeRequest struct { + prio uint64 + frame Frame + result chan writeResult +} + +type writeResult struct { + n int + err error +} + +type buffersWriter interface { + WriteBuffers(v [][]byte) (n int, err error) +} + +// Session defines a multiplexed connection for streams +type Session struct { + conn io.ReadWriteCloser + + config *Config + nextStreamID uint32 // next stream identifier + nextStreamIDLock sync.Mutex + + bucket int32 // token bucket + bucketNotify chan struct{} // used for waiting for tokens + + streams map[uint32]*Stream // all streams in this session + streamLock sync.Mutex // locks streams + + die chan struct{} // flag session has died + dieOnce sync.Once + + // socket error handling + socketReadError atomic.Value + socketWriteError atomic.Value + chSocketReadError chan struct{} + chSocketWriteError chan struct{} + socketReadErrorOnce sync.Once + socketWriteErrorOnce sync.Once + + // smux protocol errors + protoError atomic.Value + chProtoError chan struct{} + protoErrorOnce sync.Once + + chAccepts chan *Stream + + dataReady int32 // flag data has arrived + + goAway int32 // flag id exhausted + + deadline atomic.Value + + shaper chan writeRequest // a shaper for writing + writes chan writeRequest +} + +func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { + s := new(Session) + s.die = make(chan struct{}) + s.conn = conn + s.config = config + s.streams = make(map[uint32]*Stream) + s.chAccepts = make(chan *Stream, defaultAcceptBacklog) + s.bucket = int32(config.MaxReceiveBuffer) + s.bucketNotify = make(chan struct{}, 1) + s.shaper = make(chan writeRequest) + s.writes = make(chan writeRequest) + s.chSocketReadError = make(chan struct{}) + s.chSocketWriteError = make(chan struct{}) + s.chProtoError = make(chan struct{}) + + if client { + s.nextStreamID = 1 + } else { + s.nextStreamID = 0 + } + + go s.shaperLoop() + go s.recvLoop() + go s.sendLoop() + if !config.KeepAliveDisabled { + go s.keepalive() + } + return s +} + +// OpenStream is used to create a new stream +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + // generate stream id + s.nextStreamIDLock.Lock() + if s.goAway > 0 { + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + + s.nextStreamID += 2 + sid := s.nextStreamID + if sid == sid%2 { // stream-id overflows + s.goAway = 1 + s.nextStreamIDLock.Unlock() + return nil, ErrGoAway + } + s.nextStreamIDLock.Unlock() + + stream := newStream(sid, s.config.MaxFrameSize, s) + + if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil { + return nil, err + } + + s.streamLock.Lock() + defer s.streamLock.Unlock() + select { + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chSocketWriteError: + return nil, s.socketWriteError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + default: + s.streams[sid] = stream + return stream, nil + } +} + +// Open returns a generic ReadWriteCloser +func (s *Session) Open() (io.ReadWriteCloser, error) { + return s.OpenStream() +} + +// AcceptStream is used to block until the next available stream +// is ready to be accepted. +func (s *Session) AcceptStream() (*Stream, error) { + var deadline <-chan time.Time + if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case stream := <-s.chAccepts: + return stream, nil + case <-deadline: + return nil, ErrTimeout + case <-s.chSocketReadError: + return nil, s.socketReadError.Load().(error) + case <-s.chProtoError: + return nil, s.protoError.Load().(error) + case <-s.die: + return nil, io.ErrClosedPipe + } +} + +// Accept Returns a generic ReadWriteCloser instead of smux.Stream +func (s *Session) Accept() (io.ReadWriteCloser, error) { + return s.AcceptStream() +} + +// Close is used to close the session and all streams. +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + s.streamLock.Lock() + for k := range s.streams { + s.streams[k].sessionClose() + } + s.streamLock.Unlock() + return s.conn.Close() + } else { + return io.ErrClosedPipe + } +} + +// notifyBucket notifies recvLoop that bucket is available +func (s *Session) notifyBucket() { + select { + case s.bucketNotify <- struct{}{}: + default: + } +} + +func (s *Session) notifyReadError(err error) { + s.socketReadErrorOnce.Do(func() { + s.socketReadError.Store(err) + close(s.chSocketReadError) + }) +} + +func (s *Session) notifyWriteError(err error) { + s.socketWriteErrorOnce.Do(func() { + s.socketWriteError.Store(err) + close(s.chSocketWriteError) + }) +} + +func (s *Session) notifyProtoError(err error) { + s.protoErrorOnce.Do(func() { + s.protoError.Store(err) + close(s.chProtoError) + }) +} + +// IsClosed does a safe check to see if we have shutdown +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +// NumStreams returns the number of currently open streams +func (s *Session) NumStreams() int { + if s.IsClosed() { + return 0 + } + s.streamLock.Lock() + defer s.streamLock.Unlock() + return len(s.streams) +} + +// SetDeadline sets a deadline used by Accept* calls. +// A zero time value disables the deadline. +func (s *Session) SetDeadline(t time.Time) error { + s.deadline.Store(t) + return nil +} + +// LocalAddr satisfies net.Conn interface +func (s *Session) LocalAddr() net.Addr { + if ts, ok := s.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Session) RemoteAddr() net.Addr { + if ts, ok := s.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} + +// notify the session that a stream has closed +func (s *Session) streamClosed(sid uint32) { + s.streamLock.Lock() + if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } + } + delete(s.streams, sid) + s.streamLock.Unlock() +} + +// returnTokens is called by stream to return token after read +func (s *Session) returnTokens(n int) { + if atomic.AddInt32(&s.bucket, int32(n)) > 0 { + s.notifyBucket() + } +} + +// recvLoop keeps on reading from underlying connection if tokens are available +func (s *Session) recvLoop() { + var hdr rawHeader + var updHdr updHeader + + for { + for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() { + select { + case <-s.bucketNotify: + case <-s.die: + return + } + } + + // read header first + if _, err := io.ReadFull(s.conn, hdr[:]); err == nil { + atomic.StoreInt32(&s.dataReady, 1) + if hdr.Version() != byte(s.config.Version) { + s.notifyProtoError(ErrInvalidProtocol) + return + } + sid := hdr.StreamID() + switch hdr.Cmd() { + case cmdNOP: + case cmdSYN: + s.streamLock.Lock() + if _, ok := s.streams[sid]; !ok { + stream := newStream(sid, s.config.MaxFrameSize, s) + s.streams[sid] = stream + select { + case s.chAccepts <- stream: + case <-s.die: + } + } + s.streamLock.Unlock() + case cmdFIN: + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.fin() + stream.notifyReadEvent() + } + s.streamLock.Unlock() + case cmdPSH: + if hdr.Length() > 0 { + newbuf := pool.GetBuffer(int(hdr.Length())) + if written, err := io.ReadFull(s.conn, newbuf); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.pushBytes(newbuf) + atomic.AddInt32(&s.bucket, -int32(written)) + stream.notifyReadEvent() + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + } + case cmdUPD: + if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil { + s.streamLock.Lock() + if stream, ok := s.streams[sid]; ok { + stream.update(updHdr.Consumed(), updHdr.Window()) + } + s.streamLock.Unlock() + } else { + s.notifyReadError(err) + return + } + default: + s.notifyProtoError(ErrInvalidProtocol) + return + } + } else { + s.notifyReadError(err) + return + } + } +} + +func (s *Session) keepalive() { + tickerPing := time.NewTicker(s.config.KeepAliveInterval) + tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout) + defer tickerPing.Stop() + defer tickerTimeout.Stop() + for { + select { + case <-tickerPing.C: + s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0) + s.notifyBucket() // force a signal to the recvLoop + case <-tickerTimeout.C: + if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) { + // recvLoop may block while bucket is 0, in this case, + // session should not be closed. + if atomic.LoadInt32(&s.bucket) > 0 { + s.Close() + return + } + } + case <-s.die: + return + } + } +} + +// shaper shapes the sending sequence among streams +func (s *Session) shaperLoop() { + var reqs shaperHeap + var next writeRequest + var chWrite chan writeRequest + + for { + if len(reqs) > 0 { + chWrite = s.writes + next = heap.Pop(&reqs).(writeRequest) + } else { + chWrite = nil + } + + select { + case <-s.die: + return + case r := <-s.shaper: + if chWrite != nil { // next is valid, reshape + heap.Push(&reqs, next) + } + heap.Push(&reqs, r) + case chWrite <- next: + } + } +} + +func (s *Session) sendLoop() { + var buf []byte + var n int + var err error + var vec [][]byte // vector for writeBuffers + + bw, ok := s.conn.(buffersWriter) + if ok { + buf = make([]byte, headerSize) + vec = make([][]byte, 2) + } else { + buf = make([]byte, (1<<16)+headerSize) + } + + for { + select { + case <-s.die: + return + case request := <-s.writes: + buf[0] = request.frame.ver + buf[1] = request.frame.cmd + binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data))) + binary.LittleEndian.PutUint32(buf[4:], request.frame.sid) + + if len(vec) > 0 { + vec[0] = buf[:headerSize] + vec[1] = request.frame.data + n, err = bw.WriteBuffers(vec) + } else { + copy(buf[headerSize:], request.frame.data) + n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)]) + } + + n -= headerSize + if n < 0 { + n = 0 + } + + result := writeResult{ + n: n, + err: err, + } + + request.result <- result + close(request.result) + + // store conn error + if err != nil { + s.notifyWriteError(err) + return + } + } + } +} + +// writeFrame writes the frame to the underlying connection +// and returns the number of bytes written if successful +func (s *Session) writeFrame(f Frame) (n int, err error) { + return s.writeFrameInternal(f, nil, 0) +} + +// internal writeFrame version to support deadline used in keepalive +func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) { + req := writeRequest{ + prio: prio, + frame: f, + result: make(chan writeResult, 1), + } + select { + case s.shaper <- req: + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } + + select { + case result := <-req.result: + return result.n, result.err + case <-s.die: + return 0, io.ErrClosedPipe + case <-s.chSocketWriteError: + return 0, s.socketWriteError.Load().(error) + case <-deadline: + return 0, ErrTimeout + } +} diff --git a/proxy/protocol/smux/session_test.go b/proxy/protocol/smux/session_test.go new file mode 100644 index 0000000..3479570 --- /dev/null +++ b/proxy/protocol/smux/session_test.go @@ -0,0 +1,1090 @@ +package smux + +import ( + "bytes" + crand "crypto/rand" + "encoding/binary" + "fmt" + "io" + "log" + "math/rand" + "net" + "net/http" + _ "net/http/pprof" + "strings" + "sync" + "testing" + "time" +) + +func init() { + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() +} + +// setupServer starts new server listening on a random localhost port and +// returns address of the server, function to stop the server, new client +// connection to this server or an error. +func setupServer(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, nil, err + } + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + go handleConnection(conn) + }() + addr = ln.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + ln.Close() + return "", nil, nil, err + } + return ln.Addr().String(), func() { ln.Close() }, conn, nil +} + +func handleConnection(conn net.Conn) { + session, _ := Server(conn, nil) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + buf := make([]byte, 65536) + for { + n, err := s.Read(buf) + if err != nil { + return + } + s.Write(buf[:n]) + } + }(stream) + } else { + return + } + } +} + +// setupServer starts new server listening on a random localhost port and +// returns address of the server, function to stop the server, new client +// connection to this server or an error. +func setupServerV2(tb testing.TB) (addr string, stopfunc func(), client net.Conn, err error) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", nil, nil, err + } + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + go handleConnectionV2(conn) + }() + addr = ln.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + ln.Close() + return "", nil, nil, err + } + return ln.Addr().String(), func() { ln.Close() }, conn, nil +} + +func handleConnectionV2(conn net.Conn) { + config := DefaultConfig() + config.Version = 2 + session, _ := Server(conn, config) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + buf := make([]byte, 65536) + for { + n, err := s.Read(buf) + if err != nil { + return + } + s.Write(buf[:n]) + } + }(stream) + } else { + return + } + } +} + +func TestEcho(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + const N = 100 + buf := make([]byte, 10) + var sent string + var received string + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + stream.Write([]byte(msg)) + sent += msg + if n, err := stream.Read(buf); err != nil { + t.Fatal(err) + } else { + received += string(buf[:n]) + } + } + if sent != received { + t.Fatal("data mimatch") + } + session.Close() +} + +func TestWriteTo(t *testing.T) { + const N = 1 << 20 + // server + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + session, _ := Server(conn, nil) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + numBytes := 0 + buf := make([]byte, 65536) + for { + n, err := s.Read(buf) + if err != nil { + return + } + s.Write(buf[:n]) + numBytes += n + + if numBytes == N { + s.Close() + return + } + } + }(stream) + } else { + return + } + } + }() + + addr := ln.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // client + session, _ := Client(conn, nil) + stream, _ := session.OpenStream() + sndbuf := make([]byte, N) + for i := range sndbuf { + sndbuf[i] = byte(rand.Int()) + } + + go stream.Write(sndbuf) + + var rcvbuf bytes.Buffer + nw, ew := stream.WriteTo(&rcvbuf) + if ew != io.EOF { + t.Fatal(ew) + } + + if nw != N { + t.Fatal("WriteTo nw mismatch", nw) + } + + if bytes.Compare(sndbuf, rcvbuf.Bytes()) != 0 { + t.Fatal("mismatched echo bytes") + } +} + +func TestWriteToV2(t *testing.T) { + config := DefaultConfig() + config.Version = 2 + const N = 1 << 20 + // server + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + session, _ := Server(conn, config) + for { + if stream, err := session.AcceptStream(); err == nil { + go func(s io.ReadWriteCloser) { + numBytes := 0 + buf := make([]byte, 65536) + for { + n, err := s.Read(buf) + if err != nil { + return + } + s.Write(buf[:n]) + numBytes += n + + if numBytes == N { + s.Close() + return + } + } + }(stream) + } else { + return + } + } + }() + + addr := ln.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // client + session, _ := Client(conn, config) + stream, _ := session.OpenStream() + sndbuf := make([]byte, N) + for i := range sndbuf { + sndbuf[i] = byte(rand.Int()) + } + + go stream.Write(sndbuf) + + var rcvbuf bytes.Buffer + nw, ew := stream.WriteTo(&rcvbuf) + if ew != io.EOF { + t.Fatal(ew) + } + + if nw != N { + t.Fatal("WriteTo nw mismatch", nw) + } + + if bytes.Compare(sndbuf, rcvbuf.Bytes()) != 0 { + t.Fatal("mismatched echo bytes") + } +} + +func TestGetDieCh(t *testing.T) { + cs, ss, err := getSmuxStreamPair() + if err != nil { + t.Fatal(err) + } + defer ss.Close() + dieCh := ss.GetDieCh() + go func() { + select { + case <-dieCh: + case <-time.Tick(time.Second): + t.Fatal("wait die chan timeout") + } + }() + cs.Close() +} + +func TestSpeed(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + t.Log(stream.LocalAddr(), stream.RemoteAddr()) + + start := time.Now() + var wg sync.WaitGroup + wg.Add(1) + go func() { + buf := make([]byte, 1024*1024) + nrecv := 0 + for { + n, err := stream.Read(buf) + if err != nil { + t.Error(err) + break + } else { + nrecv += n + if nrecv == 4096*4096 { + break + } + } + } + stream.Close() + t.Log("time for 16MB rtt", time.Since(start)) + wg.Done() + }() + msg := make([]byte, 8192) + for i := 0; i < 2048; i++ { + stream.Write(msg) + } + wg.Wait() + session.Close() +} + +func TestParallel(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + + par := 1000 + messages := 100 + var wg sync.WaitGroup + wg.Add(par) + for i := 0; i < par; i++ { + stream, _ := session.OpenStream() + go func(s *Stream) { + buf := make([]byte, 20) + for j := 0; j < messages; j++ { + msg := fmt.Sprintf("hello%v", j) + s.Write([]byte(msg)) + if _, err := s.Read(buf); err != nil { + break + } + } + s.Close() + wg.Done() + }(stream) + } + t.Log("created", session.NumStreams(), "streams") + wg.Wait() + session.Close() +} + +func TestParallelV2(t *testing.T) { + config := DefaultConfig() + config.Version = 2 + _, stop, cli, err := setupServerV2(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, config) + + par := 1000 + messages := 100 + var wg sync.WaitGroup + wg.Add(par) + for i := 0; i < par; i++ { + stream, _ := session.OpenStream() + go func(s *Stream) { + buf := make([]byte, 20) + for j := 0; j < messages; j++ { + msg := fmt.Sprintf("hello%v", j) + s.Write([]byte(msg)) + if _, err := s.Read(buf); err != nil { + break + } + } + s.Close() + wg.Done() + }(stream) + } + t.Log("created", session.NumStreams(), "streams") + wg.Wait() + session.Close() +} + +func TestCloseThenOpen(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + session.Close() + if _, err := session.OpenStream(); err == nil { + t.Fatal("opened after close") + } +} + +func TestSessionDoubleClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + session.Close() + if err := session.Close(); err == nil { + t.Fatal("session double close doesn't return error") + } +} + +func TestStreamDoubleClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + stream.Close() + if err := stream.Close(); err == nil { + t.Fatal("stream double close doesn't return error") + } + session.Close() +} + +func TestConcurrentClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + numStreams := 100 + streams := make([]*Stream, 0, numStreams) + var wg sync.WaitGroup + wg.Add(numStreams) + for i := 0; i < 100; i++ { + stream, _ := session.OpenStream() + streams = append(streams, stream) + } + for _, s := range streams { + stream := s + go func() { + stream.Close() + wg.Done() + }() + } + session.Close() + wg.Wait() +} + +func TestTinyReadBuffer(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + const N = 100 + tinybuf := make([]byte, 6) + var sent string + var received string + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + sent += msg + nsent, err := stream.Write([]byte(msg)) + if err != nil { + t.Fatal("cannot write") + } + nrecv := 0 + for nrecv < nsent { + if n, err := stream.Read(tinybuf); err == nil { + nrecv += n + received += string(tinybuf[:n]) + } else { + t.Fatal("cannot read with tiny buffer") + } + } + } + + if sent != received { + t.Fatal("data mimatch") + } + session.Close() +} + +func TestIsClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + session.Close() + if !session.IsClosed() { + t.Fatal("still open after close") + } +} + +func TestKeepAliveTimeout(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + ln.Accept() + }() + + cli, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + config := DefaultConfig() + config.KeepAliveInterval = time.Second + config.KeepAliveTimeout = 2 * time.Second + session, _ := Client(cli, config) + time.Sleep(3 * time.Second) + if !session.IsClosed() { + t.Fatal("keepalive-timeout failed") + } +} + +type blockWriteConn struct { + net.Conn +} + +func (c *blockWriteConn) Write(b []byte) (n int, err error) { + forever := time.Hour * 24 + time.Sleep(forever) + return c.Conn.Write(b) +} + +func TestKeepAliveBlockWriteTimeout(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + ln.Accept() + }() + + cli, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + //when writeFrame block, keepalive in old version never timeout + blockWriteCli := &blockWriteConn{cli} + + config := DefaultConfig() + config.KeepAliveInterval = time.Second + config.KeepAliveTimeout = 2 * time.Second + session, _ := Client(blockWriteCli, config) + time.Sleep(3 * time.Second) + if !session.IsClosed() { + t.Fatal("keepalive-timeout failed") + } +} + +func TestServerEcho(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + err := func() error { + conn, err := ln.Accept() + if err != nil { + return err + } + defer conn.Close() + session, err := Server(conn, nil) + if err != nil { + return err + } + defer session.Close() + buf := make([]byte, 10) + stream, err := session.OpenStream() + if err != nil { + return err + } + defer stream.Close() + for i := 0; i < 100; i++ { + msg := fmt.Sprintf("hello%v", i) + stream.Write([]byte(msg)) + n, err := stream.Read(buf) + if err != nil { + return err + } + if got := string(buf[:n]); got != msg { + return fmt.Errorf("got: %q, want: %q", got, msg) + } + } + return nil + }() + if err != nil { + t.Error(err) + } + }() + + cli, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + if session, err := Client(cli, nil); err == nil { + if stream, err := session.AcceptStream(); err == nil { + buf := make([]byte, 65536) + for { + n, err := stream.Read(buf) + if err != nil { + break + } + stream.Write(buf[:n]) + } + } else { + t.Fatal(err) + } + } else { + t.Fatal(err) + } +} + +func TestSendWithoutRecv(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + const N = 100 + for i := 0; i < N; i++ { + msg := fmt.Sprintf("hello%v", i) + stream.Write([]byte(msg)) + } + buf := make([]byte, 1) + if _, err := stream.Read(buf); err != nil { + t.Fatal(err) + } + stream.Close() +} + +func TestWriteAfterClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + stream.Close() + if _, err := stream.Write([]byte("write after close")); err == nil { + t.Fatal("write after close failed") + } +} + +func TestReadStreamAfterSessionClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + session.Close() + buf := make([]byte, 10) + if _, err := stream.Read(buf); err != nil { + t.Log(err) + } else { + t.Fatal("read stream after session close succeeded") + } +} + +func TestWriteStreamAfterConnectionClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + session.conn.Close() + if _, err := stream.Write([]byte("write after connection close")); err == nil { + t.Fatal("write after connection close failed") + } +} + +func TestNumStreamAfterClose(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + if _, err := session.OpenStream(); err == nil { + if session.NumStreams() != 1 { + t.Fatal("wrong number of streams after opened") + } + session.Close() + if session.NumStreams() != 0 { + t.Fatal("wrong number of streams after session closed") + } + } else { + t.Fatal(err) + } + cli.Close() +} + +func TestRandomFrame(t *testing.T) { + addr, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + // pure random + session, _ := Client(cli, nil) + for i := 0; i < 100; i++ { + rnd := make([]byte, rand.Uint32()%1024) + io.ReadFull(crand.Reader, rnd) + session.conn.Write(rnd) + } + cli.Close() + + // double syn + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + for i := 0; i < 100; i++ { + f := newFrame(1, cmdSYN, 1000) + session.writeFrame(f) + } + cli.Close() + + // random cmds + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} + session, _ = Client(cli, nil) + for i := 0; i < 100; i++ { + f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) + session.writeFrame(f) + } + cli.Close() + + // random cmds & sids + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + for i := 0; i < 100; i++ { + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + session.writeFrame(f) + } + cli.Close() + + // random version + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + for i := 0; i < 100; i++ { + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + f.ver = byte(rand.Uint32()) + session.writeFrame(f) + } + cli.Close() + + // incorrect size + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + rnd := make([]byte, rand.Uint32()%1024) + io.ReadFull(crand.Reader, rnd) + f.data = rnd + + buf := make([]byte, headerSize+len(f.data)) + buf[0] = f.ver + buf[1] = f.cmd + binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size + binary.LittleEndian.PutUint32(buf[4:], f.sid) + copy(buf[headerSize:], f.data) + + session.conn.Write(buf) + cli.Close() + + // writeFrame after die + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + //close first + session.Close() + for i := 0; i < 100; i++ { + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + session.writeFrame(f) + } +} + +func TestWriteFrameInternal(t *testing.T) { + addr, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + // pure random + session, _ := Client(cli, nil) + for i := 0; i < 100; i++ { + rnd := make([]byte, rand.Uint32()%1024) + io.ReadFull(crand.Reader, rnd) + session.conn.Write(rnd) + } + cli.Close() + + // writeFrame after die + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + session, _ = Client(cli, nil) + //close first + session.Close() + for i := 0; i < 100; i++ { + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) + } + + // random cmds + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP} + session, _ = Client(cli, nil) + for i := 0; i < 100; i++ { + f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) + session.writeFrameInternal(f, time.After(session.config.KeepAliveTimeout), 0) + } + //deadline occur + { + c := make(chan time.Time) + close(c) + f := newFrame(1, allcmds[rand.Int()%len(allcmds)], rand.Uint32()) + _, err := session.writeFrameInternal(f, c, 0) + if !strings.Contains(err.Error(), "timeout") { + t.Fatal("write frame with deadline failed", err) + } + } + cli.Close() + + { + cli, err = net.Dial("tcp", addr) + if err != nil { + t.Fatal(err) + } + config := DefaultConfig() + config.KeepAliveInterval = time.Second + config.KeepAliveTimeout = 2 * time.Second + session, _ = Client(&blockWriteConn{cli}, config) + f := newFrame(1, byte(rand.Uint32()), rand.Uint32()) + c := make(chan time.Time) + go func() { + //die first, deadline second, better for coverage + time.Sleep(time.Second) + session.Close() + time.Sleep(time.Second) + close(c) + }() + _, err = session.writeFrameInternal(f, c, 0) + if !strings.Contains(err.Error(), "closed pipe") { + t.Fatal("write frame with to closed conn failed", err) + } + } +} + +func TestReadDeadline(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + const N = 100 + buf := make([]byte, 10) + var readErr error + for i := 0; i < N; i++ { + stream.SetReadDeadline(time.Now().Add(-1 * time.Minute)) + if _, readErr = stream.Read(buf); readErr != nil { + break + } + } + if readErr != nil { + if !strings.Contains(readErr.Error(), "timeout") { + t.Fatalf("Wrong error: %v", readErr) + } + } else { + t.Fatal("No error when reading with past deadline") + } + session.Close() +} + +func TestWriteDeadline(t *testing.T) { + _, stop, cli, err := setupServer(t) + if err != nil { + t.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + stream, _ := session.OpenStream() + buf := make([]byte, 10) + var writeErr error + for { + stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute)) + if _, writeErr = stream.Write(buf); writeErr != nil { + if !strings.Contains(writeErr.Error(), "timeout") { + t.Fatalf("Wrong error: %v", writeErr) + } + break + } + } + session.Close() +} + +func BenchmarkAcceptClose(b *testing.B) { + _, stop, cli, err := setupServer(b) + if err != nil { + b.Fatal(err) + } + defer stop() + session, _ := Client(cli, nil) + for i := 0; i < b.N; i++ { + if stream, err := session.OpenStream(); err == nil { + stream.Close() + } else { + b.Fatal(err) + } + } +} +func BenchmarkConnSmux(b *testing.B) { + cs, ss, err := getSmuxStreamPair() + if err != nil { + b.Fatal(err) + } + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func BenchmarkConnTCP(b *testing.B) { + cs, ss, err := getTCPConnectionPair() + if err != nil { + b.Fatal(err) + } + defer cs.Close() + defer ss.Close() + bench(b, cs, ss) +} + +func getSmuxStreamPair() (*Stream, *Stream, error) { + c1, c2, err := getTCPConnectionPair() + if err != nil { + return nil, nil, err + } + + s, err := Server(c2, nil) + if err != nil { + return nil, nil, err + } + c, err := Client(c1, nil) + if err != nil { + return nil, nil, err + } + var ss *Stream + done := make(chan error) + go func() { + var rerr error + ss, rerr = s.AcceptStream() + done <- rerr + close(done) + }() + cs, err := c.OpenStream() + if err != nil { + return nil, nil, err + } + err = <-done + if err != nil { + return nil, nil, err + } + + return cs, ss, nil +} + +func getTCPConnectionPair() (net.Conn, net.Conn, error) { + lst, err := net.Listen("tcp", "localhost:0") + if err != nil { + return nil, nil, err + } + defer lst.Close() + + var conn0 net.Conn + var err0 error + done := make(chan struct{}) + go func() { + conn0, err0 = lst.Accept() + close(done) + }() + + conn1, err := net.Dial("tcp", lst.Addr().String()) + if err != nil { + return nil, nil, err + } + + <-done + if err0 != nil { + return nil, nil, err0 + } + return conn0, conn1, nil +} + +func bench(b *testing.B, rd io.Reader, wr io.Writer) { + buf := make([]byte, 128*1024) + buf2 := make([]byte, 128*1024) + b.SetBytes(128 * 1024) + b.ResetTimer() + b.ReportAllocs() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + count := 0 + for { + n, _ := rd.Read(buf2) + count += n + if count == 128*1024*b.N { + return + } + } + }() + for i := 0; i < b.N; i++ { + wr.Write(buf) + } + wg.Wait() +} diff --git a/proxy/protocol/smux/shaper.go b/proxy/protocol/smux/shaper.go new file mode 100644 index 0000000..be03406 --- /dev/null +++ b/proxy/protocol/smux/shaper.go @@ -0,0 +1,16 @@ +package smux + +type shaperHeap []writeRequest + +func (h shaperHeap) Len() int { return len(h) } +func (h shaperHeap) Less(i, j int) bool { return h[i].prio < h[j].prio } +func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) } + +func (h *shaperHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} diff --git a/proxy/protocol/smux/shaper_test.go b/proxy/protocol/smux/shaper_test.go new file mode 100644 index 0000000..6f6d265 --- /dev/null +++ b/proxy/protocol/smux/shaper_test.go @@ -0,0 +1,30 @@ +package smux + +import ( + "container/heap" + "testing" +) + +func TestShaper(t *testing.T) { + w1 := writeRequest{prio: 10} + w2 := writeRequest{prio: 10} + w3 := writeRequest{prio: 20} + w4 := writeRequest{prio: 100} + + var reqs shaperHeap + heap.Push(&reqs, w4) + heap.Push(&reqs, w3) + heap.Push(&reqs, w2) + heap.Push(&reqs, w1) + + var lastPrio uint64 + for len(reqs) > 0 { + w := heap.Pop(&reqs).(writeRequest) + if w.prio < lastPrio { + t.Fatal("incorrect shaper priority") + } + + t.Log("prio:", w.prio) + lastPrio = w.prio + } +} diff --git a/proxy/protocol/smux/stream.go b/proxy/protocol/smux/stream.go new file mode 100644 index 0000000..f0ce104 --- /dev/null +++ b/proxy/protocol/smux/stream.go @@ -0,0 +1,545 @@ +package smux + +import ( + "encoding/binary" + "io" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/nadoo/glider/pool" +) + +// Stream implements net.Conn +type Stream struct { + id uint32 + sess *Session + + buffers [][]byte + heads [][]byte // slice heads kept for recycle + + bufferLock sync.Mutex + frameSize int + + // notify a read event + chReadEvent chan struct{} + + // flag the stream has closed + die chan struct{} + dieOnce sync.Once + + // FIN command + chFinEvent chan struct{} + finEventOnce sync.Once + + // deadlines + readDeadline atomic.Value + writeDeadline atomic.Value + + // per stream sliding window control + numRead uint32 // number of consumed bytes + numWritten uint32 // count num of bytes written + incr uint32 // counting for sending + + // UPD command + peerConsumed uint32 // num of bytes the peer has consumed + peerWindow uint32 // peer window, initialized to 256KB, updated by peer + chUpdate chan struct{} // notify of remote data consuming and window update +} + +// newStream initiates a Stream struct +func newStream(id uint32, frameSize int, sess *Session) *Stream { + s := new(Stream) + s.id = id + s.chReadEvent = make(chan struct{}, 1) + s.chUpdate = make(chan struct{}, 1) + s.frameSize = frameSize + s.sess = sess + s.die = make(chan struct{}) + s.chFinEvent = make(chan struct{}) + s.peerWindow = initialPeerWindow // set to initial window size + return s +} + +// ID returns the unique stream ID. +func (s *Stream) ID() uint32 { + return s.id +} + +// Read implements net.Conn +func (s *Stream) Read(b []byte) (n int, err error) { + for { + n, err = s.tryRead(b) + if err == ErrWouldBlock { + if ew := s.waitRead(); ew != nil { + return 0, ew + } + } else { + return n, err + } + } +} + +// tryRead is the nonblocking version of Read +func (s *Stream) tryRead(b []byte) (n int, err error) { + if s.sess.config.Version == 2 { + return s.tryReadv2(b) + } + + if len(b) == 0 { + return 0, nil + } + + s.bufferLock.Lock() + if len(s.buffers) > 0 { + n = copy(b, s.buffers[0]) + s.buffers[0] = s.buffers[0][n:] + if len(s.buffers[0]) == 0 { + s.buffers[0] = nil + s.buffers = s.buffers[1:] + // full recycle + pool.PutBuffer(s.heads[0]) + s.heads = s.heads[1:] + } + } + s.bufferLock.Unlock() + + if n > 0 { + s.sess.returnTokens(n) + return n, nil + } + + select { + case <-s.die: + return 0, io.EOF + default: + return 0, ErrWouldBlock + } +} + +func (s *Stream) tryReadv2(b []byte) (n int, err error) { + if len(b) == 0 { + return 0, nil + } + + var notifyConsumed uint32 + s.bufferLock.Lock() + if len(s.buffers) > 0 { + n = copy(b, s.buffers[0]) + s.buffers[0] = s.buffers[0][n:] + if len(s.buffers[0]) == 0 { + s.buffers[0] = nil + s.buffers = s.buffers[1:] + // full recycle + pool.PutBuffer(s.heads[0]) + s.heads = s.heads[1:] + } + } + + // in an ideal environment: + // if more than half of buffer has consumed, send read ack to peer + // based on round-trip time of ACK, continous flowing data + // won't slow down because of waiting for ACK, as long as the + // consumer keeps on reading data + // s.numRead == n also notify window at the first read + s.numRead += uint32(n) + s.incr += uint32(n) + if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) { + notifyConsumed = s.numRead + s.incr = 0 + } + s.bufferLock.Unlock() + + if n > 0 { + s.sess.returnTokens(n) + if notifyConsumed > 0 { + err := s.sendWindowUpdate(notifyConsumed) + return n, err + } else { + return n, nil + } + } + + select { + case <-s.die: + return 0, io.EOF + default: + return 0, ErrWouldBlock + } +} + +// WriteTo implements io.WriteTo +func (s *Stream) WriteTo(w io.Writer) (n int64, err error) { + if s.sess.config.Version == 2 { + return s.writeTov2(w) + } + + for { + var buf []byte + s.bufferLock.Lock() + if len(s.buffers) > 0 { + buf = s.buffers[0] + s.buffers = s.buffers[1:] + s.heads = s.heads[1:] + } + s.bufferLock.Unlock() + + if buf != nil { + nw, ew := w.Write(buf) + s.sess.returnTokens(len(buf)) + pool.PutBuffer(buf) + if nw > 0 { + n += int64(nw) + } + + if ew != nil { + return n, ew + } + } else if ew := s.waitRead(); ew != nil { + return n, ew + } + } +} + +func (s *Stream) writeTov2(w io.Writer) (n int64, err error) { + for { + var notifyConsumed uint32 + var buf []byte + s.bufferLock.Lock() + if len(s.buffers) > 0 { + buf = s.buffers[0] + s.buffers = s.buffers[1:] + s.heads = s.heads[1:] + } + s.numRead += uint32(len(buf)) + s.incr += uint32(len(buf)) + if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) { + notifyConsumed = s.numRead + s.incr = 0 + } + s.bufferLock.Unlock() + + if buf != nil { + nw, ew := w.Write(buf) + s.sess.returnTokens(len(buf)) + pool.PutBuffer(buf) + if nw > 0 { + n += int64(nw) + } + + if ew != nil { + return n, ew + } + + if notifyConsumed > 0 { + if err := s.sendWindowUpdate(notifyConsumed); err != nil { + return n, err + } + } + } else if ew := s.waitRead(); ew != nil { + return n, ew + } + } +} + +func (s *Stream) sendWindowUpdate(consumed uint32) error { + var timer *time.Timer + var deadline <-chan time.Time + if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { + timer = time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id) + var hdr updHeader + binary.LittleEndian.PutUint32(hdr[:], consumed) + binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer)) + frame.data = hdr[:] + _, err := s.sess.writeFrameInternal(frame, deadline, 0) + return err +} + +func (s *Stream) waitRead() error { + var timer *time.Timer + var deadline <-chan time.Time + if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() { + timer = time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case <-s.chReadEvent: + return nil + case <-s.chFinEvent: + return io.EOF + case <-s.sess.chSocketReadError: + return s.sess.socketReadError.Load().(error) + case <-s.sess.chProtoError: + return s.sess.protoError.Load().(error) + case <-deadline: + return ErrTimeout + case <-s.die: + return io.ErrClosedPipe + } + +} + +// Write implements net.Conn +// +// Note that the behavior when multiple goroutines write concurrently is not deterministic, +// frames may interleave in random way. +func (s *Stream) Write(b []byte) (n int, err error) { + if s.sess.config.Version == 2 { + return s.writeV2(b) + } + + var deadline <-chan time.Time + if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + // check if stream has closed + select { + case <-s.die: + return 0, io.ErrClosedPipe + default: + } + + // frame split and transmit + sent := 0 + frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) + bts := b + for len(bts) > 0 { + sz := len(bts) + if sz > s.frameSize { + sz = s.frameSize + } + frame.data = bts[:sz] + bts = bts[sz:] + n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten)) + s.numWritten++ + sent += n + if err != nil { + return sent, err + } + } + + return sent, nil +} + +func (s *Stream) writeV2(b []byte) (n int, err error) { + // check empty input + if len(b) == 0 { + return 0, nil + } + + // check if stream has closed + select { + case <-s.die: + return 0, io.ErrClosedPipe + default: + } + + // create write deadline timer + var deadline <-chan time.Time + if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + // frame split and transmit process + sent := 0 + frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id) + + for { + // per stream sliding window control + // [.... [consumed... numWritten] ... win... ] + // [.... [consumed...................+rmtwnd]] + var bts []byte + // note: + // even if uint32 overflow, this math still works: + // eg1: uint32(0) - uint32(math.MaxUint32) = 1 + // eg2: int32(uint32(0) - uint32(1)) = -1 + // security check for misbehavior + inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed)) + if inflight < 0 { + return 0, ErrConsumed + } + + win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight + if win > 0 { + if win > int32(len(b)) { + bts = b + b = nil + } else { + bts = b[:win] + b = b[win:] + } + + for len(bts) > 0 { + sz := len(bts) + if sz > s.frameSize { + sz = s.frameSize + } + frame.data = bts[:sz] + bts = bts[sz:] + n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten))) + atomic.AddUint32(&s.numWritten, uint32(sz)) + sent += n + if err != nil { + return sent, err + } + } + } + + // if there is any data remaining to be sent + // wait until stream closes, window changes or deadline reached + // this blocking behavior will inform upper layer to do flow control + if len(b) > 0 { + select { + case <-s.chFinEvent: // if fin arrived, future window update is impossible + return 0, io.EOF + case <-s.die: + return sent, io.ErrClosedPipe + case <-deadline: + return sent, ErrTimeout + case <-s.sess.chSocketWriteError: + return sent, s.sess.socketWriteError.Load().(error) + case <-s.chUpdate: + continue + } + } else { + return sent, nil + } + } +} + +// Close implements net.Conn +func (s *Stream) Close() error { + var once bool + var err error + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + + if once { + _, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id)) + s.sess.streamClosed(s.id) + return err + } else { + return io.ErrClosedPipe + } +} + +// GetDieCh returns a readonly chan which can be readable +// when the stream is to be closed. +func (s *Stream) GetDieCh() <-chan struct{} { + return s.die +} + +// SetReadDeadline sets the read deadline as defined by +// net.Conn.SetReadDeadline. +// A zero time value disables the deadline. +func (s *Stream) SetReadDeadline(t time.Time) error { + s.readDeadline.Store(t) + s.notifyReadEvent() + return nil +} + +// SetWriteDeadline sets the write deadline as defined by +// net.Conn.SetWriteDeadline. +// A zero time value disables the deadline. +func (s *Stream) SetWriteDeadline(t time.Time) error { + s.writeDeadline.Store(t) + return nil +} + +// SetDeadline sets both read and write deadlines as defined by +// net.Conn.SetDeadline. +// A zero time value disables the deadlines. +func (s *Stream) SetDeadline(t time.Time) error { + if err := s.SetReadDeadline(t); err != nil { + return err + } + if err := s.SetWriteDeadline(t); err != nil { + return err + } + return nil +} + +// session closes +func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) } + +// LocalAddr satisfies net.Conn interface +func (s *Stream) LocalAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + LocalAddr() net.Addr + }); ok { + return ts.LocalAddr() + } + return nil +} + +// RemoteAddr satisfies net.Conn interface +func (s *Stream) RemoteAddr() net.Addr { + if ts, ok := s.sess.conn.(interface { + RemoteAddr() net.Addr + }); ok { + return ts.RemoteAddr() + } + return nil +} + +// pushBytes append buf to buffers +func (s *Stream) pushBytes(buf []byte) (written int, err error) { + s.bufferLock.Lock() + s.buffers = append(s.buffers, buf) + s.heads = append(s.heads, buf) + s.bufferLock.Unlock() + return +} + +// recycleTokens transform remaining bytes to tokens(will truncate buffer) +func (s *Stream) recycleTokens() (n int) { + s.bufferLock.Lock() + for k := range s.buffers { + n += len(s.buffers[k]) + pool.PutBuffer(s.heads[k]) + } + s.buffers = nil + s.heads = nil + s.bufferLock.Unlock() + return +} + +// notify read event +func (s *Stream) notifyReadEvent() { + select { + case s.chReadEvent <- struct{}{}: + default: + } +} + +// update command +func (s *Stream) update(consumed uint32, window uint32) { + atomic.StoreUint32(&s.peerConsumed, consumed) + atomic.StoreUint32(&s.peerWindow, window) + select { + case s.chUpdate <- struct{}{}: + default: + } +} + +// mark this stream has been closed in protocol +func (s *Stream) fin() { + s.finEventOnce.Do(func() { + close(s.chFinEvent) + }) +} diff --git a/proxy/smux/client.go b/proxy/smux/client.go new file mode 100644 index 0000000..a586364 --- /dev/null +++ b/proxy/smux/client.go @@ -0,0 +1,76 @@ +package smux + +import ( + "errors" + "net" + "net/url" + + "github.com/nadoo/glider/log" + "github.com/nadoo/glider/proxy" + + "github.com/nadoo/glider/proxy/protocol/smux" +) + +// SmuxClient struct. +type SmuxClient struct { + dialer proxy.Dialer + addr string + session *smux.Session +} + +func init() { + proxy.RegisterDialer("smux", NewSmuxDialer) +} + +// NewSmuxDialer returns a smux dialer. +func NewSmuxDialer(s string, d proxy.Dialer) (proxy.Dialer, error) { + u, err := url.Parse(s) + if err != nil { + log.F("[smux] parse url err: %s", err) + return nil, err + } + + c := &SmuxClient{ + dialer: d, + addr: u.Host, + } + + return c, nil +} + +// Addr returns forwarder's address. +func (s *SmuxClient) Addr() string { + if s.addr == "" { + return s.dialer.Addr() + } + return s.addr +} + +// Dial connects to the address addr on the network net via the proxy. +func (s *SmuxClient) Dial(network, addr string) (net.Conn, error) { + if s.session != nil { + if c, err := s.session.OpenStream(); err == nil { + return c, err + } + s.session.Close() + } + if err := s.initConn(); err != nil { + return nil, err + } + return s.session.OpenStream() +} + +// DialUDP connects to the given address via the proxy. +func (s *SmuxClient) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) { + return nil, nil, errors.New("smux client does not support udp now") +} + +func (s *SmuxClient) initConn() error { + conn, err := s.dialer.Dial("tcp", s.addr) + if err != nil { + log.F("[smux] dial to %s error: %s", s.addr, err) + return err + } + s.session, err = smux.Client(conn, nil) + return err +} diff --git a/proxy/smux/server.go b/proxy/smux/server.go new file mode 100644 index 0000000..3adbd9c --- /dev/null +++ b/proxy/smux/server.go @@ -0,0 +1,98 @@ +package smux + +import ( + "net" + "net/url" + "strings" + + "github.com/nadoo/glider/log" + "github.com/nadoo/glider/proxy" + + "github.com/nadoo/glider/proxy/protocol/smux" +) + +// SmuxServer struct. +type SmuxServer struct { + proxy proxy.Proxy + addr string + server proxy.Server +} + +func init() { + proxy.RegisterServer("smux", NewSmuxServer) +} + +// NewSmuxServer returns a smux transport layer before the real server. +func NewSmuxServer(s string, p proxy.Proxy) (proxy.Server, error) { + transport := strings.Split(s, ",") + + u, err := url.Parse(transport[0]) + if err != nil { + log.F("[smux] parse url err: %s", err) + return nil, err + } + + m := &SmuxServer{ + proxy: p, + addr: u.Host, + } + + if len(transport) > 1 { + m.server, err = proxy.ServerFromURL(transport[1], p) + if err != nil { + return nil, err + } + } + + return m, nil +} + +// ListenAndServe listens on server's addr and serves connections. +func (s *SmuxServer) ListenAndServe() { + l, err := net.Listen("tcp", s.addr) + if err != nil { + log.F("[smux] failed to listen on %s: %v", s.addr, err) + return + } + defer l.Close() + + log.F("[smux] listening mux on %s", s.addr) + + for { + c, err := l.Accept() + if err != nil { + log.F("[smux] failed to accept: %v", err) + continue + } + + go s.Serve(c) + } +} + +// Serve serves a connection. +func (s *SmuxServer) Serve(c net.Conn) { + // we know the internal server will close the connection after serve + // defer c.Close() + + session, err := smux.Server(c, nil) + if err != nil { + log.F("[smux] failed to create session: %v", err) + return + } + + for { + // Accept a stream + stream, err := session.AcceptStream() + if err != nil { + session.Close() + break + } + go s.ServeStream(stream) + } +} + +func (s *SmuxServer) ServeStream(c *smux.Stream) { + if s.server != nil { + s.server.Serve(c) + } +}