diff --git a/dev.go b/dev.go new file mode 100644 index 0000000..5f58979 --- /dev/null +++ b/dev.go @@ -0,0 +1,16 @@ +//+build dev + +package main + +import ( + "net/http" + _ "net/http/pprof" + + _ "github.com/nadoo/glider/proxy/ws" +) + +func init() { + go func() { + http.ListenAndServe(":6060", nil) + }() +} diff --git a/pprof.go b/pprof.go deleted file mode 100644 index 75e6eac..0000000 --- a/pprof.go +++ /dev/null @@ -1,12 +0,0 @@ -//+build pprof - -package main - -import _ "net/http/pprof" -import "net/http" - -func init() { - go func() { - http.ListenAndServe(":6060", nil) - }() -} diff --git a/proxy/ws/client.go b/proxy/ws/client.go new file mode 100644 index 0000000..9035469 --- /dev/null +++ b/proxy/ws/client.go @@ -0,0 +1,115 @@ +package ws + +import ( + "bufio" + "errors" + "io" + "net" + "net/textproto" + "strings" + + "github.com/nadoo/glider/common/log" +) + +// Client ws client +type Client struct { + path string +} + +// Conn is a connection to ws server +type Conn struct { + net.Conn + reader io.Reader + writer io.Writer +} + +// NewClient . +func NewClient() (*Client, error) { + c := &Client{} + return c, nil +} + +// NewConn . +func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { + conn := &Conn{Conn: rc} + conn.Handshake() + return conn, nil +} + +// Handshake handshakes with the server using HTTP to request a protocol upgrade +// +// GET /chat HTTP/1.1 +// Host: server.example.com +// Upgrade: websocket +// Connection: Upgrade +// Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ== +// Origin: http://example.com +// Sec-WebSocket-Protocol: chat, superchat +// Sec-WebSocket-Version: 13 +// +// HTTP/1.1 101 Switching Protocols +// Upgrade: websocket +// Connection: Upgrade +// Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo= +// Sec-WebSocket-Protocol: chat +func (c *Conn) Handshake() error { + c.Conn.Write([]byte("GET / HTTP/1.1\r\n")) + c.Conn.Write([]byte("Host: echo.websocket.org\r\n")) + c.Conn.Write([]byte("Upgrade: websocket\r\n")) + c.Conn.Write([]byte("Connection: Upgrade\r\n")) + c.Conn.Write([]byte("Origin: http://127.0.0.1\r\n")) + c.Conn.Write([]byte("Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n")) + c.Conn.Write([]byte("Sec-WebSocket-Protocol: binary\r\n")) + c.Conn.Write([]byte("Sec-WebSocket-Version: 13\r\n")) + c.Conn.Write([]byte("\r\n")) + + tpr := textproto.NewReader(bufio.NewReader(c.Conn)) + _, code, _, ok := parseFirstLine(tpr) + if !ok || code != "101" { + return errors.New("error in ws handshake") + } + + respHeader, err := tpr.ReadMIMEHeader() + if err != nil { + return err + } + + // TODO: verify this + respHeader.Get("Sec-WebSocket-Accept") + // fmt.Printf("respHeader: %+v\n", respHeader) + + return nil +} +func (c *Conn) Write(b []byte) (n int, err error) { + if c.writer == nil { + c.writer = FrameWriter(c.Conn) + } + + return c.writer.Write(b) +} +func (c *Conn) Read(b []byte) (n int, err error) { + if c.reader == nil { + c.reader = FrameReader(c.Conn) + } + + return c.reader.Read(b) +} + +// parseFirstLine parses "GET /foo HTTP/1.1" OR "HTTP/1.1 200 OK" into its three parts. +// TODO: move to seperate http lib package for reuse(also for http proxy module) +func parseFirstLine(tp *textproto.Reader) (r1, r2, r3 string, ok bool) { + line, err := tp.ReadLine() + // log.F("first line: %s", line) + if err != nil { + log.F("[http] read first line error:%s", err) + return + } + + s1 := strings.Index(line, " ") + s2 := strings.Index(line[s1+1:], " ") + if s1 < 0 || s2 < 0 { + return + } + s2 += s1 + 1 + return line[:s1], line[s1+1 : s2], line[s2+1:], true +} diff --git a/proxy/ws/frame.go b/proxy/ws/frame.go new file mode 100644 index 0000000..674d89c --- /dev/null +++ b/proxy/ws/frame.go @@ -0,0 +1,149 @@ +// https://tools.ietf.org/html/rfc6455 + +package ws + +import ( + "bytes" + "encoding/binary" + "io" + "math/rand" +) + +const ( + finalBit = 1 << 7 + defaultFrameSize = 1 << 13 // 8192 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maskBit = 1 << 7 + opCodeBinary = 2 + opClose = 8 + maskKeyLen = 4 +) + +type frame struct { +} + +type frameWriter struct { + io.Writer + buf []byte + maskKey []byte +} + +// FrameWriter returns a frame writer +func FrameWriter(w io.Writer) io.Writer { + n := rand.Uint32() + return &frameWriter{ + Writer: w, + buf: make([]byte, maxFrameHeaderSize+defaultFrameSize), + maskKey: []byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)}, + } +} + +func (w *frameWriter) Write(b []byte) (int, error) { + n, err := w.ReadFrom(bytes.NewBuffer(b)) + return int(n), err +} + +func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { + for { + buf := w.buf + + nr, er := r.Read(buf) + if nr > 0 { + n += int64(nr) + buf[0] |= finalBit + buf[0] |= opCodeBinary + buf[1] |= maskBit + + lengthFieldLen := 0 + switch { + case nr <= 125: + buf[1] |= byte(nr) + lengthFieldLen = 2 + case nr < 65536: + buf[1] |= 126 + lengthFieldLen = 2 + binary.BigEndian.PutUint16(buf[2:2+lengthFieldLen], uint16(nr)) + default: + buf[1] |= 127 + lengthFieldLen = 8 + binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr)) + } + + copy(buf[2+lengthFieldLen:], w.maskKey) + payloadBuf := buf[2+lengthFieldLen+maskKeyLen:] + for i := range payloadBuf { + payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4] + } + + _, ew := w.Writer.Write(buf) + if ew != nil { + err = ew + break + } + } + + if er != nil { + if er != io.EOF { // ignore EOF as per io.ReaderFrom contract + err = er + } + break + } + + } + + return n, err +} + +type frameReader struct { + io.Reader + buf []byte + leftover []byte +} + +// FrameReader returns a chunked reader +func FrameReader(r io.Reader) io.Reader { + return &frameReader{ + Reader: r, + buf: make([]byte, defaultFrameSize), + } +} + +func (r *frameReader) Read(b []byte) (int, error) { + if len(r.leftover) > 0 { + n := copy(b, r.leftover) + r.leftover = r.leftover[n:] + return n, nil + } + + // get msg header + _, err := io.ReadFull(r.Reader, r.buf[:2]) + if err != nil { + return 0, err + } + + // final := r.buf[0]&finalBit != 0 + // frameType := int(r.buf[0] & 0xf) + // mask := r.buf[1]&maskBit != 0 + len := int64(r.buf[1] & 0x7f) + switch len { + case 126: + io.ReadFull(r.Reader, r.buf[:2]) + len = int64(binary.BigEndian.Uint16(r.buf[0:])) + case 127: + io.ReadFull(r.Reader, r.buf[:8]) + len = int64(binary.BigEndian.Uint64(r.buf[0:])) + } + + // get payload + _, err = io.ReadFull(r.Reader, r.buf[:len]) + if err != nil { + return 0, err + } + + m := copy(b, r.buf[:len]) + if m < int(len) { + r.leftover = r.buf[m:len] + } + + return m, err +} diff --git a/proxy/ws/ws.go b/proxy/ws/ws.go new file mode 100644 index 0000000..2500f2a --- /dev/null +++ b/proxy/ws/ws.go @@ -0,0 +1,78 @@ +package ws + +import ( + "errors" + "net" + "net/url" + + "github.com/nadoo/glider/common/log" + "github.com/nadoo/glider/proxy" +) + +// WS . +type WS struct { + dialer proxy.Dialer + addr string + + client *Client +} + +func init() { + proxy.RegisterDialer("ws", NewWSDialer) +} + +// NewWS returns a websocket proxy. +func NewWS(s string, dialer proxy.Dialer) (*WS, error) { + u, err := url.Parse(s) + if err != nil { + log.F("parse url err: %s", err) + return nil, err + } + + addr := u.Host + + client, err := NewClient() + if err != nil { + log.F("create ws client err: %s", err) + return nil, err + } + + p := &WS{ + dialer: dialer, + addr: addr, + client: client, + } + + return p, nil +} + +// NewWSDialer returns a ws proxy dialer. +func NewWSDialer(s string, dialer proxy.Dialer) (proxy.Dialer, error) { + return NewWS(s, dialer) +} + +// Addr returns forwarder's address +func (s *WS) Addr() string { + if s.addr == "" { + return s.dialer.Addr() + } + return s.addr +} + +// NextDialer returns the next dialer +func (s *WS) NextDialer(dstAddr string) proxy.Dialer { return s.dialer.NextDialer(dstAddr) } + +// Dial connects to the address addr on the network net via the proxy. +func (s *WS) Dial(network, addr string) (net.Conn, error) { + rc, err := s.dialer.Dial("tcp", s.addr) + if err != nil { + return nil, err + } + + return s.client.NewConn(rc, addr) +} + +// DialUDP connects to the given address via the proxy. +func (s *WS) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) { + return nil, nil, errors.New("ws client does not support udp now") +}