From 88eee75aa718a56bbc8ec8340fc1245c144e7eee Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Tue, 24 Jul 2018 00:45:41 +0800 Subject: [PATCH] ws: add support for security key check --- README.md | 9 ++++-- conf.go | 11 +++++--- dev.go | 6 +++- proxy/vmess/chunk.go | 48 +++++++++++++++----------------- proxy/ws/client.go | 66 ++++++++++++++++++++++++++++++-------------- proxy/ws/frame.go | 32 +++++++++------------ 6 files changed, 98 insertions(+), 74 deletions(-) diff --git a/README.md b/README.md index 66feb4d..a779c05 100644 --- a/README.md +++ b/README.md @@ -155,7 +155,7 @@ TLS scheme: tls://host:port[?skipVerify=true] TLS with a specified proxy protocol: - tls://host:port[?skipVerify=true],proxy://scheme + tls://host:port[?skipVerify=true],scheme:// tls://host:port[?skipVerify=true],http://[user:pass@] tls://host:port[?skipVerify=true],socks5://[user:pass@] tls://host:port[?skipVerify=true],vmess://[security:]uuid@?alterID=num @@ -170,7 +170,7 @@ Websocket with a specified proxy protocol: ws://host:port[/path],vmess://[security:]uuid@?alterID=num TLS and Websocket with a specified proxy protocol: - tls://host:port[?skipVerify=true],ws://[@/path],proxy://scheme + tls://host:port[?skipVerify=true],ws://[@/path],scheme:// tls://host:port[?skipVerify=true],ws://[@/path],http://[user:pass@] tls://host:port[?skipVerify=true],ws://[@/path],socks5://[user:pass@] tls://host:port[?skipVerify=true],ws://[@/path],vmess://[security:]uuid@?alterID=num @@ -211,7 +211,10 @@ Examples: -listen on :1081 as a transparent redirect server, forward all requests via remote ssr server. glider -listen redir://:1081 -forward "tls://1.1.1.1:443,vmess://security:uuid@?alterID=10" - -listen on :1081 as a transparent redirect server, forward all requests via remote vmess server. + -listen on :1081 as a transparent redirect server, forward all requests via remote tls+vmess server. + + glider -listen redir://:1081 -forward "ws://1.1.1.1:80,vmess://security:uuid@?alterID=10" + -listen on :1081 as a transparent redirect server, forward all requests via remote ws+vmess server. glider -listen tcptun://:80=2.2.2.2:80 -forward ss://method:pass@1.1.1.1:8443 -listen on :80 and forward all requests to 2.2.2.2:80 via remote ss server. diff --git a/conf.go b/conf.go index 64ac1d3..ba5b57c 100644 --- a/conf.go +++ b/conf.go @@ -201,7 +201,7 @@ func usage() { fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "TLS with a specified proxy protocol:\n") - fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],proxy://scheme\n") + fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],scheme://\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],http://[user:pass@]\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],socks5://[user:pass@]\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],vmess://[security:]uuid@?alterID=num\n") @@ -212,14 +212,14 @@ func usage() { fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "Websocket with a specified proxy protocol:\n") - fmt.Fprintf(os.Stderr, " ws://host:port[/path],proxy://scheme\n") + fmt.Fprintf(os.Stderr, " ws://host:port[/path],scheme://\n") fmt.Fprintf(os.Stderr, " ws://host:port[/path],http://[user:pass@]\n") fmt.Fprintf(os.Stderr, " ws://host:port[/path],socks5://[user:pass@]\n") fmt.Fprintf(os.Stderr, " ws://host:port[/path],vmess://[security:]uuid@?alterID=num\n") fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, "TLS and Websocket with a specified proxy protocol:\n") - fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],ws://[@/path],proxy://scheme\n") + fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],ws://[@/path],scheme://\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],ws://[@/path],http://[user:pass@]\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],ws://[@/path],socks5://[user:pass@]\n") fmt.Fprintf(os.Stderr, " tls://host:port[?skipVerify=true],ws://[@/path],vmess://[security:]uuid@?alterID=num\n") @@ -263,7 +263,10 @@ func usage() { fmt.Fprintf(os.Stderr, " -listen on :1081 as a transparent redirect server, forward all requests via remote ssr server.\n") fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, " "+app+" -listen redir://:1081 -forward \"tls://1.1.1.1:443,vmess://security:uuid@?alterID=10\"\n") - fmt.Fprintf(os.Stderr, " -listen on :1081 as a transparent redirect server, forward all requests via remote vmess server.\n") + fmt.Fprintf(os.Stderr, " -listen on :1081 as a transparent redirect server, forward all requests via remote tls+vmess server.\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, " "+app+" -listen redir://:1081 -forward \"ws://1.1.1.1:80,vmess://security:uuid@?alterID=10\"\n") + fmt.Fprintf(os.Stderr, " -listen on :1081 as a transparent redirect server, forward all requests via remote ws+vmess server.\n") fmt.Fprintf(os.Stderr, "\n") fmt.Fprintf(os.Stderr, " "+app+" -listen tcptun://:80=2.2.2.2:80 -forward ss://method:pass@1.1.1.1:8443\n") fmt.Fprintf(os.Stderr, " -listen on :80 and forward all requests to 2.2.2.2:80 via remote ss server.\n") diff --git a/dev.go b/dev.go index 5f58979..7a43b10 100644 --- a/dev.go +++ b/dev.go @@ -3,6 +3,7 @@ package main import ( + "fmt" "net/http" _ "net/http/pprof" @@ -11,6 +12,9 @@ import ( func init() { go func() { - http.ListenAndServe(":6060", nil) + err := http.ListenAndServe(":6060", nil) + if err != nil { + fmt.Printf("Create pprof server error: %s\n", err) + } }() } diff --git a/proxy/vmess/chunk.go b/proxy/vmess/chunk.go index f845799..89b71bc 100644 --- a/proxy/vmess/chunk.go +++ b/proxy/vmess/chunk.go @@ -62,47 +62,43 @@ func (w *chunkedWriter) ReadFrom(r io.Reader) (n int64, err error) { type chunkedReader struct { io.Reader - buf []byte - leftover []byte + buf []byte + leftBytes int } // ChunkedReader returns a chunked reader func ChunkedReader(r io.Reader) io.Reader { return &chunkedReader{ Reader: r, - buf: make([]byte, lenSize+maxChunkSize), + buf: make([]byte, lenSize), // NOTE: buf only used to save header bytes now } } func (r *chunkedReader) Read(b []byte) (int, error) { - if len(r.leftover) > 0 { - n := copy(b, r.leftover) - r.leftover = r.leftover[n:] - return n, nil + if r.leftBytes == 0 { + // get length + _, err := io.ReadFull(r.Reader, r.buf[:lenSize]) + if err != nil { + return 0, err + } + + // if length == 0, then this is the end + r.leftBytes = int(binary.BigEndian.Uint16(r.buf[:lenSize])) + if r.leftBytes == 0 { + return 0, nil + } } - // get length - _, err := io.ReadFull(r.Reader, r.buf[:lenSize]) + readLen := len(b) + if readLen > r.leftBytes { + readLen = r.leftBytes + } + + m, err := r.Reader.Read(b[:readLen]) if err != nil { return 0, err } - // if length == 0, then this is the end - len := binary.BigEndian.Uint16(r.buf[:lenSize]) - if len == 0 { - return 0, nil - } - - // get payload - _, err = io.ReadFull(r.Reader, r.buf[:len]) - if err != nil { - return 0, err - } - - m := copy(b, r.buf[:len]) - if m < int(len) { - r.leftover = r.buf[m:len] - } - + r.leftBytes -= m return m, err } diff --git a/proxy/ws/client.go b/proxy/ws/client.go index d5cd5e9..44ae373 100644 --- a/proxy/ws/client.go +++ b/proxy/ws/client.go @@ -2,15 +2,19 @@ package ws import ( "bufio" + "bytes" + "crypto/rand" + "crypto/sha1" + "encoding/base64" "errors" "io" "net" "net/textproto" "strings" - - "github.com/nadoo/glider/common/log" ) +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + // Client ws client type Client struct { host string @@ -41,31 +45,38 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { // Handshake handshakes with the server using HTTP to request a protocol upgrade func (c *Conn) Handshake(host, path string) error { - c.Conn.Write([]byte("GET " + path + " HTTP/1.1\r\n")) - // c.Conn.Write([]byte("Host: 127.0.0.1\r\n")) - c.Conn.Write([]byte("Host: " + host + "\r\n")) - c.Conn.Write([]byte("Upgrade: websocket\r\n")) - c.Conn.Write([]byte("Connection: Upgrade\r\n")) - c.Conn.Write([]byte("Origin: http://" + host + "\r\n")) - c.Conn.Write([]byte("Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n")) - c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n")) - c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n")) - c.Conn.Write([]byte("\r\n")) + clientKey := generateClientKey() + + var buf bytes.Buffer + buf.Write([]byte("GET " + path + " HTTP/1.1\r\n")) + buf.Write([]byte("Host: " + host + "\r\n")) + buf.Write([]byte("Upgrade: websocket\r\n")) + buf.Write([]byte("Connection: Upgrade\r\n")) + buf.Write([]byte("Origin: http://" + host + "\r\n")) + buf.Write([]byte("Sec-WebSocket-Key: " + clientKey + "\r\n")) + buf.Write([]byte("Sec-WebSocket-Protocol: binary\r\n")) + buf.Write([]byte("Sec-WebSocket-Version: 13\r\n")) + buf.Write([]byte("\r\n")) + + if _, err := c.Conn.Write(buf.Bytes()); err != nil { + return err + } tpr := textproto.NewReader(bufio.NewReader(c.Conn)) _, code, _, ok := parseFirstLine(tpr) if !ok || code != "101" { - return errors.New("error in ws handshake") + return errors.New("[ws] error in ws handshake parseFirstLine") } - // respHeader, err := tpr.ReadMIMEHeader() - // if err != nil { - // return err - // } + respHeader, err := tpr.ReadMIMEHeader() + if err != nil { + return err + } - // // TODO: verify this - // respHeader.Get("Sec-WebSocket-Accept") - // fmt.Printf("respHeader: %+v\n", respHeader) + serverKey := respHeader.Get("Sec-WebSocket-Accept") + if serverKey != computeServerKey(clientKey) { + return errors.New("[ws] error in ws handshake, got wrong Sec-Websocket-Key") + } return nil } @@ -92,7 +103,7 @@ func parseFirstLine(tp *textproto.Reader) (r1, r2, r3 string, ok bool) { line, err := tp.ReadLine() // log.F("first line: %s", line) if err != nil { - log.F("[http] read first line error:%s", err) + // log.F("[ws] read first line error:%s", err) return } @@ -104,3 +115,16 @@ func parseFirstLine(tp *textproto.Reader) (r1, r2, r3 string, ok bool) { s2 += s1 + 1 return line[:s1], line[s1+1 : s2], line[s2+1:], true } + +func generateClientKey() string { + p := make([]byte, 16) + rand.Read(p) + return base64.StdEncoding.EncodeToString(p) +} + +func computeServerKey(clientKey string) string { + h := sha1.New() + h.Write([]byte(clientKey)) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/proxy/ws/frame.go b/proxy/ws/frame.go index 4df6c4d..3d5b675 100644 --- a/proxy/ws/frame.go +++ b/proxy/ws/frame.go @@ -137,7 +137,7 @@ type frameReader struct { func FrameReader(r io.Reader) io.Reader { return &frameReader{ Reader: r, - buf: make([]byte, defaultFrameSize), + buf: make([]byte, maxFrameHeaderSize), // NOTE: buf only used to save header bytes now } } @@ -155,36 +155,30 @@ func (r *frameReader) Read(b []byte) (int, error) { r.leftBytes = int64(r.buf[1] & 0x7f) switch r.leftBytes { case 126: - io.ReadFull(r.Reader, r.buf[:2]) + _, err := io.ReadFull(r.Reader, r.buf[:2]) + if err != nil { + return 0, err + } r.leftBytes = int64(binary.BigEndian.Uint16(r.buf[0:])) case 127: - io.ReadFull(r.Reader, r.buf[:8]) + _, err := io.ReadFull(r.Reader, r.buf[:8]) + if err != nil { + return 0, err + } r.leftBytes = int64(binary.BigEndian.Uint64(r.buf[0:])) } } - var n, m int - var err error - if r.leftBytes > int64(len(r.buf)) { - if len(b) < len(r.buf) { - n, err = r.Reader.Read(r.buf[:len(b)]) - } else { - n, err = r.Reader.Read(r.buf) - } - } else { - if int64(len(b)) < r.leftBytes { - n, err = io.ReadFull(r.Reader, r.buf[:len(b)]) - } else { - n, err = io.ReadFull(r.Reader, r.buf[:r.leftBytes]) - } + readLen := int64(len(b)) + if readLen > r.leftBytes { + readLen = r.leftBytes } + m, err := r.Reader.Read(b[:readLen]) if err != nil { return 0, err } - m = copy(b, r.buf[:n]) r.leftBytes -= int64(m) - return m, err }