ws: fixed bug in ws codes and now it worked

This commit is contained in:
nadoo 2018-07-22 18:21:27 +08:00
parent 9856f943ad
commit ff1fb8c291
6 changed files with 61 additions and 56 deletions

1
dev.go
View File

@ -7,6 +7,7 @@ import (
_ "net/http/pprof" _ "net/http/pprof"
_ "github.com/nadoo/glider/proxy/ws" _ "github.com/nadoo/glider/proxy/ws"
// _ "github.com/nadoo/glider/proxy/tproxy"
) )
func init() { func init() {

View File

@ -2,5 +2,4 @@ package main
import ( import (
_ "github.com/nadoo/glider/proxy/redir" _ "github.com/nadoo/glider/proxy/redir"
// _ "github.com/nadoo/glider/proxy/tproxy"
) )

View File

@ -118,14 +118,14 @@ func (s *HTTP) Serve(c net.Conn) {
// tell the remote server not to keep alive // tell the remote server not to keep alive
reqHeader.Set("Connection", "close") reqHeader.Set("Connection", "close")
url, err := url.ParseRequestURI(requestURI) u, err := url.ParseRequestURI(requestURI)
if err != nil { if err != nil {
log.F("[http] parse request url error: %s", err) log.F("[http] parse request url error: %s", err)
return return
} }
var tgt = url.Host var tgt = u.Host
if !strings.Contains(url.Host, ":") { if !strings.Contains(u.Host, ":") {
tgt += ":80" tgt += ":80"
} }
@ -139,9 +139,9 @@ func (s *HTTP) Serve(c net.Conn) {
// GET http://example.com/a/index.htm HTTP/1.1 --> // GET http://example.com/a/index.htm HTTP/1.1 -->
// GET /a/index.htm HTTP/1.1 // GET /a/index.htm HTTP/1.1
url.Scheme = "" u.Scheme = ""
url.Host = "" u.Host = ""
uri := url.String() uri := u.String()
var reqBuf bytes.Buffer var reqBuf bytes.Buffer
writeFirstLine(method, uri, proto, &reqBuf) writeFirstLine(method, uri, proto, &reqBuf)

View File

@ -13,6 +13,7 @@ import (
// Client ws client // Client ws client
type Client struct { type Client struct {
host string
path string path string
} }
@ -24,41 +25,29 @@ type Conn struct {
} }
// NewClient . // NewClient .
func NewClient() (*Client, error) { func NewClient(host, path string) (*Client, error) {
c := &Client{} if path == "" {
path = "/"
}
c := &Client{host: host, path: path}
return c, nil return c, nil
} }
// NewConn . // NewConn .
func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
conn := &Conn{Conn: rc} conn := &Conn{Conn: rc}
conn.Handshake() return conn, conn.Handshake(c.host, c.path)
return conn, nil
} }
// Handshake handshakes with the server using HTTP to request a protocol upgrade // Handshake handshakes with the server using HTTP to request a protocol upgrade
// func (c *Conn) Handshake(host, path string) error {
// GET /chat HTTP/1.1 c.Conn.Write([]byte("GET " + path + " HTTP/1.1\r\n"))
// Host: server.example.com // c.Conn.Write([]byte("Host: 127.0.0.1\r\n"))
// Upgrade: websocket c.Conn.Write([]byte("Host: " + host + "\r\n"))
// Connection: Upgrade
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
// Origin: http://example.com
// Sec-WebSocket-Protocol: chat, superchat
// Sec-WebSocket-Version: 13
//
// HTTP/1.1 101 Switching Protocols
// Upgrade: websocket
// Connection: Upgrade
// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
// Sec-WebSocket-Protocol: chat
func (c *Conn) Handshake() error {
c.Conn.Write([]byte("GET / HTTP/1.1\r\n"))
c.Conn.Write([]byte("Host: echo.websocket.org\r\n"))
c.Conn.Write([]byte("Upgrade: websocket\r\n")) c.Conn.Write([]byte("Upgrade: websocket\r\n"))
c.Conn.Write([]byte("Connection: Upgrade\r\n")) c.Conn.Write([]byte("Connection: Upgrade\r\n"))
c.Conn.Write([]byte("Origin: http://127.0.0.1\r\n")) c.Conn.Write([]byte("Origin: http://" + host + "\r\n"))
c.Conn.Write([]byte("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n")) c.Conn.Write([]byte("Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n"))
c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n")) c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n"))
c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n")) c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n"))
c.Conn.Write([]byte("\r\n")) c.Conn.Write([]byte("\r\n"))
@ -69,17 +58,18 @@ func (c *Conn) Handshake() error {
return errors.New("error in ws handshake") return errors.New("error in ws handshake")
} }
respHeader, err := tpr.ReadMIMEHeader() // respHeader, err := tpr.ReadMIMEHeader()
if err != nil { // if err != nil {
return err // return err
} // }
// TODO: verify this // // TODO: verify this
respHeader.Get("Sec-WebSocket-Accept") // respHeader.Get("Sec-WebSocket-Accept")
// fmt.Printf("respHeader: %+v\n", respHeader) // fmt.Printf("respHeader: %+v\n", respHeader)
return nil return nil
} }
func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (n int, err error) {
if c.writer == nil { if c.writer == nil {
c.writer = FrameWriter(c.Conn) c.writer = FrameWriter(c.Conn)
@ -87,6 +77,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return c.writer.Write(b) return c.writer.Write(b)
} }
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
if c.reader == nil { if c.reader == nil {
c.reader = FrameReader(c.Conn) c.reader = FrameReader(c.Conn)

View File

@ -10,18 +10,15 @@ import (
) )
const ( const (
finalBit = 1 << 7 finalBit byte = 1 << 7
defaultFrameSize = 1 << 13 // 8192 defaultFrameSize = 4096
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
maskBit = 1 << 7 maskBit byte = 1 << 7
opCodeBinary = 2 opCodeBinary byte = 2
opClose = 8 opClose byte = 8
maskKeyLen = 4 maskKeyLen = 4
) )
type frame struct {
}
type frameWriter struct { type frameWriter struct {
io.Writer io.Writer
buf []byte buf []byte
@ -46,19 +43,18 @@ func (w *frameWriter) Write(b []byte) (int, error) {
func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
for { for {
buf := w.buf buf := w.buf
payloadBuf := buf[maxFrameHeaderSize:]
nr, er := r.Read(buf) nr, er := r.Read(payloadBuf)
if nr > 0 { if nr > 0 {
n += int64(nr) n += int64(nr)
buf[0] |= finalBit buf[0] = finalBit | opCodeBinary
buf[0] |= opCodeBinary buf[1] = maskBit
buf[1] |= maskBit
lengthFieldLen := 0 lengthFieldLen := 0
switch { switch {
case nr <= 125: case nr <= 125:
buf[1] |= byte(nr) buf[1] |= byte(nr)
lengthFieldLen = 2
case nr < 65536: case nr < 65536:
buf[1] |= 126 buf[1] |= 126
lengthFieldLen = 2 lengthFieldLen = 2
@ -69,13 +65,24 @@ func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr)) binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr))
} }
copy(buf[2+lengthFieldLen:], w.maskKey) _, ew := w.Writer.Write(buf[:2+lengthFieldLen])
payloadBuf := buf[2+lengthFieldLen+maskKeyLen:] if ew != nil {
err = ew
break
}
_, ew = w.Writer.Write(w.maskKey)
if ew != nil {
err = ew
break
}
payloadBuf = payloadBuf[:nr]
for i := range payloadBuf { for i := range payloadBuf {
payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4] payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4]
} }
_, ew := w.Writer.Write(buf) _, ew = w.Writer.Write(payloadBuf)
if ew != nil { if ew != nil {
err = ew err = ew
break break

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net" "net"
"net/url" "net/url"
"strings"
"github.com/nadoo/glider/common/log" "github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/proxy" "github.com/nadoo/glider/proxy"
@ -31,7 +32,13 @@ func NewWS(s string, dialer proxy.Dialer) (*WS, error) {
addr := u.Host addr := u.Host
client, err := NewClient() colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
serverName := addr[:colonPos]
client, err := NewClient(serverName, u.Path)
if err != nil { if err != nil {
log.F("create ws client err: %s", err) log.F("create ws client err: %s", err)
return nil, err return nil, err