diff --git a/dev.go b/dev.go index 5f58979..2aff07f 100644 --- a/dev.go +++ b/dev.go @@ -7,6 +7,7 @@ import ( _ "net/http/pprof" _ "github.com/nadoo/glider/proxy/ws" + // _ "github.com/nadoo/glider/proxy/tproxy" ) func init() { diff --git a/main_linux.go b/main_linux.go index 87c6457..636e616 100644 --- a/main_linux.go +++ b/main_linux.go @@ -2,5 +2,4 @@ package main import ( _ "github.com/nadoo/glider/proxy/redir" - // _ "github.com/nadoo/glider/proxy/tproxy" ) diff --git a/proxy/http/http.go b/proxy/http/http.go index 870cf37..c7da285 100644 --- a/proxy/http/http.go +++ b/proxy/http/http.go @@ -118,14 +118,14 @@ func (s *HTTP) Serve(c net.Conn) { // tell the remote server not to keep alive reqHeader.Set("Connection", "close") - url, err := url.ParseRequestURI(requestURI) + u, err := url.ParseRequestURI(requestURI) if err != nil { log.F("[http] parse request url error: %s", err) return } - var tgt = url.Host - if !strings.Contains(url.Host, ":") { + var tgt = u.Host + if !strings.Contains(u.Host, ":") { tgt += ":80" } @@ -139,9 +139,9 @@ func (s *HTTP) Serve(c net.Conn) { // GET http://example.com/a/index.htm HTTP/1.1 --> // GET /a/index.htm HTTP/1.1 - url.Scheme = "" - url.Host = "" - uri := url.String() + u.Scheme = "" + u.Host = "" + uri := u.String() var reqBuf bytes.Buffer writeFirstLine(method, uri, proto, &reqBuf) diff --git a/proxy/ws/client.go b/proxy/ws/client.go index 9035469..d5cd5e9 100644 --- a/proxy/ws/client.go +++ b/proxy/ws/client.go @@ -13,6 +13,7 @@ import ( // Client ws client type Client struct { + host string path string } @@ -24,41 +25,29 @@ type Conn struct { } // NewClient . -func NewClient() (*Client, error) { - c := &Client{} +func NewClient(host, path string) (*Client, error) { + if path == "" { + path = "/" + } + c := &Client{host: host, path: path} return c, nil } // NewConn . func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { conn := &Conn{Conn: rc} - conn.Handshake() - return conn, nil + return conn, conn.Handshake(c.host, c.path) } // Handshake handshakes with the server using HTTP to request a protocol upgrade -// -// GET /chat HTTP/1.1 -// Host: server.example.com -// Upgrade: websocket -// Connection: Upgrade -// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== -// Origin: http://example.com -// Sec-WebSocket-Protocol: chat, superchat -// Sec-WebSocket-Version: 13 -// -// HTTP/1.1 101 Switching Protocols -// Upgrade: websocket -// Connection: Upgrade -// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= -// Sec-WebSocket-Protocol: chat -func (c *Conn) Handshake() error { - c.Conn.Write([]byte("GET / HTTP/1.1\r\n")) - c.Conn.Write([]byte("Host: echo.websocket.org\r\n")) +func (c *Conn) Handshake(host, path string) error { + c.Conn.Write([]byte("GET " + path + " HTTP/1.1\r\n")) + // c.Conn.Write([]byte("Host: 127.0.0.1\r\n")) + c.Conn.Write([]byte("Host: " + host + "\r\n")) c.Conn.Write([]byte("Upgrade: websocket\r\n")) c.Conn.Write([]byte("Connection: Upgrade\r\n")) - c.Conn.Write([]byte("Origin: http://127.0.0.1\r\n")) - c.Conn.Write([]byte("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n")) + c.Conn.Write([]byte("Origin: http://" + host + "\r\n")) + c.Conn.Write([]byte("Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n")) c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n")) c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n")) c.Conn.Write([]byte("\r\n")) @@ -69,17 +58,18 @@ func (c *Conn) Handshake() error { return errors.New("error in ws handshake") } - respHeader, err := tpr.ReadMIMEHeader() - if err != nil { - return err - } + // respHeader, err := tpr.ReadMIMEHeader() + // if err != nil { + // return err + // } - // TODO: verify this - respHeader.Get("Sec-WebSocket-Accept") + // // TODO: verify this + // respHeader.Get("Sec-WebSocket-Accept") // fmt.Printf("respHeader: %+v\n", respHeader) return nil } + func (c *Conn) Write(b []byte) (n int, err error) { if c.writer == nil { c.writer = FrameWriter(c.Conn) @@ -87,6 +77,7 @@ func (c *Conn) Write(b []byte) (n int, err error) { return c.writer.Write(b) } + func (c *Conn) Read(b []byte) (n int, err error) { if c.reader == nil { c.reader = FrameReader(c.Conn) diff --git a/proxy/ws/frame.go b/proxy/ws/frame.go index 674d89c..a38e490 100644 --- a/proxy/ws/frame.go +++ b/proxy/ws/frame.go @@ -10,18 +10,15 @@ import ( ) const ( - finalBit = 1 << 7 - defaultFrameSize = 1 << 13 // 8192 - maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask - maskBit = 1 << 7 - opCodeBinary = 2 - opClose = 8 - maskKeyLen = 4 + finalBit byte = 1 << 7 + defaultFrameSize = 4096 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maskBit byte = 1 << 7 + opCodeBinary byte = 2 + opClose byte = 8 + maskKeyLen = 4 ) -type frame struct { -} - type frameWriter struct { io.Writer buf []byte @@ -46,19 +43,18 @@ func (w *frameWriter) Write(b []byte) (int, error) { func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { for { buf := w.buf + payloadBuf := buf[maxFrameHeaderSize:] - nr, er := r.Read(buf) + nr, er := r.Read(payloadBuf) if nr > 0 { n += int64(nr) - buf[0] |= finalBit - buf[0] |= opCodeBinary - buf[1] |= maskBit + buf[0] = finalBit | opCodeBinary + buf[1] = maskBit lengthFieldLen := 0 switch { case nr <= 125: buf[1] |= byte(nr) - lengthFieldLen = 2 case nr < 65536: buf[1] |= 126 lengthFieldLen = 2 @@ -69,13 +65,24 @@ func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr)) } - copy(buf[2+lengthFieldLen:], w.maskKey) - payloadBuf := buf[2+lengthFieldLen+maskKeyLen:] + _, ew := w.Writer.Write(buf[:2+lengthFieldLen]) + if ew != nil { + err = ew + break + } + + _, ew = w.Writer.Write(w.maskKey) + if ew != nil { + err = ew + break + } + + payloadBuf = payloadBuf[:nr] for i := range payloadBuf { payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4] } - _, ew := w.Writer.Write(buf) + _, ew = w.Writer.Write(payloadBuf) if ew != nil { err = ew break diff --git a/proxy/ws/ws.go b/proxy/ws/ws.go index 2500f2a..14d6162 100644 --- a/proxy/ws/ws.go +++ b/proxy/ws/ws.go @@ -4,6 +4,7 @@ import ( "errors" "net" "net/url" + "strings" "github.com/nadoo/glider/common/log" "github.com/nadoo/glider/proxy" @@ -31,7 +32,13 @@ func NewWS(s string, dialer proxy.Dialer) (*WS, error) { addr := u.Host - client, err := NewClient() + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + serverName := addr[:colonPos] + + client, err := NewClient(serverName, u.Path) if err != nil { log.F("create ws client err: %s", err) return nil, err