mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 01:15:41 +08:00
proxy: added smux support
This commit is contained in:
parent
34a053b875
commit
dbd2e04521
37
README.md
37
README.md
@ -63,6 +63,7 @@ we can set up local listeners as proxy servers, and forward requests to internet
|
||||
|TLS |√| |√| |transport client & server
|
||||
|KCP | |√|√| |transport client & server
|
||||
|Unix |√|√|√|√|transport client & server
|
||||
|Smux |√| |√| |transport client & server
|
||||
|Websocket |√| |√| |transport client & server
|
||||
|Simple-Obfs | | |√| |transport client only
|
||||
|Redir |√| | | |linux only
|
||||
@ -96,7 +97,7 @@ glider -h
|
||||
<summary>click to see details</summary>
|
||||
|
||||
```bash
|
||||
glider 0.13.0 usage:
|
||||
glider 0.14.0 usage:
|
||||
-check string
|
||||
check=tcp[://HOST:PORT]: tcp port connect check
|
||||
check=http://HOST[:PORT][/URI][#expect=STRING_IN_RESP_LINE]
|
||||
@ -152,27 +153,10 @@ glider 0.13.0 usage:
|
||||
forward strategy, default: rr (default "rr")
|
||||
-verbose
|
||||
verbose mode
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
run:
|
||||
```bash
|
||||
glider -config CONFIGPATH
|
||||
```
|
||||
```bash
|
||||
glider -verbose -listen :8443 -forward SCHEME://HOST:PORT
|
||||
```
|
||||
|
||||
#### Schemes
|
||||
|
||||
<details>
|
||||
<summary>click to see details</summary>
|
||||
|
||||
```bash
|
||||
Available schemes:
|
||||
listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp
|
||||
forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs
|
||||
listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp
|
||||
forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs
|
||||
|
||||
Socks5 scheme:
|
||||
socks://[user:pass@]host:port
|
||||
@ -251,6 +235,9 @@ TLS and Websocket with a specified proxy protocol:
|
||||
Unix domain socket scheme:
|
||||
unix://path
|
||||
|
||||
Smux scheme:
|
||||
smux://host:port
|
||||
|
||||
KCP scheme:
|
||||
kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]
|
||||
|
||||
@ -298,16 +285,8 @@ Config file format(see `./glider.conf.example` as an example):
|
||||
KEY=VALUE
|
||||
KEY=VALUE
|
||||
# KEY equals to command line flag name: listen forward strategy...
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
#### Examples
|
||||
|
||||
<details>
|
||||
<summary>click to see details</summary>
|
||||
|
||||
```bash
|
||||
Examples:
|
||||
./glider -config glider.conf
|
||||
-run glider with specified config file.
|
||||
|
||||
|
@ -131,8 +131,8 @@ func usage() {
|
||||
fmt.Fprintf(w, "\n")
|
||||
|
||||
fmt.Fprintf(w, "Available schemes:\n")
|
||||
fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp\n")
|
||||
fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs\n")
|
||||
fmt.Fprintf(w, " listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp\n")
|
||||
fmt.Fprintf(w, " forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs\n")
|
||||
fmt.Fprintf(w, "\n")
|
||||
|
||||
fmt.Fprintf(w, "Socks5 scheme:\n")
|
||||
@ -231,6 +231,10 @@ func usage() {
|
||||
fmt.Fprintf(w, " unix://path\n")
|
||||
fmt.Fprintf(w, "\n")
|
||||
|
||||
fmt.Fprintf(w, "Smux scheme:\n")
|
||||
fmt.Fprintf(w, " smux://host:port\n")
|
||||
fmt.Fprintf(w, "\n")
|
||||
|
||||
fmt.Fprintf(w, "KCP scheme:\n")
|
||||
fmt.Fprintf(w, " kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]\n")
|
||||
fmt.Fprintf(w, "\n")
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
_ "github.com/nadoo/glider/proxy/mixed"
|
||||
_ "github.com/nadoo/glider/proxy/obfs"
|
||||
_ "github.com/nadoo/glider/proxy/reject"
|
||||
_ "github.com/nadoo/glider/proxy/smux"
|
||||
_ "github.com/nadoo/glider/proxy/socks4"
|
||||
_ "github.com/nadoo/glider/proxy/socks5"
|
||||
_ "github.com/nadoo/glider/proxy/ss"
|
||||
|
2
go.mod
2
go.mod
@ -8,7 +8,7 @@ require (
|
||||
github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb
|
||||
github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152
|
||||
github.com/ebfe/rc2 v0.0.0-20131011165748-24b9757f5521 // indirect
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305
|
||||
github.com/klauspost/cpuid/v2 v2.0.6 // indirect
|
||||
github.com/klauspost/reedsolomon v1.9.12 // indirect
|
||||
github.com/mdlayher/raw v0.0.0-20210412142147-51b895745faf // indirect
|
||||
|
4
go.sum
4
go.sum
@ -39,8 +39,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
|
||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
|
||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa h1:ff6iUOGfajZ+T8RBuRtOjYMW7/6aF8iJG+iCp2wQyQ4=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305 h1:DGmCtsdLE6r7tuFH5YHnDoKJU9WXjptF/8Q8iKc41Tk=
|
||||
github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
|
||||
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
|
||||
|
21
proxy/protocol/smux/LICENSE
Normal file
21
proxy/protocol/smux/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2016-2017 Daniel Fu
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
81
proxy/protocol/smux/frame.go
Normal file
81
proxy/protocol/smux/frame.go
Normal file
@ -0,0 +1,81 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const ( // cmds
|
||||
// protocol version 1:
|
||||
cmdSYN byte = iota // stream open
|
||||
cmdFIN // stream close, a.k.a EOF mark
|
||||
cmdPSH // data push
|
||||
cmdNOP // no operation
|
||||
|
||||
// protocol version 2 extra commands
|
||||
// notify bytes consumed by remote peer-end
|
||||
cmdUPD
|
||||
)
|
||||
|
||||
const (
|
||||
// data size of cmdUPD, format:
|
||||
// |4B data consumed(ACK)| 4B window size(WINDOW) |
|
||||
szCmdUPD = 8
|
||||
)
|
||||
|
||||
const (
|
||||
// initial peer window guess, a slow-start
|
||||
initialPeerWindow = 262144
|
||||
)
|
||||
|
||||
const (
|
||||
sizeOfVer = 1
|
||||
sizeOfCmd = 1
|
||||
sizeOfLength = 2
|
||||
sizeOfSid = 4
|
||||
headerSize = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength
|
||||
)
|
||||
|
||||
// Frame defines a packet from or to be multiplexed into a single connection
|
||||
type Frame struct {
|
||||
ver byte
|
||||
cmd byte
|
||||
sid uint32
|
||||
data []byte
|
||||
}
|
||||
|
||||
func newFrame(version byte, cmd byte, sid uint32) Frame {
|
||||
return Frame{ver: version, cmd: cmd, sid: sid}
|
||||
}
|
||||
|
||||
type rawHeader [headerSize]byte
|
||||
|
||||
func (h rawHeader) Version() byte {
|
||||
return h[0]
|
||||
}
|
||||
|
||||
func (h rawHeader) Cmd() byte {
|
||||
return h[1]
|
||||
}
|
||||
|
||||
func (h rawHeader) Length() uint16 {
|
||||
return binary.LittleEndian.Uint16(h[2:])
|
||||
}
|
||||
|
||||
func (h rawHeader) StreamID() uint32 {
|
||||
return binary.LittleEndian.Uint32(h[4:])
|
||||
}
|
||||
|
||||
func (h rawHeader) String() string {
|
||||
return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d",
|
||||
h.Version(), h.Cmd(), h.StreamID(), h.Length())
|
||||
}
|
||||
|
||||
type updHeader [szCmdUPD]byte
|
||||
|
||||
func (h updHeader) Consumed() uint32 {
|
||||
return binary.LittleEndian.Uint32(h[:])
|
||||
}
|
||||
func (h updHeader) Window() uint32 {
|
||||
return binary.LittleEndian.Uint32(h[4:])
|
||||
}
|
110
proxy/protocol/smux/mux.go
Normal file
110
proxy/protocol/smux/mux.go
Normal file
@ -0,0 +1,110 @@
|
||||
// Package smux is a multiplexing library for Golang.
|
||||
//
|
||||
// It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP,
|
||||
// and provides stream-oriented multiplexing over a single channel.
|
||||
package smux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config is used to tune the Smux session
|
||||
type Config struct {
|
||||
// SMUX Protocol version, support 1,2
|
||||
Version int
|
||||
|
||||
// Disabled keepalive
|
||||
KeepAliveDisabled bool
|
||||
|
||||
// KeepAliveInterval is how often to send a NOP command to the remote
|
||||
KeepAliveInterval time.Duration
|
||||
|
||||
// KeepAliveTimeout is how long the session
|
||||
// will be closed if no data has arrived
|
||||
KeepAliveTimeout time.Duration
|
||||
|
||||
// MaxFrameSize is used to control the maximum
|
||||
// frame size to sent to the remote
|
||||
MaxFrameSize int
|
||||
|
||||
// MaxReceiveBuffer is used to control the maximum
|
||||
// number of data in the buffer pool
|
||||
MaxReceiveBuffer int
|
||||
|
||||
// MaxStreamBuffer is used to control the maximum
|
||||
// number of data per stream
|
||||
MaxStreamBuffer int
|
||||
}
|
||||
|
||||
// DefaultConfig is used to return a default configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Version: 1,
|
||||
KeepAliveInterval: 10 * time.Second,
|
||||
KeepAliveTimeout: 30 * time.Second,
|
||||
MaxFrameSize: 32768,
|
||||
MaxReceiveBuffer: 4194304,
|
||||
MaxStreamBuffer: 65536,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifyConfig is used to verify the sanity of configuration
|
||||
func VerifyConfig(config *Config) error {
|
||||
if !(config.Version == 1 || config.Version == 2) {
|
||||
return errors.New("unsupported protocol version")
|
||||
}
|
||||
if !config.KeepAliveDisabled {
|
||||
if config.KeepAliveInterval == 0 {
|
||||
return errors.New("keep-alive interval must be positive")
|
||||
}
|
||||
if config.KeepAliveTimeout < config.KeepAliveInterval {
|
||||
return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval")
|
||||
}
|
||||
}
|
||||
if config.MaxFrameSize <= 0 {
|
||||
return errors.New("max frame size must be positive")
|
||||
}
|
||||
if config.MaxFrameSize > 65535 {
|
||||
return errors.New("max frame size must not be larger than 65535")
|
||||
}
|
||||
if config.MaxReceiveBuffer <= 0 {
|
||||
return errors.New("max receive buffer must be positive")
|
||||
}
|
||||
if config.MaxStreamBuffer <= 0 {
|
||||
return errors.New("max stream buffer must be positive")
|
||||
}
|
||||
if config.MaxStreamBuffer > config.MaxReceiveBuffer {
|
||||
return errors.New("max stream buffer must not be larger than max receive buffer")
|
||||
}
|
||||
if config.MaxStreamBuffer > math.MaxInt32 {
|
||||
return errors.New("max stream buffer cannot be larger than 2147483647")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Server is used to initialize a new server-side connection.
|
||||
func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
if err := VerifyConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSession(config, conn, false), nil
|
||||
}
|
||||
|
||||
// Client is used to initialize a new client-side connection.
|
||||
func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
|
||||
if err := VerifyConfig(config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return newSession(config, conn, true), nil
|
||||
}
|
86
proxy/protocol/smux/mux_test.go
Normal file
86
proxy/protocol/smux/mux_test.go
Normal file
@ -0,0 +1,86 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type buffer struct {
|
||||
bytes.Buffer
|
||||
}
|
||||
|
||||
func (b *buffer) Close() error {
|
||||
b.Buffer.Reset()
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
VerifyConfig(DefaultConfig())
|
||||
|
||||
config := DefaultConfig()
|
||||
config.KeepAliveInterval = 0
|
||||
err := VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.KeepAliveInterval = 10
|
||||
config.KeepAliveTimeout = 5
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.MaxFrameSize = 0
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.MaxFrameSize = 65536
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.MaxReceiveBuffer = 0
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.MaxStreamBuffer = 0
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config = DefaultConfig()
|
||||
config.MaxStreamBuffer = 100
|
||||
config.MaxReceiveBuffer = 99
|
||||
err = VerifyConfig(config)
|
||||
t.Log(err)
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var bts buffer
|
||||
if _, err := Server(&bts, config); err == nil {
|
||||
t.Fatal("server started with wrong config")
|
||||
}
|
||||
|
||||
if _, err := Client(&bts, config); err == nil {
|
||||
t.Fatal("client started with wrong config")
|
||||
}
|
||||
}
|
527
proxy/protocol/smux/session.go
Normal file
527
proxy/protocol/smux/session.go
Normal file
@ -0,0 +1,527 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/nadoo/glider/pool"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultAcceptBacklog = 1024
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidProtocol = errors.New("invalid protocol")
|
||||
ErrConsumed = errors.New("peer consumed more than sent")
|
||||
ErrGoAway = errors.New("stream id overflows, should start a new connection")
|
||||
ErrTimeout = errors.New("timeout")
|
||||
ErrWouldBlock = errors.New("operation would block on IO")
|
||||
)
|
||||
|
||||
type writeRequest struct {
|
||||
prio uint64
|
||||
frame Frame
|
||||
result chan writeResult
|
||||
}
|
||||
|
||||
type writeResult struct {
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
type buffersWriter interface {
|
||||
WriteBuffers(v [][]byte) (n int, err error)
|
||||
}
|
||||
|
||||
// Session defines a multiplexed connection for streams
|
||||
type Session struct {
|
||||
conn io.ReadWriteCloser
|
||||
|
||||
config *Config
|
||||
nextStreamID uint32 // next stream identifier
|
||||
nextStreamIDLock sync.Mutex
|
||||
|
||||
bucket int32 // token bucket
|
||||
bucketNotify chan struct{} // used for waiting for tokens
|
||||
|
||||
streams map[uint32]*Stream // all streams in this session
|
||||
streamLock sync.Mutex // locks streams
|
||||
|
||||
die chan struct{} // flag session has died
|
||||
dieOnce sync.Once
|
||||
|
||||
// socket error handling
|
||||
socketReadError atomic.Value
|
||||
socketWriteError atomic.Value
|
||||
chSocketReadError chan struct{}
|
||||
chSocketWriteError chan struct{}
|
||||
socketReadErrorOnce sync.Once
|
||||
socketWriteErrorOnce sync.Once
|
||||
|
||||
// smux protocol errors
|
||||
protoError atomic.Value
|
||||
chProtoError chan struct{}
|
||||
protoErrorOnce sync.Once
|
||||
|
||||
chAccepts chan *Stream
|
||||
|
||||
dataReady int32 // flag data has arrived
|
||||
|
||||
goAway int32 // flag id exhausted
|
||||
|
||||
deadline atomic.Value
|
||||
|
||||
shaper chan writeRequest // a shaper for writing
|
||||
writes chan writeRequest
|
||||
}
|
||||
|
||||
func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
|
||||
s := new(Session)
|
||||
s.die = make(chan struct{})
|
||||
s.conn = conn
|
||||
s.config = config
|
||||
s.streams = make(map[uint32]*Stream)
|
||||
s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
|
||||
s.bucket = int32(config.MaxReceiveBuffer)
|
||||
s.bucketNotify = make(chan struct{}, 1)
|
||||
s.shaper = make(chan writeRequest)
|
||||
s.writes = make(chan writeRequest)
|
||||
s.chSocketReadError = make(chan struct{})
|
||||
s.chSocketWriteError = make(chan struct{})
|
||||
s.chProtoError = make(chan struct{})
|
||||
|
||||
if client {
|
||||
s.nextStreamID = 1
|
||||
} else {
|
||||
s.nextStreamID = 0
|
||||
}
|
||||
|
||||
go s.shaperLoop()
|
||||
go s.recvLoop()
|
||||
go s.sendLoop()
|
||||
if !config.KeepAliveDisabled {
|
||||
go s.keepalive()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// OpenStream is used to create a new stream
|
||||
func (s *Session) OpenStream() (*Stream, error) {
|
||||
if s.IsClosed() {
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
|
||||
// generate stream id
|
||||
s.nextStreamIDLock.Lock()
|
||||
if s.goAway > 0 {
|
||||
s.nextStreamIDLock.Unlock()
|
||||
return nil, ErrGoAway
|
||||
}
|
||||
|
||||
s.nextStreamID += 2
|
||||
sid := s.nextStreamID
|
||||
if sid == sid%2 { // stream-id overflows
|
||||
s.goAway = 1
|
||||
s.nextStreamIDLock.Unlock()
|
||||
return nil, ErrGoAway
|
||||
}
|
||||
s.nextStreamIDLock.Unlock()
|
||||
|
||||
stream := newStream(sid, s.config.MaxFrameSize, s)
|
||||
|
||||
if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.streamLock.Lock()
|
||||
defer s.streamLock.Unlock()
|
||||
select {
|
||||
case <-s.chSocketReadError:
|
||||
return nil, s.socketReadError.Load().(error)
|
||||
case <-s.chSocketWriteError:
|
||||
return nil, s.socketWriteError.Load().(error)
|
||||
case <-s.die:
|
||||
return nil, io.ErrClosedPipe
|
||||
default:
|
||||
s.streams[sid] = stream
|
||||
return stream, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Open returns a generic ReadWriteCloser
|
||||
func (s *Session) Open() (io.ReadWriteCloser, error) {
|
||||
return s.OpenStream()
|
||||
}
|
||||
|
||||
// AcceptStream is used to block until the next available stream
|
||||
// is ready to be accepted.
|
||||
func (s *Session) AcceptStream() (*Stream, error) {
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer := time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
select {
|
||||
case stream := <-s.chAccepts:
|
||||
return stream, nil
|
||||
case <-deadline:
|
||||
return nil, ErrTimeout
|
||||
case <-s.chSocketReadError:
|
||||
return nil, s.socketReadError.Load().(error)
|
||||
case <-s.chProtoError:
|
||||
return nil, s.protoError.Load().(error)
|
||||
case <-s.die:
|
||||
return nil, io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
// Accept Returns a generic ReadWriteCloser instead of smux.Stream
|
||||
func (s *Session) Accept() (io.ReadWriteCloser, error) {
|
||||
return s.AcceptStream()
|
||||
}
|
||||
|
||||
// Close is used to close the session and all streams.
|
||||
func (s *Session) Close() error {
|
||||
var once bool
|
||||
s.dieOnce.Do(func() {
|
||||
close(s.die)
|
||||
once = true
|
||||
})
|
||||
|
||||
if once {
|
||||
s.streamLock.Lock()
|
||||
for k := range s.streams {
|
||||
s.streams[k].sessionClose()
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
return s.conn.Close()
|
||||
} else {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
// notifyBucket notifies recvLoop that bucket is available
|
||||
func (s *Session) notifyBucket() {
|
||||
select {
|
||||
case s.bucketNotify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) notifyReadError(err error) {
|
||||
s.socketReadErrorOnce.Do(func() {
|
||||
s.socketReadError.Store(err)
|
||||
close(s.chSocketReadError)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) notifyWriteError(err error) {
|
||||
s.socketWriteErrorOnce.Do(func() {
|
||||
s.socketWriteError.Store(err)
|
||||
close(s.chSocketWriteError)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Session) notifyProtoError(err error) {
|
||||
s.protoErrorOnce.Do(func() {
|
||||
s.protoError.Store(err)
|
||||
close(s.chProtoError)
|
||||
})
|
||||
}
|
||||
|
||||
// IsClosed does a safe check to see if we have shutdown
|
||||
func (s *Session) IsClosed() bool {
|
||||
select {
|
||||
case <-s.die:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// NumStreams returns the number of currently open streams
|
||||
func (s *Session) NumStreams() int {
|
||||
if s.IsClosed() {
|
||||
return 0
|
||||
}
|
||||
s.streamLock.Lock()
|
||||
defer s.streamLock.Unlock()
|
||||
return len(s.streams)
|
||||
}
|
||||
|
||||
// SetDeadline sets a deadline used by Accept* calls.
|
||||
// A zero time value disables the deadline.
|
||||
func (s *Session) SetDeadline(t time.Time) error {
|
||||
s.deadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr satisfies net.Conn interface
|
||||
func (s *Session) LocalAddr() net.Addr {
|
||||
if ts, ok := s.conn.(interface {
|
||||
LocalAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr satisfies net.Conn interface
|
||||
func (s *Session) RemoteAddr() net.Addr {
|
||||
if ts, ok := s.conn.(interface {
|
||||
RemoteAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// notify the session that a stream has closed
|
||||
func (s *Session) streamClosed(sid uint32) {
|
||||
s.streamLock.Lock()
|
||||
if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
|
||||
if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
|
||||
s.notifyBucket()
|
||||
}
|
||||
}
|
||||
delete(s.streams, sid)
|
||||
s.streamLock.Unlock()
|
||||
}
|
||||
|
||||
// returnTokens is called by stream to return token after read
|
||||
func (s *Session) returnTokens(n int) {
|
||||
if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
|
||||
s.notifyBucket()
|
||||
}
|
||||
}
|
||||
|
||||
// recvLoop keeps on reading from underlying connection if tokens are available
|
||||
func (s *Session) recvLoop() {
|
||||
var hdr rawHeader
|
||||
var updHdr updHeader
|
||||
|
||||
for {
|
||||
for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
|
||||
select {
|
||||
case <-s.bucketNotify:
|
||||
case <-s.die:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// read header first
|
||||
if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
|
||||
atomic.StoreInt32(&s.dataReady, 1)
|
||||
if hdr.Version() != byte(s.config.Version) {
|
||||
s.notifyProtoError(ErrInvalidProtocol)
|
||||
return
|
||||
}
|
||||
sid := hdr.StreamID()
|
||||
switch hdr.Cmd() {
|
||||
case cmdNOP:
|
||||
case cmdSYN:
|
||||
s.streamLock.Lock()
|
||||
if _, ok := s.streams[sid]; !ok {
|
||||
stream := newStream(sid, s.config.MaxFrameSize, s)
|
||||
s.streams[sid] = stream
|
||||
select {
|
||||
case s.chAccepts <- stream:
|
||||
case <-s.die:
|
||||
}
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
case cmdFIN:
|
||||
s.streamLock.Lock()
|
||||
if stream, ok := s.streams[sid]; ok {
|
||||
stream.fin()
|
||||
stream.notifyReadEvent()
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
case cmdPSH:
|
||||
if hdr.Length() > 0 {
|
||||
newbuf := pool.GetBuffer(int(hdr.Length()))
|
||||
if written, err := io.ReadFull(s.conn, newbuf); err == nil {
|
||||
s.streamLock.Lock()
|
||||
if stream, ok := s.streams[sid]; ok {
|
||||
stream.pushBytes(newbuf)
|
||||
atomic.AddInt32(&s.bucket, -int32(written))
|
||||
stream.notifyReadEvent()
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
} else {
|
||||
s.notifyReadError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
case cmdUPD:
|
||||
if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
|
||||
s.streamLock.Lock()
|
||||
if stream, ok := s.streams[sid]; ok {
|
||||
stream.update(updHdr.Consumed(), updHdr.Window())
|
||||
}
|
||||
s.streamLock.Unlock()
|
||||
} else {
|
||||
s.notifyReadError(err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
s.notifyProtoError(ErrInvalidProtocol)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
s.notifyReadError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) keepalive() {
|
||||
tickerPing := time.NewTicker(s.config.KeepAliveInterval)
|
||||
tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
|
||||
defer tickerPing.Stop()
|
||||
defer tickerTimeout.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tickerPing.C:
|
||||
s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
|
||||
s.notifyBucket() // force a signal to the recvLoop
|
||||
case <-tickerTimeout.C:
|
||||
if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
|
||||
// recvLoop may block while bucket is 0, in this case,
|
||||
// session should not be closed.
|
||||
if atomic.LoadInt32(&s.bucket) > 0 {
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
case <-s.die:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shaper shapes the sending sequence among streams
|
||||
func (s *Session) shaperLoop() {
|
||||
var reqs shaperHeap
|
||||
var next writeRequest
|
||||
var chWrite chan writeRequest
|
||||
|
||||
for {
|
||||
if len(reqs) > 0 {
|
||||
chWrite = s.writes
|
||||
next = heap.Pop(&reqs).(writeRequest)
|
||||
} else {
|
||||
chWrite = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.die:
|
||||
return
|
||||
case r := <-s.shaper:
|
||||
if chWrite != nil { // next is valid, reshape
|
||||
heap.Push(&reqs, next)
|
||||
}
|
||||
heap.Push(&reqs, r)
|
||||
case chWrite <- next:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) sendLoop() {
|
||||
var buf []byte
|
||||
var n int
|
||||
var err error
|
||||
var vec [][]byte // vector for writeBuffers
|
||||
|
||||
bw, ok := s.conn.(buffersWriter)
|
||||
if ok {
|
||||
buf = make([]byte, headerSize)
|
||||
vec = make([][]byte, 2)
|
||||
} else {
|
||||
buf = make([]byte, (1<<16)+headerSize)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.die:
|
||||
return
|
||||
case request := <-s.writes:
|
||||
buf[0] = request.frame.ver
|
||||
buf[1] = request.frame.cmd
|
||||
binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
|
||||
binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
|
||||
|
||||
if len(vec) > 0 {
|
||||
vec[0] = buf[:headerSize]
|
||||
vec[1] = request.frame.data
|
||||
n, err = bw.WriteBuffers(vec)
|
||||
} else {
|
||||
copy(buf[headerSize:], request.frame.data)
|
||||
n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
|
||||
}
|
||||
|
||||
n -= headerSize
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
result := writeResult{
|
||||
n: n,
|
||||
err: err,
|
||||
}
|
||||
|
||||
request.result <- result
|
||||
close(request.result)
|
||||
|
||||
// store conn error
|
||||
if err != nil {
|
||||
s.notifyWriteError(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeFrame writes the frame to the underlying connection
|
||||
// and returns the number of bytes written if successful
|
||||
func (s *Session) writeFrame(f Frame) (n int, err error) {
|
||||
return s.writeFrameInternal(f, nil, 0)
|
||||
}
|
||||
|
||||
// internal writeFrame version to support deadline used in keepalive
|
||||
func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
|
||||
req := writeRequest{
|
||||
prio: prio,
|
||||
frame: f,
|
||||
result: make(chan writeResult, 1),
|
||||
}
|
||||
select {
|
||||
case s.shaper <- req:
|
||||
case <-s.die:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-s.chSocketWriteError:
|
||||
return 0, s.socketWriteError.Load().(error)
|
||||
case <-deadline:
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
|
||||
select {
|
||||
case result := <-req.result:
|
||||
return result.n, result.err
|
||||
case <-s.die:
|
||||
return 0, io.ErrClosedPipe
|
||||
case <-s.chSocketWriteError:
|
||||
return 0, s.socketWriteError.Load().(error)
|
||||
case <-deadline:
|
||||
return 0, ErrTimeout
|
||||
}
|
||||
}
|
1090
proxy/protocol/smux/session_test.go
Normal file
1090
proxy/protocol/smux/session_test.go
Normal file
File diff suppressed because it is too large
Load Diff
16
proxy/protocol/smux/shaper.go
Normal file
16
proxy/protocol/smux/shaper.go
Normal file
@ -0,0 +1,16 @@
|
||||
package smux
|
||||
|
||||
type shaperHeap []writeRequest
|
||||
|
||||
func (h shaperHeap) Len() int { return len(h) }
|
||||
func (h shaperHeap) Less(i, j int) bool { return h[i].prio < h[j].prio }
|
||||
func (h shaperHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
|
||||
func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }
|
||||
|
||||
func (h *shaperHeap) Pop() interface{} {
|
||||
old := *h
|
||||
n := len(old)
|
||||
x := old[n-1]
|
||||
*h = old[0 : n-1]
|
||||
return x
|
||||
}
|
30
proxy/protocol/smux/shaper_test.go
Normal file
30
proxy/protocol/smux/shaper_test.go
Normal file
@ -0,0 +1,30 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShaper(t *testing.T) {
|
||||
w1 := writeRequest{prio: 10}
|
||||
w2 := writeRequest{prio: 10}
|
||||
w3 := writeRequest{prio: 20}
|
||||
w4 := writeRequest{prio: 100}
|
||||
|
||||
var reqs shaperHeap
|
||||
heap.Push(&reqs, w4)
|
||||
heap.Push(&reqs, w3)
|
||||
heap.Push(&reqs, w2)
|
||||
heap.Push(&reqs, w1)
|
||||
|
||||
var lastPrio uint64
|
||||
for len(reqs) > 0 {
|
||||
w := heap.Pop(&reqs).(writeRequest)
|
||||
if w.prio < lastPrio {
|
||||
t.Fatal("incorrect shaper priority")
|
||||
}
|
||||
|
||||
t.Log("prio:", w.prio)
|
||||
lastPrio = w.prio
|
||||
}
|
||||
}
|
545
proxy/protocol/smux/stream.go
Normal file
545
proxy/protocol/smux/stream.go
Normal file
@ -0,0 +1,545 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/nadoo/glider/pool"
|
||||
)
|
||||
|
||||
// Stream implements net.Conn
|
||||
type Stream struct {
|
||||
id uint32
|
||||
sess *Session
|
||||
|
||||
buffers [][]byte
|
||||
heads [][]byte // slice heads kept for recycle
|
||||
|
||||
bufferLock sync.Mutex
|
||||
frameSize int
|
||||
|
||||
// notify a read event
|
||||
chReadEvent chan struct{}
|
||||
|
||||
// flag the stream has closed
|
||||
die chan struct{}
|
||||
dieOnce sync.Once
|
||||
|
||||
// FIN command
|
||||
chFinEvent chan struct{}
|
||||
finEventOnce sync.Once
|
||||
|
||||
// deadlines
|
||||
readDeadline atomic.Value
|
||||
writeDeadline atomic.Value
|
||||
|
||||
// per stream sliding window control
|
||||
numRead uint32 // number of consumed bytes
|
||||
numWritten uint32 // count num of bytes written
|
||||
incr uint32 // counting for sending
|
||||
|
||||
// UPD command
|
||||
peerConsumed uint32 // num of bytes the peer has consumed
|
||||
peerWindow uint32 // peer window, initialized to 256KB, updated by peer
|
||||
chUpdate chan struct{} // notify of remote data consuming and window update
|
||||
}
|
||||
|
||||
// newStream initiates a Stream struct
|
||||
func newStream(id uint32, frameSize int, sess *Session) *Stream {
|
||||
s := new(Stream)
|
||||
s.id = id
|
||||
s.chReadEvent = make(chan struct{}, 1)
|
||||
s.chUpdate = make(chan struct{}, 1)
|
||||
s.frameSize = frameSize
|
||||
s.sess = sess
|
||||
s.die = make(chan struct{})
|
||||
s.chFinEvent = make(chan struct{})
|
||||
s.peerWindow = initialPeerWindow // set to initial window size
|
||||
return s
|
||||
}
|
||||
|
||||
// ID returns the unique stream ID.
|
||||
func (s *Stream) ID() uint32 {
|
||||
return s.id
|
||||
}
|
||||
|
||||
// Read implements net.Conn
|
||||
func (s *Stream) Read(b []byte) (n int, err error) {
|
||||
for {
|
||||
n, err = s.tryRead(b)
|
||||
if err == ErrWouldBlock {
|
||||
if ew := s.waitRead(); ew != nil {
|
||||
return 0, ew
|
||||
}
|
||||
} else {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tryRead is the nonblocking version of Read
|
||||
func (s *Stream) tryRead(b []byte) (n int, err error) {
|
||||
if s.sess.config.Version == 2 {
|
||||
return s.tryReadv2(b)
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
s.bufferLock.Lock()
|
||||
if len(s.buffers) > 0 {
|
||||
n = copy(b, s.buffers[0])
|
||||
s.buffers[0] = s.buffers[0][n:]
|
||||
if len(s.buffers[0]) == 0 {
|
||||
s.buffers[0] = nil
|
||||
s.buffers = s.buffers[1:]
|
||||
// full recycle
|
||||
pool.PutBuffer(s.heads[0])
|
||||
s.heads = s.heads[1:]
|
||||
}
|
||||
}
|
||||
s.bufferLock.Unlock()
|
||||
|
||||
if n > 0 {
|
||||
s.sess.returnTokens(n)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.die:
|
||||
return 0, io.EOF
|
||||
default:
|
||||
return 0, ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) tryReadv2(b []byte) (n int, err error) {
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
var notifyConsumed uint32
|
||||
s.bufferLock.Lock()
|
||||
if len(s.buffers) > 0 {
|
||||
n = copy(b, s.buffers[0])
|
||||
s.buffers[0] = s.buffers[0][n:]
|
||||
if len(s.buffers[0]) == 0 {
|
||||
s.buffers[0] = nil
|
||||
s.buffers = s.buffers[1:]
|
||||
// full recycle
|
||||
pool.PutBuffer(s.heads[0])
|
||||
s.heads = s.heads[1:]
|
||||
}
|
||||
}
|
||||
|
||||
// in an ideal environment:
|
||||
// if more than half of buffer has consumed, send read ack to peer
|
||||
// based on round-trip time of ACK, continous flowing data
|
||||
// won't slow down because of waiting for ACK, as long as the
|
||||
// consumer keeps on reading data
|
||||
// s.numRead == n also notify window at the first read
|
||||
s.numRead += uint32(n)
|
||||
s.incr += uint32(n)
|
||||
if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
|
||||
notifyConsumed = s.numRead
|
||||
s.incr = 0
|
||||
}
|
||||
s.bufferLock.Unlock()
|
||||
|
||||
if n > 0 {
|
||||
s.sess.returnTokens(n)
|
||||
if notifyConsumed > 0 {
|
||||
err := s.sendWindowUpdate(notifyConsumed)
|
||||
return n, err
|
||||
} else {
|
||||
return n, nil
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.die:
|
||||
return 0, io.EOF
|
||||
default:
|
||||
return 0, ErrWouldBlock
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo implements io.WriteTo
|
||||
func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
|
||||
if s.sess.config.Version == 2 {
|
||||
return s.writeTov2(w)
|
||||
}
|
||||
|
||||
for {
|
||||
var buf []byte
|
||||
s.bufferLock.Lock()
|
||||
if len(s.buffers) > 0 {
|
||||
buf = s.buffers[0]
|
||||
s.buffers = s.buffers[1:]
|
||||
s.heads = s.heads[1:]
|
||||
}
|
||||
s.bufferLock.Unlock()
|
||||
|
||||
if buf != nil {
|
||||
nw, ew := w.Write(buf)
|
||||
s.sess.returnTokens(len(buf))
|
||||
pool.PutBuffer(buf)
|
||||
if nw > 0 {
|
||||
n += int64(nw)
|
||||
}
|
||||
|
||||
if ew != nil {
|
||||
return n, ew
|
||||
}
|
||||
} else if ew := s.waitRead(); ew != nil {
|
||||
return n, ew
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
|
||||
for {
|
||||
var notifyConsumed uint32
|
||||
var buf []byte
|
||||
s.bufferLock.Lock()
|
||||
if len(s.buffers) > 0 {
|
||||
buf = s.buffers[0]
|
||||
s.buffers = s.buffers[1:]
|
||||
s.heads = s.heads[1:]
|
||||
}
|
||||
s.numRead += uint32(len(buf))
|
||||
s.incr += uint32(len(buf))
|
||||
if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
|
||||
notifyConsumed = s.numRead
|
||||
s.incr = 0
|
||||
}
|
||||
s.bufferLock.Unlock()
|
||||
|
||||
if buf != nil {
|
||||
nw, ew := w.Write(buf)
|
||||
s.sess.returnTokens(len(buf))
|
||||
pool.PutBuffer(buf)
|
||||
if nw > 0 {
|
||||
n += int64(nw)
|
||||
}
|
||||
|
||||
if ew != nil {
|
||||
return n, ew
|
||||
}
|
||||
|
||||
if notifyConsumed > 0 {
|
||||
if err := s.sendWindowUpdate(notifyConsumed); err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
} else if ew := s.waitRead(); ew != nil {
|
||||
return n, ew
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Stream) sendWindowUpdate(consumed uint32) error {
|
||||
var timer *time.Timer
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer = time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
|
||||
var hdr updHeader
|
||||
binary.LittleEndian.PutUint32(hdr[:], consumed)
|
||||
binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
|
||||
frame.data = hdr[:]
|
||||
_, err := s.sess.writeFrameInternal(frame, deadline, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Stream) waitRead() error {
|
||||
var timer *time.Timer
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer = time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
select {
|
||||
case <-s.chReadEvent:
|
||||
return nil
|
||||
case <-s.chFinEvent:
|
||||
return io.EOF
|
||||
case <-s.sess.chSocketReadError:
|
||||
return s.sess.socketReadError.Load().(error)
|
||||
case <-s.sess.chProtoError:
|
||||
return s.sess.protoError.Load().(error)
|
||||
case <-deadline:
|
||||
return ErrTimeout
|
||||
case <-s.die:
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Write implements net.Conn
|
||||
//
|
||||
// Note that the behavior when multiple goroutines write concurrently is not deterministic,
|
||||
// frames may interleave in random way.
|
||||
func (s *Stream) Write(b []byte) (n int, err error) {
|
||||
if s.sess.config.Version == 2 {
|
||||
return s.writeV2(b)
|
||||
}
|
||||
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer := time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
// check if stream has closed
|
||||
select {
|
||||
case <-s.die:
|
||||
return 0, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
// frame split and transmit
|
||||
sent := 0
|
||||
frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
|
||||
bts := b
|
||||
for len(bts) > 0 {
|
||||
sz := len(bts)
|
||||
if sz > s.frameSize {
|
||||
sz = s.frameSize
|
||||
}
|
||||
frame.data = bts[:sz]
|
||||
bts = bts[sz:]
|
||||
n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten))
|
||||
s.numWritten++
|
||||
sent += n
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
}
|
||||
|
||||
return sent, nil
|
||||
}
|
||||
|
||||
func (s *Stream) writeV2(b []byte) (n int, err error) {
|
||||
// check empty input
|
||||
if len(b) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// check if stream has closed
|
||||
select {
|
||||
case <-s.die:
|
||||
return 0, io.ErrClosedPipe
|
||||
default:
|
||||
}
|
||||
|
||||
// create write deadline timer
|
||||
var deadline <-chan time.Time
|
||||
if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
|
||||
timer := time.NewTimer(time.Until(d))
|
||||
defer timer.Stop()
|
||||
deadline = timer.C
|
||||
}
|
||||
|
||||
// frame split and transmit process
|
||||
sent := 0
|
||||
frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
|
||||
|
||||
for {
|
||||
// per stream sliding window control
|
||||
// [.... [consumed... numWritten] ... win... ]
|
||||
// [.... [consumed...................+rmtwnd]]
|
||||
var bts []byte
|
||||
// note:
|
||||
// even if uint32 overflow, this math still works:
|
||||
// eg1: uint32(0) - uint32(math.MaxUint32) = 1
|
||||
// eg2: int32(uint32(0) - uint32(1)) = -1
|
||||
// security check for misbehavior
|
||||
inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
|
||||
if inflight < 0 {
|
||||
return 0, ErrConsumed
|
||||
}
|
||||
|
||||
win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
|
||||
if win > 0 {
|
||||
if win > int32(len(b)) {
|
||||
bts = b
|
||||
b = nil
|
||||
} else {
|
||||
bts = b[:win]
|
||||
b = b[win:]
|
||||
}
|
||||
|
||||
for len(bts) > 0 {
|
||||
sz := len(bts)
|
||||
if sz > s.frameSize {
|
||||
sz = s.frameSize
|
||||
}
|
||||
frame.data = bts[:sz]
|
||||
bts = bts[sz:]
|
||||
n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
|
||||
atomic.AddUint32(&s.numWritten, uint32(sz))
|
||||
sent += n
|
||||
if err != nil {
|
||||
return sent, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if there is any data remaining to be sent
|
||||
// wait until stream closes, window changes or deadline reached
|
||||
// this blocking behavior will inform upper layer to do flow control
|
||||
if len(b) > 0 {
|
||||
select {
|
||||
case <-s.chFinEvent: // if fin arrived, future window update is impossible
|
||||
return 0, io.EOF
|
||||
case <-s.die:
|
||||
return sent, io.ErrClosedPipe
|
||||
case <-deadline:
|
||||
return sent, ErrTimeout
|
||||
case <-s.sess.chSocketWriteError:
|
||||
return sent, s.sess.socketWriteError.Load().(error)
|
||||
case <-s.chUpdate:
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
return sent, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close implements net.Conn
|
||||
func (s *Stream) Close() error {
|
||||
var once bool
|
||||
var err error
|
||||
s.dieOnce.Do(func() {
|
||||
close(s.die)
|
||||
once = true
|
||||
})
|
||||
|
||||
if once {
|
||||
_, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id))
|
||||
s.sess.streamClosed(s.id)
|
||||
return err
|
||||
} else {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
}
|
||||
|
||||
// GetDieCh returns a readonly chan which can be readable
|
||||
// when the stream is to be closed.
|
||||
func (s *Stream) GetDieCh() <-chan struct{} {
|
||||
return s.die
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the read deadline as defined by
|
||||
// net.Conn.SetReadDeadline.
|
||||
// A zero time value disables the deadline.
|
||||
func (s *Stream) SetReadDeadline(t time.Time) error {
|
||||
s.readDeadline.Store(t)
|
||||
s.notifyReadEvent()
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the write deadline as defined by
|
||||
// net.Conn.SetWriteDeadline.
|
||||
// A zero time value disables the deadline.
|
||||
func (s *Stream) SetWriteDeadline(t time.Time) error {
|
||||
s.writeDeadline.Store(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDeadline sets both read and write deadlines as defined by
|
||||
// net.Conn.SetDeadline.
|
||||
// A zero time value disables the deadlines.
|
||||
func (s *Stream) SetDeadline(t time.Time) error {
|
||||
if err := s.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.SetWriteDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// session closes
|
||||
func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
|
||||
|
||||
// LocalAddr satisfies net.Conn interface
|
||||
func (s *Stream) LocalAddr() net.Addr {
|
||||
if ts, ok := s.sess.conn.(interface {
|
||||
LocalAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.LocalAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoteAddr satisfies net.Conn interface
|
||||
func (s *Stream) RemoteAddr() net.Addr {
|
||||
if ts, ok := s.sess.conn.(interface {
|
||||
RemoteAddr() net.Addr
|
||||
}); ok {
|
||||
return ts.RemoteAddr()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pushBytes append buf to buffers
|
||||
func (s *Stream) pushBytes(buf []byte) (written int, err error) {
|
||||
s.bufferLock.Lock()
|
||||
s.buffers = append(s.buffers, buf)
|
||||
s.heads = append(s.heads, buf)
|
||||
s.bufferLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// recycleTokens transform remaining bytes to tokens(will truncate buffer)
|
||||
func (s *Stream) recycleTokens() (n int) {
|
||||
s.bufferLock.Lock()
|
||||
for k := range s.buffers {
|
||||
n += len(s.buffers[k])
|
||||
pool.PutBuffer(s.heads[k])
|
||||
}
|
||||
s.buffers = nil
|
||||
s.heads = nil
|
||||
s.bufferLock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// notify read event
|
||||
func (s *Stream) notifyReadEvent() {
|
||||
select {
|
||||
case s.chReadEvent <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// update command
|
||||
func (s *Stream) update(consumed uint32, window uint32) {
|
||||
atomic.StoreUint32(&s.peerConsumed, consumed)
|
||||
atomic.StoreUint32(&s.peerWindow, window)
|
||||
select {
|
||||
case s.chUpdate <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// mark this stream has been closed in protocol
|
||||
func (s *Stream) fin() {
|
||||
s.finEventOnce.Do(func() {
|
||||
close(s.chFinEvent)
|
||||
})
|
||||
}
|
76
proxy/smux/client.go
Normal file
76
proxy/smux/client.go
Normal file
@ -0,0 +1,76 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/url"
|
||||
|
||||
"github.com/nadoo/glider/log"
|
||||
"github.com/nadoo/glider/proxy"
|
||||
|
||||
"github.com/nadoo/glider/proxy/protocol/smux"
|
||||
)
|
||||
|
||||
// SmuxClient struct.
|
||||
type SmuxClient struct {
|
||||
dialer proxy.Dialer
|
||||
addr string
|
||||
session *smux.Session
|
||||
}
|
||||
|
||||
func init() {
|
||||
proxy.RegisterDialer("smux", NewSmuxDialer)
|
||||
}
|
||||
|
||||
// NewSmuxDialer returns a smux dialer.
|
||||
func NewSmuxDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
log.F("[smux] parse url err: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &SmuxClient{
|
||||
dialer: d,
|
||||
addr: u.Host,
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Addr returns forwarder's address.
|
||||
func (s *SmuxClient) Addr() string {
|
||||
if s.addr == "" {
|
||||
return s.dialer.Addr()
|
||||
}
|
||||
return s.addr
|
||||
}
|
||||
|
||||
// Dial connects to the address addr on the network net via the proxy.
|
||||
func (s *SmuxClient) Dial(network, addr string) (net.Conn, error) {
|
||||
if s.session != nil {
|
||||
if c, err := s.session.OpenStream(); err == nil {
|
||||
return c, err
|
||||
}
|
||||
s.session.Close()
|
||||
}
|
||||
if err := s.initConn(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.session.OpenStream()
|
||||
}
|
||||
|
||||
// DialUDP connects to the given address via the proxy.
|
||||
func (s *SmuxClient) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) {
|
||||
return nil, nil, errors.New("smux client does not support udp now")
|
||||
}
|
||||
|
||||
func (s *SmuxClient) initConn() error {
|
||||
conn, err := s.dialer.Dial("tcp", s.addr)
|
||||
if err != nil {
|
||||
log.F("[smux] dial to %s error: %s", s.addr, err)
|
||||
return err
|
||||
}
|
||||
s.session, err = smux.Client(conn, nil)
|
||||
return err
|
||||
}
|
98
proxy/smux/server.go
Normal file
98
proxy/smux/server.go
Normal file
@ -0,0 +1,98 @@
|
||||
package smux
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/nadoo/glider/log"
|
||||
"github.com/nadoo/glider/proxy"
|
||||
|
||||
"github.com/nadoo/glider/proxy/protocol/smux"
|
||||
)
|
||||
|
||||
// SmuxServer struct.
|
||||
type SmuxServer struct {
|
||||
proxy proxy.Proxy
|
||||
addr string
|
||||
server proxy.Server
|
||||
}
|
||||
|
||||
func init() {
|
||||
proxy.RegisterServer("smux", NewSmuxServer)
|
||||
}
|
||||
|
||||
// NewSmuxServer returns a smux transport layer before the real server.
|
||||
func NewSmuxServer(s string, p proxy.Proxy) (proxy.Server, error) {
|
||||
transport := strings.Split(s, ",")
|
||||
|
||||
u, err := url.Parse(transport[0])
|
||||
if err != nil {
|
||||
log.F("[smux] parse url err: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &SmuxServer{
|
||||
proxy: p,
|
||||
addr: u.Host,
|
||||
}
|
||||
|
||||
if len(transport) > 1 {
|
||||
m.server, err = proxy.ServerFromURL(transport[1], p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// ListenAndServe listens on server's addr and serves connections.
|
||||
func (s *SmuxServer) ListenAndServe() {
|
||||
l, err := net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
log.F("[smux] failed to listen on %s: %v", s.addr, err)
|
||||
return
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
log.F("[smux] listening mux on %s", s.addr)
|
||||
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
log.F("[smux] failed to accept: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go s.Serve(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Serve serves a connection.
|
||||
func (s *SmuxServer) Serve(c net.Conn) {
|
||||
// we know the internal server will close the connection after serve
|
||||
// defer c.Close()
|
||||
|
||||
session, err := smux.Server(c, nil)
|
||||
if err != nil {
|
||||
log.F("[smux] failed to create session: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
// Accept a stream
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
session.Close()
|
||||
break
|
||||
}
|
||||
go s.ServeStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmuxServer) ServeStream(c *smux.Stream) {
|
||||
if s.server != nil {
|
||||
s.server.Serve(c)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user