mirror of
				https://github.com/nadoo/glider.git
				synced 2025-11-04 07:42:38 +08:00 
			
		
		
		
	proxy: added smux support
This commit is contained in:
		
							parent
							
								
									34a053b875
								
							
						
					
					
						commit
						dbd2e04521
					
				
							
								
								
									
										37
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								README.md
									
									
									
									
									
								
							@ -63,6 +63,7 @@ we can set up local listeners as proxy servers, and forward requests to internet
 | 
				
			|||||||
|TLS          |√| |√| |transport client & server
 | 
					|TLS          |√| |√| |transport client & server
 | 
				
			||||||
|KCP          | |√|√| |transport client & server
 | 
					|KCP          | |√|√| |transport client & server
 | 
				
			||||||
|Unix         |√|√|√|√|transport client & server
 | 
					|Unix         |√|√|√|√|transport client & server
 | 
				
			||||||
 | 
					|Smux         |√| |√| |transport client & server
 | 
				
			||||||
|Websocket    |√| |√| |transport client & server
 | 
					|Websocket    |√| |√| |transport client & server
 | 
				
			||||||
|Simple-Obfs  | | |√| |transport client only
 | 
					|Simple-Obfs  | | |√| |transport client only
 | 
				
			||||||
|Redir        |√| | | |linux only
 | 
					|Redir        |√| | | |linux only
 | 
				
			||||||
@ -96,7 +97,7 @@ glider -h
 | 
				
			|||||||
<summary>click to see details</summary>
 | 
					<summary>click to see details</summary>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
glider 0.13.0 usage:
 | 
					glider 0.14.0 usage:
 | 
				
			||||||
  -check string
 | 
					  -check string
 | 
				
			||||||
    	check=tcp[://HOST:PORT]: tcp port connect check
 | 
					    	check=tcp[://HOST:PORT]: tcp port connect check
 | 
				
			||||||
    	check=http://HOST[:PORT][/URI][#expect=STRING_IN_RESP_LINE]
 | 
					    	check=http://HOST[:PORT][/URI][#expect=STRING_IN_RESP_LINE]
 | 
				
			||||||
@ -152,27 +153,10 @@ glider 0.13.0 usage:
 | 
				
			|||||||
    	forward strategy, default: rr (default "rr")
 | 
					    	forward strategy, default: rr (default "rr")
 | 
				
			||||||
  -verbose
 | 
					  -verbose
 | 
				
			||||||
    	verbose mode
 | 
					    	verbose mode
 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
</details>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
run:
 | 
					 | 
				
			||||||
```bash
 | 
					 | 
				
			||||||
glider -config CONFIGPATH
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
```bash
 | 
					 | 
				
			||||||
glider -verbose -listen :8443 -forward SCHEME://HOST:PORT
 | 
					 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Schemes
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
<details>
 | 
					 | 
				
			||||||
<summary>click to see details</summary>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
```bash
 | 
					 | 
				
			||||||
Available schemes:
 | 
					Available schemes:
 | 
				
			||||||
  listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp
 | 
					  listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp
 | 
				
			||||||
  forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs
 | 
					  forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Socks5 scheme:
 | 
					Socks5 scheme:
 | 
				
			||||||
  socks://[user:pass@]host:port
 | 
					  socks://[user:pass@]host:port
 | 
				
			||||||
@ -251,6 +235,9 @@ TLS and Websocket with a specified proxy protocol:
 | 
				
			|||||||
Unix domain socket scheme:
 | 
					Unix domain socket scheme:
 | 
				
			||||||
  unix://path
 | 
					  unix://path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Smux scheme:
 | 
				
			||||||
 | 
					  smux://host:port
 | 
				
			||||||
 | 
					
 | 
				
			||||||
KCP scheme:
 | 
					KCP scheme:
 | 
				
			||||||
  kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]
 | 
					  kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -298,16 +285,8 @@ Config file format(see `./glider.conf.example` as an example):
 | 
				
			|||||||
  KEY=VALUE
 | 
					  KEY=VALUE
 | 
				
			||||||
  KEY=VALUE
 | 
					  KEY=VALUE
 | 
				
			||||||
  # KEY equals to command line flag name: listen forward strategy...
 | 
					  # KEY equals to command line flag name: listen forward strategy...
 | 
				
			||||||
```
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
</details>
 | 
					Examples:
 | 
				
			||||||
 | 
					 | 
				
			||||||
#### Examples
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
<details>
 | 
					 | 
				
			||||||
<summary>click to see details</summary>
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
```bash
 | 
					 | 
				
			||||||
  ./glider -config glider.conf
 | 
					  ./glider -config glider.conf
 | 
				
			||||||
    -run glider with specified config file.
 | 
					    -run glider with specified config file.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -131,8 +131,8 @@ func usage() {
 | 
				
			|||||||
	fmt.Fprintf(w, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Fprintf(w, "Available schemes:\n")
 | 
						fmt.Fprintf(w, "Available schemes:\n")
 | 
				
			||||||
	fmt.Fprintf(w, "  listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix kcp\n")
 | 
						fmt.Fprintf(w, "  listen: mixed ss socks5 http vless trojan trojanc redir redir6 tcp udp tls ws unix smux kcp\n")
 | 
				
			||||||
	fmt.Fprintf(w, "  forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix kcp simple-obfs\n")
 | 
						fmt.Fprintf(w, "  forward: reject ss socks4 socks5 http ssr ssh vless vmess trojan trojanc tcp udp tls ws unix smux kcp simple-obfs\n")
 | 
				
			||||||
	fmt.Fprintf(w, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Fprintf(w, "Socks5 scheme:\n")
 | 
						fmt.Fprintf(w, "Socks5 scheme:\n")
 | 
				
			||||||
@ -231,6 +231,10 @@ func usage() {
 | 
				
			|||||||
	fmt.Fprintf(w, "  unix://path\n")
 | 
						fmt.Fprintf(w, "  unix://path\n")
 | 
				
			||||||
	fmt.Fprintf(w, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Fprintf(w, "Smux scheme:\n")
 | 
				
			||||||
 | 
						fmt.Fprintf(w, "  smux://host:port\n")
 | 
				
			||||||
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	fmt.Fprintf(w, "KCP scheme:\n")
 | 
						fmt.Fprintf(w, "KCP scheme:\n")
 | 
				
			||||||
	fmt.Fprintf(w, "  kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]\n")
 | 
						fmt.Fprintf(w, "  kcp://CRYPT:KEY@host:port[?dataShards=NUM&parityShards=NUM&mode=MODE]\n")
 | 
				
			||||||
	fmt.Fprintf(w, "\n")
 | 
						fmt.Fprintf(w, "\n")
 | 
				
			||||||
 | 
				
			|||||||
@ -10,6 +10,7 @@ import (
 | 
				
			|||||||
	_ "github.com/nadoo/glider/proxy/mixed"
 | 
						_ "github.com/nadoo/glider/proxy/mixed"
 | 
				
			||||||
	_ "github.com/nadoo/glider/proxy/obfs"
 | 
						_ "github.com/nadoo/glider/proxy/obfs"
 | 
				
			||||||
	_ "github.com/nadoo/glider/proxy/reject"
 | 
						_ "github.com/nadoo/glider/proxy/reject"
 | 
				
			||||||
 | 
						_ "github.com/nadoo/glider/proxy/smux"
 | 
				
			||||||
	_ "github.com/nadoo/glider/proxy/socks4"
 | 
						_ "github.com/nadoo/glider/proxy/socks4"
 | 
				
			||||||
	_ "github.com/nadoo/glider/proxy/socks5"
 | 
						_ "github.com/nadoo/glider/proxy/socks5"
 | 
				
			||||||
	_ "github.com/nadoo/glider/proxy/ss"
 | 
						_ "github.com/nadoo/glider/proxy/ss"
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							@ -8,7 +8,7 @@ require (
 | 
				
			|||||||
	github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb
 | 
						github.com/dgryski/go-idea v0.0.0-20170306091226-d2fb45a411fb
 | 
				
			||||||
	github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152
 | 
						github.com/dgryski/go-rc2 v0.0.0-20150621095337-8a9021637152
 | 
				
			||||||
	github.com/ebfe/rc2 v0.0.0-20131011165748-24b9757f5521 // indirect
 | 
						github.com/ebfe/rc2 v0.0.0-20131011165748-24b9757f5521 // indirect
 | 
				
			||||||
	github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa
 | 
						github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305
 | 
				
			||||||
	github.com/klauspost/cpuid/v2 v2.0.6 // indirect
 | 
						github.com/klauspost/cpuid/v2 v2.0.6 // indirect
 | 
				
			||||||
	github.com/klauspost/reedsolomon v1.9.12 // indirect
 | 
						github.com/klauspost/reedsolomon v1.9.12 // indirect
 | 
				
			||||||
	github.com/mdlayher/raw v0.0.0-20210412142147-51b895745faf // indirect
 | 
						github.com/mdlayher/raw v0.0.0-20210412142147-51b895745faf // indirect
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										4
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								go.sum
									
									
									
									
									
								
							@ -39,8 +39,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
 | 
				
			|||||||
github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
 | 
					github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY=
 | 
				
			||||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
 | 
					github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714 h1:/jC7qQFrv8CrSJVmaolDVOxTfS9kc36uB6H40kdbQq8=
 | 
				
			||||||
github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis=
 | 
					github.com/hugelgupf/socketpair v0.0.0-20190730060125-05d35a94e714/go.mod h1:2Goc3h8EklBH5mspfHFxBnEoURQCGzQQH1ga9Myjvis=
 | 
				
			||||||
github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa h1:ff6iUOGfajZ+T8RBuRtOjYMW7/6aF8iJG+iCp2wQyQ4=
 | 
					github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305 h1:DGmCtsdLE6r7tuFH5YHnDoKJU9WXjptF/8Q8iKc41Tk=
 | 
				
			||||||
github.com/insomniacslk/dhcp v0.0.0-20210315110227-c51060810aaa/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
 | 
					github.com/insomniacslk/dhcp v0.0.0-20210420161629-6bd1ce0fd305/go.mod h1:TKl4jN3Voofo4UJIicyNhWGp/nlQqQkFxmwIFTvBkKI=
 | 
				
			||||||
github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
 | 
					github.com/jsimonetti/rtnetlink v0.0.0-20190606172950-9527aa82566a/go.mod h1:Oz+70psSo5OFh8DBl0Zv2ACw7Esh6pPUphlvZG9x7uw=
 | 
				
			||||||
github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
 | 
					github.com/jsimonetti/rtnetlink v0.0.0-20200117123717-f846d4f6c1f4/go.mod h1:WGuG/smIU4J/54PblvSbh+xvCZmpJnFgr3ds6Z55XMQ=
 | 
				
			||||||
github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
 | 
					github.com/jsimonetti/rtnetlink v0.0.0-20201009170750-9c6f07d100c1/go.mod h1:hqoO/u39cqLeBLebZ8fWdE96O7FxrAsRYhnVOdgHxok=
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										21
									
								
								proxy/protocol/smux/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								proxy/protocol/smux/LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,21 @@
 | 
				
			|||||||
 | 
					MIT License
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright (c) 2016-2017 Daniel Fu
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					SOFTWARE.
 | 
				
			||||||
							
								
								
									
										81
									
								
								proxy/protocol/smux/frame.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								proxy/protocol/smux/frame.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,81 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const ( // cmds
 | 
				
			||||||
 | 
						// protocol version 1:
 | 
				
			||||||
 | 
						cmdSYN byte = iota // stream open
 | 
				
			||||||
 | 
						cmdFIN             // stream close, a.k.a EOF mark
 | 
				
			||||||
 | 
						cmdPSH             // data push
 | 
				
			||||||
 | 
						cmdNOP             // no operation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// protocol version 2 extra commands
 | 
				
			||||||
 | 
						// notify bytes consumed by remote peer-end
 | 
				
			||||||
 | 
						cmdUPD
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						// data size of cmdUPD, format:
 | 
				
			||||||
 | 
						// |4B data consumed(ACK)| 4B window size(WINDOW) |
 | 
				
			||||||
 | 
						szCmdUPD = 8
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						// initial peer window guess, a slow-start
 | 
				
			||||||
 | 
						initialPeerWindow = 262144
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						sizeOfVer    = 1
 | 
				
			||||||
 | 
						sizeOfCmd    = 1
 | 
				
			||||||
 | 
						sizeOfLength = 2
 | 
				
			||||||
 | 
						sizeOfSid    = 4
 | 
				
			||||||
 | 
						headerSize   = sizeOfVer + sizeOfCmd + sizeOfSid + sizeOfLength
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Frame defines a packet from or to be multiplexed into a single connection
 | 
				
			||||||
 | 
					type Frame struct {
 | 
				
			||||||
 | 
						ver  byte
 | 
				
			||||||
 | 
						cmd  byte
 | 
				
			||||||
 | 
						sid  uint32
 | 
				
			||||||
 | 
						data []byte
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newFrame(version byte, cmd byte, sid uint32) Frame {
 | 
				
			||||||
 | 
						return Frame{ver: version, cmd: cmd, sid: sid}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type rawHeader [headerSize]byte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h rawHeader) Version() byte {
 | 
				
			||||||
 | 
						return h[0]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h rawHeader) Cmd() byte {
 | 
				
			||||||
 | 
						return h[1]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h rawHeader) Length() uint16 {
 | 
				
			||||||
 | 
						return binary.LittleEndian.Uint16(h[2:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h rawHeader) StreamID() uint32 {
 | 
				
			||||||
 | 
						return binary.LittleEndian.Uint32(h[4:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h rawHeader) String() string {
 | 
				
			||||||
 | 
						return fmt.Sprintf("Version:%d Cmd:%d StreamID:%d Length:%d",
 | 
				
			||||||
 | 
							h.Version(), h.Cmd(), h.StreamID(), h.Length())
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type updHeader [szCmdUPD]byte
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h updHeader) Consumed() uint32 {
 | 
				
			||||||
 | 
						return binary.LittleEndian.Uint32(h[:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					func (h updHeader) Window() uint32 {
 | 
				
			||||||
 | 
						return binary.LittleEndian.Uint32(h[4:])
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										110
									
								
								proxy/protocol/smux/mux.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								proxy/protocol/smux/mux.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,110 @@
 | 
				
			|||||||
 | 
					// Package smux is a multiplexing library for Golang.
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// It relies on an underlying connection to provide reliability and ordering, such as TCP or KCP,
 | 
				
			||||||
 | 
					// and provides stream-oriented multiplexing over a single channel.
 | 
				
			||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"math"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Config is used to tune the Smux session
 | 
				
			||||||
 | 
					type Config struct {
 | 
				
			||||||
 | 
						// SMUX Protocol version, support 1,2
 | 
				
			||||||
 | 
						Version int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Disabled keepalive
 | 
				
			||||||
 | 
						KeepAliveDisabled bool
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// KeepAliveInterval is how often to send a NOP command to the remote
 | 
				
			||||||
 | 
						KeepAliveInterval time.Duration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// KeepAliveTimeout is how long the session
 | 
				
			||||||
 | 
						// will be closed if no data has arrived
 | 
				
			||||||
 | 
						KeepAliveTimeout time.Duration
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxFrameSize is used to control the maximum
 | 
				
			||||||
 | 
						// frame size to sent to the remote
 | 
				
			||||||
 | 
						MaxFrameSize int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxReceiveBuffer is used to control the maximum
 | 
				
			||||||
 | 
						// number of data in the buffer pool
 | 
				
			||||||
 | 
						MaxReceiveBuffer int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// MaxStreamBuffer is used to control the maximum
 | 
				
			||||||
 | 
						// number of data per stream
 | 
				
			||||||
 | 
						MaxStreamBuffer int
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DefaultConfig is used to return a default configuration
 | 
				
			||||||
 | 
					func DefaultConfig() *Config {
 | 
				
			||||||
 | 
						return &Config{
 | 
				
			||||||
 | 
							Version:           1,
 | 
				
			||||||
 | 
							KeepAliveInterval: 10 * time.Second,
 | 
				
			||||||
 | 
							KeepAliveTimeout:  30 * time.Second,
 | 
				
			||||||
 | 
							MaxFrameSize:      32768,
 | 
				
			||||||
 | 
							MaxReceiveBuffer:  4194304,
 | 
				
			||||||
 | 
							MaxStreamBuffer:   65536,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// VerifyConfig is used to verify the sanity of configuration
 | 
				
			||||||
 | 
					func VerifyConfig(config *Config) error {
 | 
				
			||||||
 | 
						if !(config.Version == 1 || config.Version == 2) {
 | 
				
			||||||
 | 
							return errors.New("unsupported protocol version")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !config.KeepAliveDisabled {
 | 
				
			||||||
 | 
							if config.KeepAliveInterval == 0 {
 | 
				
			||||||
 | 
								return errors.New("keep-alive interval must be positive")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if config.KeepAliveTimeout < config.KeepAliveInterval {
 | 
				
			||||||
 | 
								return fmt.Errorf("keep-alive timeout must be larger than keep-alive interval")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxFrameSize <= 0 {
 | 
				
			||||||
 | 
							return errors.New("max frame size must be positive")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxFrameSize > 65535 {
 | 
				
			||||||
 | 
							return errors.New("max frame size must not be larger than 65535")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxReceiveBuffer <= 0 {
 | 
				
			||||||
 | 
							return errors.New("max receive buffer must be positive")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxStreamBuffer <= 0 {
 | 
				
			||||||
 | 
							return errors.New("max stream buffer must be positive")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxStreamBuffer > config.MaxReceiveBuffer {
 | 
				
			||||||
 | 
							return errors.New("max stream buffer must not be larger than max receive buffer")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if config.MaxStreamBuffer > math.MaxInt32 {
 | 
				
			||||||
 | 
							return errors.New("max stream buffer cannot be larger than 2147483647")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Server is used to initialize a new server-side connection.
 | 
				
			||||||
 | 
					func Server(conn io.ReadWriteCloser, config *Config) (*Session, error) {
 | 
				
			||||||
 | 
						if config == nil {
 | 
				
			||||||
 | 
							config = DefaultConfig()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := VerifyConfig(config); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return newSession(config, conn, false), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Client is used to initialize a new client-side connection.
 | 
				
			||||||
 | 
					func Client(conn io.ReadWriteCloser, config *Config) (*Session, error) {
 | 
				
			||||||
 | 
						if config == nil {
 | 
				
			||||||
 | 
							config = DefaultConfig()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if err := VerifyConfig(config); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return newSession(config, conn, true), nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										86
									
								
								proxy/protocol/smux/mux_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								proxy/protocol/smux/mux_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,86 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type buffer struct {
 | 
				
			||||||
 | 
						bytes.Buffer
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (b *buffer) Close() error {
 | 
				
			||||||
 | 
						b.Buffer.Reset()
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestConfig(t *testing.T) {
 | 
				
			||||||
 | 
						VerifyConfig(DefaultConfig())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config := DefaultConfig()
 | 
				
			||||||
 | 
						config.KeepAliveInterval = 0
 | 
				
			||||||
 | 
						err := VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.KeepAliveInterval = 10
 | 
				
			||||||
 | 
						config.KeepAliveTimeout = 5
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.MaxFrameSize = 0
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.MaxFrameSize = 65536
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.MaxReceiveBuffer = 0
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.MaxStreamBuffer = 0
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config = DefaultConfig()
 | 
				
			||||||
 | 
						config.MaxStreamBuffer = 100
 | 
				
			||||||
 | 
						config.MaxReceiveBuffer = 99
 | 
				
			||||||
 | 
						err = VerifyConfig(config)
 | 
				
			||||||
 | 
						t.Log(err)
 | 
				
			||||||
 | 
						if err == nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var bts buffer
 | 
				
			||||||
 | 
						if _, err := Server(&bts, config); err == nil {
 | 
				
			||||||
 | 
							t.Fatal("server started with wrong config")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := Client(&bts, config); err == nil {
 | 
				
			||||||
 | 
							t.Fatal("client started with wrong config")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										527
									
								
								proxy/protocol/smux/session.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										527
									
								
								proxy/protocol/smux/session.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,527 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"container/heap"
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/pool"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const (
 | 
				
			||||||
 | 
						defaultAcceptBacklog = 1024
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					var (
 | 
				
			||||||
 | 
						ErrInvalidProtocol = errors.New("invalid protocol")
 | 
				
			||||||
 | 
						ErrConsumed        = errors.New("peer consumed more than sent")
 | 
				
			||||||
 | 
						ErrGoAway          = errors.New("stream id overflows, should start a new connection")
 | 
				
			||||||
 | 
						ErrTimeout         = errors.New("timeout")
 | 
				
			||||||
 | 
						ErrWouldBlock      = errors.New("operation would block on IO")
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type writeRequest struct {
 | 
				
			||||||
 | 
						prio   uint64
 | 
				
			||||||
 | 
						frame  Frame
 | 
				
			||||||
 | 
						result chan writeResult
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type writeResult struct {
 | 
				
			||||||
 | 
						n   int
 | 
				
			||||||
 | 
						err error
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type buffersWriter interface {
 | 
				
			||||||
 | 
						WriteBuffers(v [][]byte) (n int, err error)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Session defines a multiplexed connection for streams
 | 
				
			||||||
 | 
					type Session struct {
 | 
				
			||||||
 | 
						conn io.ReadWriteCloser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						config           *Config
 | 
				
			||||||
 | 
						nextStreamID     uint32 // next stream identifier
 | 
				
			||||||
 | 
						nextStreamIDLock sync.Mutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						bucket       int32         // token bucket
 | 
				
			||||||
 | 
						bucketNotify chan struct{} // used for waiting for tokens
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						streams    map[uint32]*Stream // all streams in this session
 | 
				
			||||||
 | 
						streamLock sync.Mutex         // locks streams
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						die     chan struct{} // flag session has died
 | 
				
			||||||
 | 
						dieOnce sync.Once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// socket error handling
 | 
				
			||||||
 | 
						socketReadError      atomic.Value
 | 
				
			||||||
 | 
						socketWriteError     atomic.Value
 | 
				
			||||||
 | 
						chSocketReadError    chan struct{}
 | 
				
			||||||
 | 
						chSocketWriteError   chan struct{}
 | 
				
			||||||
 | 
						socketReadErrorOnce  sync.Once
 | 
				
			||||||
 | 
						socketWriteErrorOnce sync.Once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// smux protocol errors
 | 
				
			||||||
 | 
						protoError     atomic.Value
 | 
				
			||||||
 | 
						chProtoError   chan struct{}
 | 
				
			||||||
 | 
						protoErrorOnce sync.Once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						chAccepts chan *Stream
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						dataReady int32 // flag data has arrived
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						goAway int32 // flag id exhausted
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						deadline atomic.Value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						shaper chan writeRequest // a shaper for writing
 | 
				
			||||||
 | 
						writes chan writeRequest
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
 | 
				
			||||||
 | 
						s := new(Session)
 | 
				
			||||||
 | 
						s.die = make(chan struct{})
 | 
				
			||||||
 | 
						s.conn = conn
 | 
				
			||||||
 | 
						s.config = config
 | 
				
			||||||
 | 
						s.streams = make(map[uint32]*Stream)
 | 
				
			||||||
 | 
						s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
 | 
				
			||||||
 | 
						s.bucket = int32(config.MaxReceiveBuffer)
 | 
				
			||||||
 | 
						s.bucketNotify = make(chan struct{}, 1)
 | 
				
			||||||
 | 
						s.shaper = make(chan writeRequest)
 | 
				
			||||||
 | 
						s.writes = make(chan writeRequest)
 | 
				
			||||||
 | 
						s.chSocketReadError = make(chan struct{})
 | 
				
			||||||
 | 
						s.chSocketWriteError = make(chan struct{})
 | 
				
			||||||
 | 
						s.chProtoError = make(chan struct{})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if client {
 | 
				
			||||||
 | 
							s.nextStreamID = 1
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							s.nextStreamID = 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						go s.shaperLoop()
 | 
				
			||||||
 | 
						go s.recvLoop()
 | 
				
			||||||
 | 
						go s.sendLoop()
 | 
				
			||||||
 | 
						if !config.KeepAliveDisabled {
 | 
				
			||||||
 | 
							go s.keepalive()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// OpenStream is used to create a new stream
 | 
				
			||||||
 | 
					func (s *Session) OpenStream() (*Stream, error) {
 | 
				
			||||||
 | 
						if s.IsClosed() {
 | 
				
			||||||
 | 
							return nil, io.ErrClosedPipe
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// generate stream id
 | 
				
			||||||
 | 
						s.nextStreamIDLock.Lock()
 | 
				
			||||||
 | 
						if s.goAway > 0 {
 | 
				
			||||||
 | 
							s.nextStreamIDLock.Unlock()
 | 
				
			||||||
 | 
							return nil, ErrGoAway
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.nextStreamID += 2
 | 
				
			||||||
 | 
						sid := s.nextStreamID
 | 
				
			||||||
 | 
						if sid == sid%2 { // stream-id overflows
 | 
				
			||||||
 | 
							s.goAway = 1
 | 
				
			||||||
 | 
							s.nextStreamIDLock.Unlock()
 | 
				
			||||||
 | 
							return nil, ErrGoAway
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.nextStreamIDLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						stream := newStream(sid, s.config.MaxFrameSize, s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if _, err := s.writeFrame(newFrame(byte(s.config.Version), cmdSYN, sid)); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.streamLock.Lock()
 | 
				
			||||||
 | 
						defer s.streamLock.Unlock()
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.chSocketReadError:
 | 
				
			||||||
 | 
							return nil, s.socketReadError.Load().(error)
 | 
				
			||||||
 | 
						case <-s.chSocketWriteError:
 | 
				
			||||||
 | 
							return nil, s.socketWriteError.Load().(error)
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return nil, io.ErrClosedPipe
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							s.streams[sid] = stream
 | 
				
			||||||
 | 
							return stream, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Open returns a generic ReadWriteCloser
 | 
				
			||||||
 | 
					func (s *Session) Open() (io.ReadWriteCloser, error) {
 | 
				
			||||||
 | 
						return s.OpenStream()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AcceptStream is used to block until the next available stream
 | 
				
			||||||
 | 
					// is ready to be accepted.
 | 
				
			||||||
 | 
					func (s *Session) AcceptStream() (*Stream, error) {
 | 
				
			||||||
 | 
						var deadline <-chan time.Time
 | 
				
			||||||
 | 
						if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
 | 
				
			||||||
 | 
							timer := time.NewTimer(time.Until(d))
 | 
				
			||||||
 | 
							defer timer.Stop()
 | 
				
			||||||
 | 
							deadline = timer.C
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case stream := <-s.chAccepts:
 | 
				
			||||||
 | 
							return stream, nil
 | 
				
			||||||
 | 
						case <-deadline:
 | 
				
			||||||
 | 
							return nil, ErrTimeout
 | 
				
			||||||
 | 
						case <-s.chSocketReadError:
 | 
				
			||||||
 | 
							return nil, s.socketReadError.Load().(error)
 | 
				
			||||||
 | 
						case <-s.chProtoError:
 | 
				
			||||||
 | 
							return nil, s.protoError.Load().(error)
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return nil, io.ErrClosedPipe
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Accept Returns a generic ReadWriteCloser instead of smux.Stream
 | 
				
			||||||
 | 
					func (s *Session) Accept() (io.ReadWriteCloser, error) {
 | 
				
			||||||
 | 
						return s.AcceptStream()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close is used to close the session and all streams.
 | 
				
			||||||
 | 
					func (s *Session) Close() error {
 | 
				
			||||||
 | 
						var once bool
 | 
				
			||||||
 | 
						s.dieOnce.Do(func() {
 | 
				
			||||||
 | 
							close(s.die)
 | 
				
			||||||
 | 
							once = true
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if once {
 | 
				
			||||||
 | 
							s.streamLock.Lock()
 | 
				
			||||||
 | 
							for k := range s.streams {
 | 
				
			||||||
 | 
								s.streams[k].sessionClose()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.streamLock.Unlock()
 | 
				
			||||||
 | 
							return s.conn.Close()
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return io.ErrClosedPipe
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// notifyBucket notifies recvLoop that bucket is available
 | 
				
			||||||
 | 
					func (s *Session) notifyBucket() {
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case s.bucketNotify <- struct{}{}:
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Session) notifyReadError(err error) {
 | 
				
			||||||
 | 
						s.socketReadErrorOnce.Do(func() {
 | 
				
			||||||
 | 
							s.socketReadError.Store(err)
 | 
				
			||||||
 | 
							close(s.chSocketReadError)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Session) notifyWriteError(err error) {
 | 
				
			||||||
 | 
						s.socketWriteErrorOnce.Do(func() {
 | 
				
			||||||
 | 
							s.socketWriteError.Store(err)
 | 
				
			||||||
 | 
							close(s.chSocketWriteError)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Session) notifyProtoError(err error) {
 | 
				
			||||||
 | 
						s.protoErrorOnce.Do(func() {
 | 
				
			||||||
 | 
							s.protoError.Store(err)
 | 
				
			||||||
 | 
							close(s.chProtoError)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// IsClosed does a safe check to see if we have shutdown
 | 
				
			||||||
 | 
					func (s *Session) IsClosed() bool {
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return true
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return false
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NumStreams returns the number of currently open streams
 | 
				
			||||||
 | 
					func (s *Session) NumStreams() int {
 | 
				
			||||||
 | 
						if s.IsClosed() {
 | 
				
			||||||
 | 
							return 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.streamLock.Lock()
 | 
				
			||||||
 | 
						defer s.streamLock.Unlock()
 | 
				
			||||||
 | 
						return len(s.streams)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetDeadline sets a deadline used by Accept* calls.
 | 
				
			||||||
 | 
					// A zero time value disables the deadline.
 | 
				
			||||||
 | 
					func (s *Session) SetDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						s.deadline.Store(t)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LocalAddr satisfies net.Conn interface
 | 
				
			||||||
 | 
					func (s *Session) LocalAddr() net.Addr {
 | 
				
			||||||
 | 
						if ts, ok := s.conn.(interface {
 | 
				
			||||||
 | 
							LocalAddr() net.Addr
 | 
				
			||||||
 | 
						}); ok {
 | 
				
			||||||
 | 
							return ts.LocalAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RemoteAddr satisfies net.Conn interface
 | 
				
			||||||
 | 
					func (s *Session) RemoteAddr() net.Addr {
 | 
				
			||||||
 | 
						if ts, ok := s.conn.(interface {
 | 
				
			||||||
 | 
							RemoteAddr() net.Addr
 | 
				
			||||||
 | 
						}); ok {
 | 
				
			||||||
 | 
							return ts.RemoteAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// notify the session that a stream has closed
 | 
				
			||||||
 | 
					func (s *Session) streamClosed(sid uint32) {
 | 
				
			||||||
 | 
						s.streamLock.Lock()
 | 
				
			||||||
 | 
						if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
 | 
				
			||||||
 | 
							if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
 | 
				
			||||||
 | 
								s.notifyBucket()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						delete(s.streams, sid)
 | 
				
			||||||
 | 
						s.streamLock.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// returnTokens is called by stream to return token after read
 | 
				
			||||||
 | 
					func (s *Session) returnTokens(n int) {
 | 
				
			||||||
 | 
						if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
 | 
				
			||||||
 | 
							s.notifyBucket()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// recvLoop keeps on reading from underlying connection if tokens are available
 | 
				
			||||||
 | 
					func (s *Session) recvLoop() {
 | 
				
			||||||
 | 
						var hdr rawHeader
 | 
				
			||||||
 | 
						var updHdr updHeader
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-s.bucketNotify:
 | 
				
			||||||
 | 
								case <-s.die:
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// read header first
 | 
				
			||||||
 | 
							if _, err := io.ReadFull(s.conn, hdr[:]); err == nil {
 | 
				
			||||||
 | 
								atomic.StoreInt32(&s.dataReady, 1)
 | 
				
			||||||
 | 
								if hdr.Version() != byte(s.config.Version) {
 | 
				
			||||||
 | 
									s.notifyProtoError(ErrInvalidProtocol)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								sid := hdr.StreamID()
 | 
				
			||||||
 | 
								switch hdr.Cmd() {
 | 
				
			||||||
 | 
								case cmdNOP:
 | 
				
			||||||
 | 
								case cmdSYN:
 | 
				
			||||||
 | 
									s.streamLock.Lock()
 | 
				
			||||||
 | 
									if _, ok := s.streams[sid]; !ok {
 | 
				
			||||||
 | 
										stream := newStream(sid, s.config.MaxFrameSize, s)
 | 
				
			||||||
 | 
										s.streams[sid] = stream
 | 
				
			||||||
 | 
										select {
 | 
				
			||||||
 | 
										case s.chAccepts <- stream:
 | 
				
			||||||
 | 
										case <-s.die:
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									s.streamLock.Unlock()
 | 
				
			||||||
 | 
								case cmdFIN:
 | 
				
			||||||
 | 
									s.streamLock.Lock()
 | 
				
			||||||
 | 
									if stream, ok := s.streams[sid]; ok {
 | 
				
			||||||
 | 
										stream.fin()
 | 
				
			||||||
 | 
										stream.notifyReadEvent()
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									s.streamLock.Unlock()
 | 
				
			||||||
 | 
								case cmdPSH:
 | 
				
			||||||
 | 
									if hdr.Length() > 0 {
 | 
				
			||||||
 | 
										newbuf := pool.GetBuffer(int(hdr.Length()))
 | 
				
			||||||
 | 
										if written, err := io.ReadFull(s.conn, newbuf); err == nil {
 | 
				
			||||||
 | 
											s.streamLock.Lock()
 | 
				
			||||||
 | 
											if stream, ok := s.streams[sid]; ok {
 | 
				
			||||||
 | 
												stream.pushBytes(newbuf)
 | 
				
			||||||
 | 
												atomic.AddInt32(&s.bucket, -int32(written))
 | 
				
			||||||
 | 
												stream.notifyReadEvent()
 | 
				
			||||||
 | 
											}
 | 
				
			||||||
 | 
											s.streamLock.Unlock()
 | 
				
			||||||
 | 
										} else {
 | 
				
			||||||
 | 
											s.notifyReadError(err)
 | 
				
			||||||
 | 
											return
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								case cmdUPD:
 | 
				
			||||||
 | 
									if _, err := io.ReadFull(s.conn, updHdr[:]); err == nil {
 | 
				
			||||||
 | 
										s.streamLock.Lock()
 | 
				
			||||||
 | 
										if stream, ok := s.streams[sid]; ok {
 | 
				
			||||||
 | 
											stream.update(updHdr.Consumed(), updHdr.Window())
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										s.streamLock.Unlock()
 | 
				
			||||||
 | 
									} else {
 | 
				
			||||||
 | 
										s.notifyReadError(err)
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
									s.notifyProtoError(ErrInvalidProtocol)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								s.notifyReadError(err)
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Session) keepalive() {
 | 
				
			||||||
 | 
						tickerPing := time.NewTicker(s.config.KeepAliveInterval)
 | 
				
			||||||
 | 
						tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
 | 
				
			||||||
 | 
						defer tickerPing.Stop()
 | 
				
			||||||
 | 
						defer tickerTimeout.Stop()
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-tickerPing.C:
 | 
				
			||||||
 | 
								s.writeFrameInternal(newFrame(byte(s.config.Version), cmdNOP, 0), tickerPing.C, 0)
 | 
				
			||||||
 | 
								s.notifyBucket() // force a signal to the recvLoop
 | 
				
			||||||
 | 
							case <-tickerTimeout.C:
 | 
				
			||||||
 | 
								if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
 | 
				
			||||||
 | 
									// recvLoop may block while bucket is 0, in this case,
 | 
				
			||||||
 | 
									// session should not be closed.
 | 
				
			||||||
 | 
									if atomic.LoadInt32(&s.bucket) > 0 {
 | 
				
			||||||
 | 
										s.Close()
 | 
				
			||||||
 | 
										return
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							case <-s.die:
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// shaper shapes the sending sequence among streams
 | 
				
			||||||
 | 
					func (s *Session) shaperLoop() {
 | 
				
			||||||
 | 
						var reqs shaperHeap
 | 
				
			||||||
 | 
						var next writeRequest
 | 
				
			||||||
 | 
						var chWrite chan writeRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							if len(reqs) > 0 {
 | 
				
			||||||
 | 
								chWrite = s.writes
 | 
				
			||||||
 | 
								next = heap.Pop(&reqs).(writeRequest)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								chWrite = nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-s.die:
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							case r := <-s.shaper:
 | 
				
			||||||
 | 
								if chWrite != nil { // next is valid, reshape
 | 
				
			||||||
 | 
									heap.Push(&reqs, next)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								heap.Push(&reqs, r)
 | 
				
			||||||
 | 
							case chWrite <- next:
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Session) sendLoop() {
 | 
				
			||||||
 | 
						var buf []byte
 | 
				
			||||||
 | 
						var n int
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						var vec [][]byte // vector for writeBuffers
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						bw, ok := s.conn.(buffersWriter)
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							buf = make([]byte, headerSize)
 | 
				
			||||||
 | 
							vec = make([][]byte, 2)
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							buf = make([]byte, (1<<16)+headerSize)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							select {
 | 
				
			||||||
 | 
							case <-s.die:
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							case request := <-s.writes:
 | 
				
			||||||
 | 
								buf[0] = request.frame.ver
 | 
				
			||||||
 | 
								buf[1] = request.frame.cmd
 | 
				
			||||||
 | 
								binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
 | 
				
			||||||
 | 
								binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if len(vec) > 0 {
 | 
				
			||||||
 | 
									vec[0] = buf[:headerSize]
 | 
				
			||||||
 | 
									vec[1] = request.frame.data
 | 
				
			||||||
 | 
									n, err = bw.WriteBuffers(vec)
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									copy(buf[headerSize:], request.frame.data)
 | 
				
			||||||
 | 
									n, err = s.conn.Write(buf[:headerSize+len(request.frame.data)])
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								n -= headerSize
 | 
				
			||||||
 | 
								if n < 0 {
 | 
				
			||||||
 | 
									n = 0
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								result := writeResult{
 | 
				
			||||||
 | 
									n:   n,
 | 
				
			||||||
 | 
									err: err,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								request.result <- result
 | 
				
			||||||
 | 
								close(request.result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// store conn error
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									s.notifyWriteError(err)
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// writeFrame writes the frame to the underlying connection
 | 
				
			||||||
 | 
					// and returns the number of bytes written if successful
 | 
				
			||||||
 | 
					func (s *Session) writeFrame(f Frame) (n int, err error) {
 | 
				
			||||||
 | 
						return s.writeFrameInternal(f, nil, 0)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// internal writeFrame version to support deadline used in keepalive
 | 
				
			||||||
 | 
					func (s *Session) writeFrameInternal(f Frame, deadline <-chan time.Time, prio uint64) (int, error) {
 | 
				
			||||||
 | 
						req := writeRequest{
 | 
				
			||||||
 | 
							prio:   prio,
 | 
				
			||||||
 | 
							frame:  f,
 | 
				
			||||||
 | 
							result: make(chan writeResult, 1),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case s.shaper <- req:
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.ErrClosedPipe
 | 
				
			||||||
 | 
						case <-s.chSocketWriteError:
 | 
				
			||||||
 | 
							return 0, s.socketWriteError.Load().(error)
 | 
				
			||||||
 | 
						case <-deadline:
 | 
				
			||||||
 | 
							return 0, ErrTimeout
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case result := <-req.result:
 | 
				
			||||||
 | 
							return result.n, result.err
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.ErrClosedPipe
 | 
				
			||||||
 | 
						case <-s.chSocketWriteError:
 | 
				
			||||||
 | 
							return 0, s.socketWriteError.Load().(error)
 | 
				
			||||||
 | 
						case <-deadline:
 | 
				
			||||||
 | 
							return 0, ErrTimeout
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										1090
									
								
								proxy/protocol/smux/session_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1090
									
								
								proxy/protocol/smux/session_test.go
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										16
									
								
								proxy/protocol/smux/shaper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								proxy/protocol/smux/shaper.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,16 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type shaperHeap []writeRequest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h shaperHeap) Len() int            { return len(h) }
 | 
				
			||||||
 | 
					func (h shaperHeap) Less(i, j int) bool  { return h[i].prio < h[j].prio }
 | 
				
			||||||
 | 
					func (h shaperHeap) Swap(i, j int)       { h[i], h[j] = h[j], h[i] }
 | 
				
			||||||
 | 
					func (h *shaperHeap) Push(x interface{}) { *h = append(*h, x.(writeRequest)) }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (h *shaperHeap) Pop() interface{} {
 | 
				
			||||||
 | 
						old := *h
 | 
				
			||||||
 | 
						n := len(old)
 | 
				
			||||||
 | 
						x := old[n-1]
 | 
				
			||||||
 | 
						*h = old[0 : n-1]
 | 
				
			||||||
 | 
						return x
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										30
									
								
								proxy/protocol/smux/shaper_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								proxy/protocol/smux/shaper_test.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,30 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"container/heap"
 | 
				
			||||||
 | 
						"testing"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestShaper(t *testing.T) {
 | 
				
			||||||
 | 
						w1 := writeRequest{prio: 10}
 | 
				
			||||||
 | 
						w2 := writeRequest{prio: 10}
 | 
				
			||||||
 | 
						w3 := writeRequest{prio: 20}
 | 
				
			||||||
 | 
						w4 := writeRequest{prio: 100}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var reqs shaperHeap
 | 
				
			||||||
 | 
						heap.Push(&reqs, w4)
 | 
				
			||||||
 | 
						heap.Push(&reqs, w3)
 | 
				
			||||||
 | 
						heap.Push(&reqs, w2)
 | 
				
			||||||
 | 
						heap.Push(&reqs, w1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var lastPrio uint64
 | 
				
			||||||
 | 
						for len(reqs) > 0 {
 | 
				
			||||||
 | 
							w := heap.Pop(&reqs).(writeRequest)
 | 
				
			||||||
 | 
							if w.prio < lastPrio {
 | 
				
			||||||
 | 
								t.Fatal("incorrect shaper priority")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							t.Log("prio:", w.prio)
 | 
				
			||||||
 | 
							lastPrio = w.prio
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										545
									
								
								proxy/protocol/smux/stream.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										545
									
								
								proxy/protocol/smux/stream.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,545 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"sync/atomic"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/pool"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Stream implements net.Conn
 | 
				
			||||||
 | 
					type Stream struct {
 | 
				
			||||||
 | 
						id   uint32
 | 
				
			||||||
 | 
						sess *Session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						buffers [][]byte
 | 
				
			||||||
 | 
						heads   [][]byte // slice heads kept for recycle
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						bufferLock sync.Mutex
 | 
				
			||||||
 | 
						frameSize  int
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// notify a read event
 | 
				
			||||||
 | 
						chReadEvent chan struct{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// flag the stream has closed
 | 
				
			||||||
 | 
						die     chan struct{}
 | 
				
			||||||
 | 
						dieOnce sync.Once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// FIN command
 | 
				
			||||||
 | 
						chFinEvent   chan struct{}
 | 
				
			||||||
 | 
						finEventOnce sync.Once
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// deadlines
 | 
				
			||||||
 | 
						readDeadline  atomic.Value
 | 
				
			||||||
 | 
						writeDeadline atomic.Value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// per stream sliding window control
 | 
				
			||||||
 | 
						numRead    uint32 // number of consumed bytes
 | 
				
			||||||
 | 
						numWritten uint32 // count num of bytes written
 | 
				
			||||||
 | 
						incr       uint32 // counting for sending
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// UPD command
 | 
				
			||||||
 | 
						peerConsumed uint32        // num of bytes the peer has consumed
 | 
				
			||||||
 | 
						peerWindow   uint32        // peer window, initialized to 256KB, updated by peer
 | 
				
			||||||
 | 
						chUpdate     chan struct{} // notify of remote data consuming and window update
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// newStream initiates a Stream struct
 | 
				
			||||||
 | 
					func newStream(id uint32, frameSize int, sess *Session) *Stream {
 | 
				
			||||||
 | 
						s := new(Stream)
 | 
				
			||||||
 | 
						s.id = id
 | 
				
			||||||
 | 
						s.chReadEvent = make(chan struct{}, 1)
 | 
				
			||||||
 | 
						s.chUpdate = make(chan struct{}, 1)
 | 
				
			||||||
 | 
						s.frameSize = frameSize
 | 
				
			||||||
 | 
						s.sess = sess
 | 
				
			||||||
 | 
						s.die = make(chan struct{})
 | 
				
			||||||
 | 
						s.chFinEvent = make(chan struct{})
 | 
				
			||||||
 | 
						s.peerWindow = initialPeerWindow // set to initial window size
 | 
				
			||||||
 | 
						return s
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ID returns the unique stream ID.
 | 
				
			||||||
 | 
					func (s *Stream) ID() uint32 {
 | 
				
			||||||
 | 
						return s.id
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Read implements net.Conn
 | 
				
			||||||
 | 
					func (s *Stream) Read(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							n, err = s.tryRead(b)
 | 
				
			||||||
 | 
							if err == ErrWouldBlock {
 | 
				
			||||||
 | 
								if ew := s.waitRead(); ew != nil {
 | 
				
			||||||
 | 
									return 0, ew
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return n, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// tryRead is the nonblocking version of Read
 | 
				
			||||||
 | 
					func (s *Stream) tryRead(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						if s.sess.config.Version == 2 {
 | 
				
			||||||
 | 
							return s.tryReadv2(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(b) == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						s.bufferLock.Lock()
 | 
				
			||||||
 | 
						if len(s.buffers) > 0 {
 | 
				
			||||||
 | 
							n = copy(b, s.buffers[0])
 | 
				
			||||||
 | 
							s.buffers[0] = s.buffers[0][n:]
 | 
				
			||||||
 | 
							if len(s.buffers[0]) == 0 {
 | 
				
			||||||
 | 
								s.buffers[0] = nil
 | 
				
			||||||
 | 
								s.buffers = s.buffers[1:]
 | 
				
			||||||
 | 
								// full recycle
 | 
				
			||||||
 | 
								pool.PutBuffer(s.heads[0])
 | 
				
			||||||
 | 
								s.heads = s.heads[1:]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.bufferLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if n > 0 {
 | 
				
			||||||
 | 
							s.sess.returnTokens(n)
 | 
				
			||||||
 | 
							return n, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.EOF
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return 0, ErrWouldBlock
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Stream) tryReadv2(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						if len(b) == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var notifyConsumed uint32
 | 
				
			||||||
 | 
						s.bufferLock.Lock()
 | 
				
			||||||
 | 
						if len(s.buffers) > 0 {
 | 
				
			||||||
 | 
							n = copy(b, s.buffers[0])
 | 
				
			||||||
 | 
							s.buffers[0] = s.buffers[0][n:]
 | 
				
			||||||
 | 
							if len(s.buffers[0]) == 0 {
 | 
				
			||||||
 | 
								s.buffers[0] = nil
 | 
				
			||||||
 | 
								s.buffers = s.buffers[1:]
 | 
				
			||||||
 | 
								// full recycle
 | 
				
			||||||
 | 
								pool.PutBuffer(s.heads[0])
 | 
				
			||||||
 | 
								s.heads = s.heads[1:]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// in an ideal environment:
 | 
				
			||||||
 | 
						// if more than half of buffer has consumed, send read ack to peer
 | 
				
			||||||
 | 
						// based on round-trip time of ACK, continous flowing data
 | 
				
			||||||
 | 
						// won't slow down because of waiting for ACK, as long as the
 | 
				
			||||||
 | 
						// consumer keeps on reading data
 | 
				
			||||||
 | 
						// s.numRead == n also notify window at the first read
 | 
				
			||||||
 | 
						s.numRead += uint32(n)
 | 
				
			||||||
 | 
						s.incr += uint32(n)
 | 
				
			||||||
 | 
						if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
 | 
				
			||||||
 | 
							notifyConsumed = s.numRead
 | 
				
			||||||
 | 
							s.incr = 0
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.bufferLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if n > 0 {
 | 
				
			||||||
 | 
							s.sess.returnTokens(n)
 | 
				
			||||||
 | 
							if notifyConsumed > 0 {
 | 
				
			||||||
 | 
								err := s.sendWindowUpdate(notifyConsumed)
 | 
				
			||||||
 | 
								return n, err
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return n, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.EOF
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							return 0, ErrWouldBlock
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// WriteTo implements io.WriteTo
 | 
				
			||||||
 | 
					func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
 | 
				
			||||||
 | 
						if s.sess.config.Version == 2 {
 | 
				
			||||||
 | 
							return s.writeTov2(w)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							var buf []byte
 | 
				
			||||||
 | 
							s.bufferLock.Lock()
 | 
				
			||||||
 | 
							if len(s.buffers) > 0 {
 | 
				
			||||||
 | 
								buf = s.buffers[0]
 | 
				
			||||||
 | 
								s.buffers = s.buffers[1:]
 | 
				
			||||||
 | 
								s.heads = s.heads[1:]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.bufferLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if buf != nil {
 | 
				
			||||||
 | 
								nw, ew := w.Write(buf)
 | 
				
			||||||
 | 
								s.sess.returnTokens(len(buf))
 | 
				
			||||||
 | 
								pool.PutBuffer(buf)
 | 
				
			||||||
 | 
								if nw > 0 {
 | 
				
			||||||
 | 
									n += int64(nw)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if ew != nil {
 | 
				
			||||||
 | 
									return n, ew
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if ew := s.waitRead(); ew != nil {
 | 
				
			||||||
 | 
								return n, ew
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							var notifyConsumed uint32
 | 
				
			||||||
 | 
							var buf []byte
 | 
				
			||||||
 | 
							s.bufferLock.Lock()
 | 
				
			||||||
 | 
							if len(s.buffers) > 0 {
 | 
				
			||||||
 | 
								buf = s.buffers[0]
 | 
				
			||||||
 | 
								s.buffers = s.buffers[1:]
 | 
				
			||||||
 | 
								s.heads = s.heads[1:]
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.numRead += uint32(len(buf))
 | 
				
			||||||
 | 
							s.incr += uint32(len(buf))
 | 
				
			||||||
 | 
							if s.incr >= uint32(s.sess.config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
 | 
				
			||||||
 | 
								notifyConsumed = s.numRead
 | 
				
			||||||
 | 
								s.incr = 0
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.bufferLock.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if buf != nil {
 | 
				
			||||||
 | 
								nw, ew := w.Write(buf)
 | 
				
			||||||
 | 
								s.sess.returnTokens(len(buf))
 | 
				
			||||||
 | 
								pool.PutBuffer(buf)
 | 
				
			||||||
 | 
								if nw > 0 {
 | 
				
			||||||
 | 
									n += int64(nw)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if ew != nil {
 | 
				
			||||||
 | 
									return n, ew
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								if notifyConsumed > 0 {
 | 
				
			||||||
 | 
									if err := s.sendWindowUpdate(notifyConsumed); err != nil {
 | 
				
			||||||
 | 
										return n, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else if ew := s.waitRead(); ew != nil {
 | 
				
			||||||
 | 
								return n, ew
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Stream) sendWindowUpdate(consumed uint32) error {
 | 
				
			||||||
 | 
						var timer *time.Timer
 | 
				
			||||||
 | 
						var deadline <-chan time.Time
 | 
				
			||||||
 | 
						if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
 | 
				
			||||||
 | 
							timer = time.NewTimer(time.Until(d))
 | 
				
			||||||
 | 
							defer timer.Stop()
 | 
				
			||||||
 | 
							deadline = timer.C
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						frame := newFrame(byte(s.sess.config.Version), cmdUPD, s.id)
 | 
				
			||||||
 | 
						var hdr updHeader
 | 
				
			||||||
 | 
						binary.LittleEndian.PutUint32(hdr[:], consumed)
 | 
				
			||||||
 | 
						binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.config.MaxStreamBuffer))
 | 
				
			||||||
 | 
						frame.data = hdr[:]
 | 
				
			||||||
 | 
						_, err := s.sess.writeFrameInternal(frame, deadline, 0)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Stream) waitRead() error {
 | 
				
			||||||
 | 
						var timer *time.Timer
 | 
				
			||||||
 | 
						var deadline <-chan time.Time
 | 
				
			||||||
 | 
						if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
 | 
				
			||||||
 | 
							timer = time.NewTimer(time.Until(d))
 | 
				
			||||||
 | 
							defer timer.Stop()
 | 
				
			||||||
 | 
							deadline = timer.C
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.chReadEvent:
 | 
				
			||||||
 | 
							return nil
 | 
				
			||||||
 | 
						case <-s.chFinEvent:
 | 
				
			||||||
 | 
							return io.EOF
 | 
				
			||||||
 | 
						case <-s.sess.chSocketReadError:
 | 
				
			||||||
 | 
							return s.sess.socketReadError.Load().(error)
 | 
				
			||||||
 | 
						case <-s.sess.chProtoError:
 | 
				
			||||||
 | 
							return s.sess.protoError.Load().(error)
 | 
				
			||||||
 | 
						case <-deadline:
 | 
				
			||||||
 | 
							return ErrTimeout
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return io.ErrClosedPipe
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Write implements net.Conn
 | 
				
			||||||
 | 
					//
 | 
				
			||||||
 | 
					// Note that the behavior when multiple goroutines write concurrently is not deterministic,
 | 
				
			||||||
 | 
					// frames may interleave in random way.
 | 
				
			||||||
 | 
					func (s *Stream) Write(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						if s.sess.config.Version == 2 {
 | 
				
			||||||
 | 
							return s.writeV2(b)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var deadline <-chan time.Time
 | 
				
			||||||
 | 
						if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
 | 
				
			||||||
 | 
							timer := time.NewTimer(time.Until(d))
 | 
				
			||||||
 | 
							defer timer.Stop()
 | 
				
			||||||
 | 
							deadline = timer.C
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// check if stream has closed
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.ErrClosedPipe
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// frame split and transmit
 | 
				
			||||||
 | 
						sent := 0
 | 
				
			||||||
 | 
						frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
 | 
				
			||||||
 | 
						bts := b
 | 
				
			||||||
 | 
						for len(bts) > 0 {
 | 
				
			||||||
 | 
							sz := len(bts)
 | 
				
			||||||
 | 
							if sz > s.frameSize {
 | 
				
			||||||
 | 
								sz = s.frameSize
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							frame.data = bts[:sz]
 | 
				
			||||||
 | 
							bts = bts[sz:]
 | 
				
			||||||
 | 
							n, err := s.sess.writeFrameInternal(frame, deadline, uint64(s.numWritten))
 | 
				
			||||||
 | 
							s.numWritten++
 | 
				
			||||||
 | 
							sent += n
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return sent, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return sent, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *Stream) writeV2(b []byte) (n int, err error) {
 | 
				
			||||||
 | 
						// check empty input
 | 
				
			||||||
 | 
						if len(b) == 0 {
 | 
				
			||||||
 | 
							return 0, nil
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// check if stream has closed
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case <-s.die:
 | 
				
			||||||
 | 
							return 0, io.ErrClosedPipe
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// create write deadline timer
 | 
				
			||||||
 | 
						var deadline <-chan time.Time
 | 
				
			||||||
 | 
						if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
 | 
				
			||||||
 | 
							timer := time.NewTimer(time.Until(d))
 | 
				
			||||||
 | 
							defer timer.Stop()
 | 
				
			||||||
 | 
							deadline = timer.C
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// frame split and transmit process
 | 
				
			||||||
 | 
						sent := 0
 | 
				
			||||||
 | 
						frame := newFrame(byte(s.sess.config.Version), cmdPSH, s.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							// per stream sliding window control
 | 
				
			||||||
 | 
							// [.... [consumed... numWritten] ... win... ]
 | 
				
			||||||
 | 
							// [.... [consumed...................+rmtwnd]]
 | 
				
			||||||
 | 
							var bts []byte
 | 
				
			||||||
 | 
							// note:
 | 
				
			||||||
 | 
							// even if uint32 overflow, this math still works:
 | 
				
			||||||
 | 
							// eg1: uint32(0) - uint32(math.MaxUint32) = 1
 | 
				
			||||||
 | 
							// eg2: int32(uint32(0) - uint32(1)) = -1
 | 
				
			||||||
 | 
							// security check for misbehavior
 | 
				
			||||||
 | 
							inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
 | 
				
			||||||
 | 
							if inflight < 0 {
 | 
				
			||||||
 | 
								return 0, ErrConsumed
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
 | 
				
			||||||
 | 
							if win > 0 {
 | 
				
			||||||
 | 
								if win > int32(len(b)) {
 | 
				
			||||||
 | 
									bts = b
 | 
				
			||||||
 | 
									b = nil
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									bts = b[:win]
 | 
				
			||||||
 | 
									b = b[win:]
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								for len(bts) > 0 {
 | 
				
			||||||
 | 
									sz := len(bts)
 | 
				
			||||||
 | 
									if sz > s.frameSize {
 | 
				
			||||||
 | 
										sz = s.frameSize
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									frame.data = bts[:sz]
 | 
				
			||||||
 | 
									bts = bts[sz:]
 | 
				
			||||||
 | 
									n, err := s.sess.writeFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
 | 
				
			||||||
 | 
									atomic.AddUint32(&s.numWritten, uint32(sz))
 | 
				
			||||||
 | 
									sent += n
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return sent, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// if there is any data remaining to be sent
 | 
				
			||||||
 | 
							// wait until stream closes, window changes or deadline reached
 | 
				
			||||||
 | 
							// this blocking behavior will inform upper layer to do flow control
 | 
				
			||||||
 | 
							if len(b) > 0 {
 | 
				
			||||||
 | 
								select {
 | 
				
			||||||
 | 
								case <-s.chFinEvent: // if fin arrived, future window update is impossible
 | 
				
			||||||
 | 
									return 0, io.EOF
 | 
				
			||||||
 | 
								case <-s.die:
 | 
				
			||||||
 | 
									return sent, io.ErrClosedPipe
 | 
				
			||||||
 | 
								case <-deadline:
 | 
				
			||||||
 | 
									return sent, ErrTimeout
 | 
				
			||||||
 | 
								case <-s.sess.chSocketWriteError:
 | 
				
			||||||
 | 
									return sent, s.sess.socketWriteError.Load().(error)
 | 
				
			||||||
 | 
								case <-s.chUpdate:
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								return sent, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Close implements net.Conn
 | 
				
			||||||
 | 
					func (s *Stream) Close() error {
 | 
				
			||||||
 | 
						var once bool
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						s.dieOnce.Do(func() {
 | 
				
			||||||
 | 
							close(s.die)
 | 
				
			||||||
 | 
							once = true
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if once {
 | 
				
			||||||
 | 
							_, err = s.sess.writeFrame(newFrame(byte(s.sess.config.Version), cmdFIN, s.id))
 | 
				
			||||||
 | 
							s.sess.streamClosed(s.id)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							return io.ErrClosedPipe
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GetDieCh returns a readonly chan which can be readable
 | 
				
			||||||
 | 
					// when the stream is to be closed.
 | 
				
			||||||
 | 
					func (s *Stream) GetDieCh() <-chan struct{} {
 | 
				
			||||||
 | 
						return s.die
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetReadDeadline sets the read deadline as defined by
 | 
				
			||||||
 | 
					// net.Conn.SetReadDeadline.
 | 
				
			||||||
 | 
					// A zero time value disables the deadline.
 | 
				
			||||||
 | 
					func (s *Stream) SetReadDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						s.readDeadline.Store(t)
 | 
				
			||||||
 | 
						s.notifyReadEvent()
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetWriteDeadline sets the write deadline as defined by
 | 
				
			||||||
 | 
					// net.Conn.SetWriteDeadline.
 | 
				
			||||||
 | 
					// A zero time value disables the deadline.
 | 
				
			||||||
 | 
					func (s *Stream) SetWriteDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						s.writeDeadline.Store(t)
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SetDeadline sets both read and write deadlines as defined by
 | 
				
			||||||
 | 
					// net.Conn.SetDeadline.
 | 
				
			||||||
 | 
					// A zero time value disables the deadlines.
 | 
				
			||||||
 | 
					func (s *Stream) SetDeadline(t time.Time) error {
 | 
				
			||||||
 | 
						if err := s.SetReadDeadline(t); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := s.SetWriteDeadline(t); err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// session closes
 | 
				
			||||||
 | 
					func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// LocalAddr satisfies net.Conn interface
 | 
				
			||||||
 | 
					func (s *Stream) LocalAddr() net.Addr {
 | 
				
			||||||
 | 
						if ts, ok := s.sess.conn.(interface {
 | 
				
			||||||
 | 
							LocalAddr() net.Addr
 | 
				
			||||||
 | 
						}); ok {
 | 
				
			||||||
 | 
							return ts.LocalAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// RemoteAddr satisfies net.Conn interface
 | 
				
			||||||
 | 
					func (s *Stream) RemoteAddr() net.Addr {
 | 
				
			||||||
 | 
						if ts, ok := s.sess.conn.(interface {
 | 
				
			||||||
 | 
							RemoteAddr() net.Addr
 | 
				
			||||||
 | 
						}); ok {
 | 
				
			||||||
 | 
							return ts.RemoteAddr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// pushBytes append buf to buffers
 | 
				
			||||||
 | 
					func (s *Stream) pushBytes(buf []byte) (written int, err error) {
 | 
				
			||||||
 | 
						s.bufferLock.Lock()
 | 
				
			||||||
 | 
						s.buffers = append(s.buffers, buf)
 | 
				
			||||||
 | 
						s.heads = append(s.heads, buf)
 | 
				
			||||||
 | 
						s.bufferLock.Unlock()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// recycleTokens transform remaining bytes to tokens(will truncate buffer)
 | 
				
			||||||
 | 
					func (s *Stream) recycleTokens() (n int) {
 | 
				
			||||||
 | 
						s.bufferLock.Lock()
 | 
				
			||||||
 | 
						for k := range s.buffers {
 | 
				
			||||||
 | 
							n += len(s.buffers[k])
 | 
				
			||||||
 | 
							pool.PutBuffer(s.heads[k])
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.buffers = nil
 | 
				
			||||||
 | 
						s.heads = nil
 | 
				
			||||||
 | 
						s.bufferLock.Unlock()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// notify read event
 | 
				
			||||||
 | 
					func (s *Stream) notifyReadEvent() {
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case s.chReadEvent <- struct{}{}:
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// update command
 | 
				
			||||||
 | 
					func (s *Stream) update(consumed uint32, window uint32) {
 | 
				
			||||||
 | 
						atomic.StoreUint32(&s.peerConsumed, consumed)
 | 
				
			||||||
 | 
						atomic.StoreUint32(&s.peerWindow, window)
 | 
				
			||||||
 | 
						select {
 | 
				
			||||||
 | 
						case s.chUpdate <- struct{}{}:
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// mark this stream has been closed in protocol
 | 
				
			||||||
 | 
					func (s *Stream) fin() {
 | 
				
			||||||
 | 
						s.finEventOnce.Do(func() {
 | 
				
			||||||
 | 
							close(s.chFinEvent)
 | 
				
			||||||
 | 
						})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										76
									
								
								proxy/smux/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								proxy/smux/client.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,76 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/log"
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/proxy"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/proxy/protocol/smux"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SmuxClient struct.
 | 
				
			||||||
 | 
					type SmuxClient struct {
 | 
				
			||||||
 | 
						dialer  proxy.Dialer
 | 
				
			||||||
 | 
						addr    string
 | 
				
			||||||
 | 
						session *smux.Session
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						proxy.RegisterDialer("smux", NewSmuxDialer)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewSmuxDialer returns a smux dialer.
 | 
				
			||||||
 | 
					func NewSmuxDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
 | 
				
			||||||
 | 
						u, err := url.Parse(s)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.F("[smux] parse url err: %s", err)
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c := &SmuxClient{
 | 
				
			||||||
 | 
							dialer: d,
 | 
				
			||||||
 | 
							addr:   u.Host,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return c, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Addr returns forwarder's address.
 | 
				
			||||||
 | 
					func (s *SmuxClient) Addr() string {
 | 
				
			||||||
 | 
						if s.addr == "" {
 | 
				
			||||||
 | 
							return s.dialer.Addr()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return s.addr
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Dial connects to the address addr on the network net via the proxy.
 | 
				
			||||||
 | 
					func (s *SmuxClient) Dial(network, addr string) (net.Conn, error) {
 | 
				
			||||||
 | 
						if s.session != nil {
 | 
				
			||||||
 | 
							if c, err := s.session.OpenStream(); err == nil {
 | 
				
			||||||
 | 
								return c, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							s.session.Close()
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if err := s.initConn(); err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return s.session.OpenStream()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DialUDP connects to the given address via the proxy.
 | 
				
			||||||
 | 
					func (s *SmuxClient) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) {
 | 
				
			||||||
 | 
						return nil, nil, errors.New("smux client does not support udp now")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SmuxClient) initConn() error {
 | 
				
			||||||
 | 
						conn, err := s.dialer.Dial("tcp", s.addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.F("[smux] dial to %s error: %s", s.addr, err)
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						s.session, err = smux.Client(conn, nil)
 | 
				
			||||||
 | 
						return err
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										98
									
								
								proxy/smux/server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								proxy/smux/server.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,98 @@
 | 
				
			|||||||
 | 
					package smux
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
 | 
						"net/url"
 | 
				
			||||||
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/log"
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/proxy"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						"github.com/nadoo/glider/proxy/protocol/smux"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// SmuxServer struct.
 | 
				
			||||||
 | 
					type SmuxServer struct {
 | 
				
			||||||
 | 
						proxy  proxy.Proxy
 | 
				
			||||||
 | 
						addr   string
 | 
				
			||||||
 | 
						server proxy.Server
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func init() {
 | 
				
			||||||
 | 
						proxy.RegisterServer("smux", NewSmuxServer)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewSmuxServer returns a smux transport layer before the real server.
 | 
				
			||||||
 | 
					func NewSmuxServer(s string, p proxy.Proxy) (proxy.Server, error) {
 | 
				
			||||||
 | 
						transport := strings.Split(s, ",")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						u, err := url.Parse(transport[0])
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.F("[smux] parse url err: %s", err)
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m := &SmuxServer{
 | 
				
			||||||
 | 
							proxy: p,
 | 
				
			||||||
 | 
							addr:  u.Host,
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if len(transport) > 1 {
 | 
				
			||||||
 | 
							m.server, err = proxy.ServerFromURL(transport[1], p)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return nil, err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return m, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ListenAndServe listens on server's addr and serves connections.
 | 
				
			||||||
 | 
					func (s *SmuxServer) ListenAndServe() {
 | 
				
			||||||
 | 
						l, err := net.Listen("tcp", s.addr)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.F("[smux] failed to listen on %s: %v", s.addr, err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						defer l.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						log.F("[smux] listening mux on %s", s.addr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							c, err := l.Accept()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								log.F("[smux] failed to accept: %v", err)
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							go s.Serve(c)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Serve serves a connection.
 | 
				
			||||||
 | 
					func (s *SmuxServer) Serve(c net.Conn) {
 | 
				
			||||||
 | 
						// we know the internal server will close the connection after serve
 | 
				
			||||||
 | 
						// defer c.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						session, err := smux.Server(c, nil)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							log.F("[smux] failed to create session: %v", err)
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for {
 | 
				
			||||||
 | 
							// Accept a stream
 | 
				
			||||||
 | 
							stream, err := session.AcceptStream()
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								session.Close()
 | 
				
			||||||
 | 
								break
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							go s.ServeStream(stream)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (s *SmuxServer) ServeStream(c *smux.Stream) {
 | 
				
			||||||
 | 
						if s.server != nil {
 | 
				
			||||||
 | 
							s.server.Serve(c)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user