mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 17:35:40 +08:00
ws: fixed bug in ws codes and now it worked
This commit is contained in:
parent
9856f943ad
commit
ff1fb8c291
1
dev.go
1
dev.go
@ -7,6 +7,7 @@ import (
|
|||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
|
|
||||||
_ "github.com/nadoo/glider/proxy/ws"
|
_ "github.com/nadoo/glider/proxy/ws"
|
||||||
|
// _ "github.com/nadoo/glider/proxy/tproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -2,5 +2,4 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
_ "github.com/nadoo/glider/proxy/redir"
|
_ "github.com/nadoo/glider/proxy/redir"
|
||||||
// _ "github.com/nadoo/glider/proxy/tproxy"
|
|
||||||
)
|
)
|
||||||
|
@ -118,14 +118,14 @@ func (s *HTTP) Serve(c net.Conn) {
|
|||||||
// tell the remote server not to keep alive
|
// tell the remote server not to keep alive
|
||||||
reqHeader.Set("Connection", "close")
|
reqHeader.Set("Connection", "close")
|
||||||
|
|
||||||
url, err := url.ParseRequestURI(requestURI)
|
u, err := url.ParseRequestURI(requestURI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[http] parse request url error: %s", err)
|
log.F("[http] parse request url error: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var tgt = url.Host
|
var tgt = u.Host
|
||||||
if !strings.Contains(url.Host, ":") {
|
if !strings.Contains(u.Host, ":") {
|
||||||
tgt += ":80"
|
tgt += ":80"
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -139,9 +139,9 @@ func (s *HTTP) Serve(c net.Conn) {
|
|||||||
|
|
||||||
// GET http://example.com/a/index.htm HTTP/1.1 -->
|
// GET http://example.com/a/index.htm HTTP/1.1 -->
|
||||||
// GET /a/index.htm HTTP/1.1
|
// GET /a/index.htm HTTP/1.1
|
||||||
url.Scheme = ""
|
u.Scheme = ""
|
||||||
url.Host = ""
|
u.Host = ""
|
||||||
uri := url.String()
|
uri := u.String()
|
||||||
|
|
||||||
var reqBuf bytes.Buffer
|
var reqBuf bytes.Buffer
|
||||||
writeFirstLine(method, uri, proto, &reqBuf)
|
writeFirstLine(method, uri, proto, &reqBuf)
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
// Client ws client
|
// Client ws client
|
||||||
type Client struct {
|
type Client struct {
|
||||||
|
host string
|
||||||
path string
|
path string
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -24,41 +25,29 @@ type Conn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient .
|
// NewClient .
|
||||||
func NewClient() (*Client, error) {
|
func NewClient(host, path string) (*Client, error) {
|
||||||
c := &Client{}
|
if path == "" {
|
||||||
|
path = "/"
|
||||||
|
}
|
||||||
|
c := &Client{host: host, path: path}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConn .
|
// NewConn .
|
||||||
func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
|
func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) {
|
||||||
conn := &Conn{Conn: rc}
|
conn := &Conn{Conn: rc}
|
||||||
conn.Handshake()
|
return conn, conn.Handshake(c.host, c.path)
|
||||||
return conn, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handshake handshakes with the server using HTTP to request a protocol upgrade
|
// Handshake handshakes with the server using HTTP to request a protocol upgrade
|
||||||
//
|
func (c *Conn) Handshake(host, path string) error {
|
||||||
// GET /chat HTTP/1.1
|
c.Conn.Write([]byte("GET " + path + " HTTP/1.1\r\n"))
|
||||||
// Host: server.example.com
|
// c.Conn.Write([]byte("Host: 127.0.0.1\r\n"))
|
||||||
// Upgrade: websocket
|
c.Conn.Write([]byte("Host: " + host + "\r\n"))
|
||||||
// Connection: Upgrade
|
|
||||||
// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
|
|
||||||
// Origin: http://example.com
|
|
||||||
// Sec-WebSocket-Protocol: chat, superchat
|
|
||||||
// Sec-WebSocket-Version: 13
|
|
||||||
//
|
|
||||||
// HTTP/1.1 101 Switching Protocols
|
|
||||||
// Upgrade: websocket
|
|
||||||
// Connection: Upgrade
|
|
||||||
// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
|
|
||||||
// Sec-WebSocket-Protocol: chat
|
|
||||||
func (c *Conn) Handshake() error {
|
|
||||||
c.Conn.Write([]byte("GET / HTTP/1.1\r\n"))
|
|
||||||
c.Conn.Write([]byte("Host: echo.websocket.org\r\n"))
|
|
||||||
c.Conn.Write([]byte("Upgrade: websocket\r\n"))
|
c.Conn.Write([]byte("Upgrade: websocket\r\n"))
|
||||||
c.Conn.Write([]byte("Connection: Upgrade\r\n"))
|
c.Conn.Write([]byte("Connection: Upgrade\r\n"))
|
||||||
c.Conn.Write([]byte("Origin: http://127.0.0.1\r\n"))
|
c.Conn.Write([]byte("Origin: http://" + host + "\r\n"))
|
||||||
c.Conn.Write([]byte("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"))
|
c.Conn.Write([]byte("Sec-WebSocket-Key: w4v7O6xFTi36lq3RNcgctw==\r\n"))
|
||||||
c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n"))
|
c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n"))
|
||||||
c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n"))
|
c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n"))
|
||||||
c.Conn.Write([]byte("\r\n"))
|
c.Conn.Write([]byte("\r\n"))
|
||||||
@ -69,17 +58,18 @@ func (c *Conn) Handshake() error {
|
|||||||
return errors.New("error in ws handshake")
|
return errors.New("error in ws handshake")
|
||||||
}
|
}
|
||||||
|
|
||||||
respHeader, err := tpr.ReadMIMEHeader()
|
// respHeader, err := tpr.ReadMIMEHeader()
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
// TODO: verify this
|
// // TODO: verify this
|
||||||
respHeader.Get("Sec-WebSocket-Accept")
|
// respHeader.Get("Sec-WebSocket-Accept")
|
||||||
// fmt.Printf("respHeader: %+v\n", respHeader)
|
// fmt.Printf("respHeader: %+v\n", respHeader)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||||
if c.writer == nil {
|
if c.writer == nil {
|
||||||
c.writer = FrameWriter(c.Conn)
|
c.writer = FrameWriter(c.Conn)
|
||||||
@ -87,6 +77,7 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||||||
|
|
||||||
return c.writer.Write(b)
|
return c.writer.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
if c.reader == nil {
|
if c.reader == nil {
|
||||||
c.reader = FrameReader(c.Conn)
|
c.reader = FrameReader(c.Conn)
|
||||||
|
@ -10,18 +10,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
finalBit = 1 << 7
|
finalBit byte = 1 << 7
|
||||||
defaultFrameSize = 1 << 13 // 8192
|
defaultFrameSize = 4096
|
||||||
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask
|
||||||
maskBit = 1 << 7
|
maskBit byte = 1 << 7
|
||||||
opCodeBinary = 2
|
opCodeBinary byte = 2
|
||||||
opClose = 8
|
opClose byte = 8
|
||||||
maskKeyLen = 4
|
maskKeyLen = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
type frame struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
type frameWriter struct {
|
type frameWriter struct {
|
||||||
io.Writer
|
io.Writer
|
||||||
buf []byte
|
buf []byte
|
||||||
@ -46,19 +43,18 @@ func (w *frameWriter) Write(b []byte) (int, error) {
|
|||||||
func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
for {
|
for {
|
||||||
buf := w.buf
|
buf := w.buf
|
||||||
|
payloadBuf := buf[maxFrameHeaderSize:]
|
||||||
|
|
||||||
nr, er := r.Read(buf)
|
nr, er := r.Read(payloadBuf)
|
||||||
if nr > 0 {
|
if nr > 0 {
|
||||||
n += int64(nr)
|
n += int64(nr)
|
||||||
buf[0] |= finalBit
|
buf[0] = finalBit | opCodeBinary
|
||||||
buf[0] |= opCodeBinary
|
buf[1] = maskBit
|
||||||
buf[1] |= maskBit
|
|
||||||
|
|
||||||
lengthFieldLen := 0
|
lengthFieldLen := 0
|
||||||
switch {
|
switch {
|
||||||
case nr <= 125:
|
case nr <= 125:
|
||||||
buf[1] |= byte(nr)
|
buf[1] |= byte(nr)
|
||||||
lengthFieldLen = 2
|
|
||||||
case nr < 65536:
|
case nr < 65536:
|
||||||
buf[1] |= 126
|
buf[1] |= 126
|
||||||
lengthFieldLen = 2
|
lengthFieldLen = 2
|
||||||
@ -69,13 +65,24 @@ func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
|
|||||||
binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr))
|
binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr))
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(buf[2+lengthFieldLen:], w.maskKey)
|
_, ew := w.Writer.Write(buf[:2+lengthFieldLen])
|
||||||
payloadBuf := buf[2+lengthFieldLen+maskKeyLen:]
|
if ew != nil {
|
||||||
|
err = ew
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
_, ew = w.Writer.Write(w.maskKey)
|
||||||
|
if ew != nil {
|
||||||
|
err = ew
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
payloadBuf = payloadBuf[:nr]
|
||||||
for i := range payloadBuf {
|
for i := range payloadBuf {
|
||||||
payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4]
|
payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4]
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ew := w.Writer.Write(buf)
|
_, ew = w.Writer.Write(payloadBuf)
|
||||||
if ew != nil {
|
if ew != nil {
|
||||||
err = ew
|
err = ew
|
||||||
break
|
break
|
||||||
|
@ -4,6 +4,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/nadoo/glider/common/log"
|
"github.com/nadoo/glider/common/log"
|
||||||
"github.com/nadoo/glider/proxy"
|
"github.com/nadoo/glider/proxy"
|
||||||
@ -31,7 +32,13 @@ func NewWS(s string, dialer proxy.Dialer) (*WS, error) {
|
|||||||
|
|
||||||
addr := u.Host
|
addr := u.Host
|
||||||
|
|
||||||
client, err := NewClient()
|
colonPos := strings.LastIndex(addr, ":")
|
||||||
|
if colonPos == -1 {
|
||||||
|
colonPos = len(addr)
|
||||||
|
}
|
||||||
|
serverName := addr[:colonPos]
|
||||||
|
|
||||||
|
client, err := NewClient(serverName, u.Path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("create ws client err: %s", err)
|
log.F("create ws client err: %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
Loading…
Reference in New Issue
Block a user