ws: optimized the frame implementation

This commit is contained in:
nadoo 2020-10-20 20:28:35 +08:00
parent beec9d205f
commit cd78995cd4
6 changed files with 70 additions and 98 deletions

View File

@ -59,7 +59,7 @@ we can set up local listeners as proxy servers, and forward requests to internet
|tls |√| |√| |transport client & server
|kcp | |√|√| |transport client & server
|unix |√| |√| |transport client & server
|websocket |√| |√| |transport client only
|websocket |√| |√| |transport client & server
|simple-obfs | | |√| |transport client only
|tcptun |√| | | |transport server only
|udptun | |√| | |transport server only

4
go.mod
View File

@ -11,9 +11,9 @@ require (
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/xtaci/kcp-go/v5 v5.6.1
golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 // indirect
golang.org/x/net v0.0.0-20201020065357-d65d470038a5 // indirect
golang.org/x/sys v0.0.0-20201018230417-eeed37f84f13 // indirect
golang.org/x/tools v0.0.0-20201017001424-6003fad69a88 // indirect
golang.org/x/tools v0.0.0-20201019175715-b894a3290fff // indirect
gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect
)

8
go.sum
View File

@ -141,8 +141,8 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381 h1:VXak5I6aEWmAXeQjA+QSZzlgN
golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0 h1:5kGOVHlq0euqwzgTC9Vu15p6fV1Wi0ArVi8da2urnVg=
golang.org/x/net v0.0.0-20201016165138-7b1cca2348c0/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20201020065357-d65d470038a5 h1:KrxvpY64uUzANd9wKWr6ZAsufiii93XnvXaeikyCJ2g=
golang.org/x/net v0.0.0-20201020065357-d65d470038a5/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -177,8 +177,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU=
golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA=
golang.org/x/tools v0.0.0-20201017001424-6003fad69a88 h1:ZB1XYzdDo7c/O48jzjMkvIjnC120Z9/CwgDWhePjQdQ=
golang.org/x/tools v0.0.0-20201017001424-6003fad69a88/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.0.0-20201019175715-b894a3290fff h1:HiwHyqQ9ttqCHuTa++R4wNxOg6MY1hduSDT8j2aXoMM=
golang.org/x/tools v0.0.0-20201019175715-b894a3290fff/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=

View File

@ -103,14 +103,14 @@ func (c *ClientConn) Handshake(host, path string) error {
func (c *ClientConn) Write(b []byte) (n int, err error) {
if c.writer == nil {
c.writer = FrameWriter(c.Conn, true)
c.writer = FrameWriter(c.Conn, false)
}
return c.writer.Write(b)
}
func (c *ClientConn) Read(b []byte) (n int, err error) {
if c.reader == nil {
c.reader = FrameReader(c.Conn, true)
c.reader = FrameReader(c.Conn, false)
}
return c.reader.Read(b)
}

View File

@ -23,10 +23,11 @@
package ws
import (
"bytes"
"encoding/binary"
"io"
"math/rand"
"github.com/nadoo/glider/pool"
)
const (
@ -43,97 +44,67 @@ const (
type frameWriter struct {
io.Writer
buf []byte
client bool
maskKey [4]byte
header [maxHeaderSize]byte
server bool
maskKey [4]byte
maskOffset int
}
// FrameWriter returns a frame writer.
func FrameWriter(w io.Writer, client bool) io.Writer {
func FrameWriter(w io.Writer, server bool) io.Writer {
n := rand.Uint32()
return &frameWriter{
Writer: w,
buf: make([]byte, maxHeaderSize+defaultFrameSize),
client: client,
server: server,
maskKey: [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)},
}
}
func (w *frameWriter) Write(b []byte) (int, error) {
n, err := w.ReadFrom(bytes.NewBuffer(b))
return int(n), err
}
func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
payloadBuf := buf[maxHeaderSize:]
nr, er := r.Read(payloadBuf)
if nr > 0 {
n += int64(nr)
buf[0] = opCodeBinary
buf[1] = 0
if w.client {
buf[0] |= finalBit
buf[1] = maskBit
}
lengthFieldLen := 0
switch {
case nr <= 125:
buf[1] |= byte(nr)
case nr < 65536:
buf[1] |= 126
lengthFieldLen = 2
binary.BigEndian.PutUint16(buf[2:2+lengthFieldLen], uint16(nr))
default:
buf[1] |= 127
lengthFieldLen = 8
binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr))
}
// header and length
_, ew := w.Writer.Write(buf[:2+lengthFieldLen])
if ew != nil {
err = ew
break
}
payloadBuf = payloadBuf[:nr]
if w.client {
// maskkey
_, ew = w.Writer.Write(w.maskKey[:])
if ew != nil {
err = ew
break
}
// payload mask
for i := range payloadBuf {
payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4]
}
}
_, ew = w.Writer.Write(payloadBuf)
if ew != nil {
err = ew
break
}
}
if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
err = er
}
break
}
hdr := w.header
hdr[0], hdr[1] = opCodeBinary|finalBit, 0
if !w.server {
hdr[1] = maskBit
}
return n, err
nPayload, lenFieldLen := len(b), 0
switch {
case nPayload <= 125:
hdr[1] |= byte(nPayload)
case nPayload < 65536:
hdr[1] |= 126
lenFieldLen = 2
binary.BigEndian.PutUint16(hdr[2:2+lenFieldLen], uint16(nPayload))
default:
hdr[1] |= 127
lenFieldLen = 8
binary.BigEndian.PutUint64(hdr[2:2+lenFieldLen], uint64(nPayload))
}
// header and length
_, err := w.Writer.Write(hdr[:2+lenFieldLen])
if err != nil {
return 0, err
}
if w.server {
return w.Writer.Write(b)
}
buf := pool.GetBuffer(nPayload)
pool.PutBuffer(buf)
_, err = w.Writer.Write(w.maskKey[:])
if err != nil {
return 0, err
}
// payload mask
for i := 0; i < nPayload; i++ {
buf[i] = b[i] ^ w.maskKey[i%4]
}
return w.Writer.Write(buf)
}
type frameReader struct {
@ -146,22 +117,21 @@ type frameReader struct {
}
// FrameReader returns a chunked reader.
func FrameReader(r io.Reader, client bool) io.Reader {
return &frameReader{Reader: r, server: !client}
func FrameReader(r io.Reader, server bool) io.Reader {
return &frameReader{Reader: r, server: server}
}
func (r *frameReader) Read(b []byte) (int, error) {
if r.left == 0 {
// get msg header
_, err := io.ReadFull(r.Reader, r.buf[:2])
if err != nil {
return 0, err
}
// final := r.buf[0]&finalBit != 0
// final := r.buf[0]&finalBit == finalBit
// frameType := int(r.buf[0] & 0xf)
// r.mask = r.buf[1]&maskBit != 0
// r.mask = r.buf[1]&maskBit == maskBit
r.left = int64(r.buf[1] & 0x7f)
switch r.left {
@ -193,9 +163,9 @@ func (r *frameReader) Read(b []byte) (int, error) {
readLen = r.left
}
m, err := r.Reader.Read(b[:readLen])
m, err := io.ReadFull(r.Reader, b[:readLen])
if err != nil {
return 0, err
return m, err
}
if r.server {

View File

@ -3,6 +3,7 @@ package ws
import (
"bufio"
"errors"
"fmt"
"io"
"net"
"net/textproto"
@ -70,6 +71,7 @@ func (s *WS) Serve(c net.Conn) {
if s.server != nil {
sc, err := s.NewServerConn(c)
if err != nil {
log.F("[ws] handshake error: %s", err)
return
}
s.server.Serve(sc)
@ -109,7 +111,7 @@ func (c *ServerConn) Handshake(host, path string) error {
}
if reqHeader.Get("Host") != host {
return errors.New("[ws] got wrong host")
return fmt.Errorf("[ws] got wrong host: %s, expected: %s", reqHeader.Get("Host"), host)
}
clientKey := reqHeader.Get("Sec-WebSocket-Key")
@ -131,14 +133,14 @@ func (c *ServerConn) Handshake(host, path string) error {
func (c *ServerConn) Write(b []byte) (n int, err error) {
if c.writer == nil {
c.writer = FrameWriter(c.Conn, false)
c.writer = FrameWriter(c.Conn, true)
}
return c.writer.Write(b)
}
func (c *ServerConn) Read(b []byte) (n int, err error) {
if c.reader == nil {
c.reader = FrameReader(c.Conn, false)
c.reader = FrameReader(c.Conn, true)
}
return c.reader.Read(b)
}