diff --git a/proxy/ws/frame.go b/proxy/ws/frame.go index c336039..4df6c4d 100644 --- a/proxy/ws/frame.go +++ b/proxy/ws/frame.go @@ -30,13 +30,16 @@ import ( ) const ( - finalBit byte = 1 << 7 - defaultFrameSize = 4096 - maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask - maskBit byte = 1 << 7 - opCodeBinary byte = 2 - opClose byte = 8 - maskKeyLen = 4 + defaultFrameSize = 4096 + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maskKeyLen = 4 +) + +const ( + finalBit byte = 1 << 7 + maskBit byte = 1 << 7 + opCodeBinary byte = 2 + opClose byte = 8 ) type frameWriter struct { @@ -85,18 +88,21 @@ func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { binary.BigEndian.PutUint64(buf[2:2+lengthFieldLen], uint64(nr)) } + // header and length _, ew := w.Writer.Write(buf[:2+lengthFieldLen]) if ew != nil { err = ew break } + // maskkey _, ew = w.Writer.Write(w.maskKey) if ew != nil { err = ew break } + // payload payloadBuf = payloadBuf[:nr] for i := range payloadBuf { payloadBuf[i] = payloadBuf[i] ^ w.maskKey[i%4] @@ -123,8 +129,8 @@ func (w *frameWriter) ReadFrom(r io.Reader) (n int64, err error) { type frameReader struct { io.Reader - buf []byte - leftover []byte + buf []byte + leftBytes int64 } // FrameReader returns a chunked reader @@ -136,41 +142,49 @@ func FrameReader(r io.Reader) io.Reader { } func (r *frameReader) Read(b []byte) (int, error) { - if len(r.leftover) > 0 { - n := copy(b, r.leftover) - r.leftover = r.leftover[n:] - return n, nil + if r.leftBytes == 0 { + // get msg header + _, err := io.ReadFull(r.Reader, r.buf[:2]) + if err != nil { + return 0, err + } + + // final := r.buf[0]&finalBit != 0 + // frameType := int(r.buf[0] & 0xf) + // mask := r.buf[1]&maskBit != 0 + r.leftBytes = int64(r.buf[1] & 0x7f) + switch r.leftBytes { + case 126: + io.ReadFull(r.Reader, r.buf[:2]) + r.leftBytes = int64(binary.BigEndian.Uint16(r.buf[0:])) + case 127: + io.ReadFull(r.Reader, r.buf[:8]) + r.leftBytes = int64(binary.BigEndian.Uint64(r.buf[0:])) + } + } + + var n, m int + var err error + if r.leftBytes > int64(len(r.buf)) { + if len(b) < len(r.buf) { + n, err = r.Reader.Read(r.buf[:len(b)]) + } else { + n, err = r.Reader.Read(r.buf) + } + } else { + if int64(len(b)) < r.leftBytes { + n, err = io.ReadFull(r.Reader, r.buf[:len(b)]) + } else { + n, err = io.ReadFull(r.Reader, r.buf[:r.leftBytes]) + } } - // get msg header - _, err := io.ReadFull(r.Reader, r.buf[:2]) if err != nil { return 0, err } - // final := r.buf[0]&finalBit != 0 - // frameType := int(r.buf[0] & 0xf) - // mask := r.buf[1]&maskBit != 0 - len := int64(r.buf[1] & 0x7f) - switch len { - case 126: - io.ReadFull(r.Reader, r.buf[:2]) - len = int64(binary.BigEndian.Uint16(r.buf[0:])) - case 127: - io.ReadFull(r.Reader, r.buf[:8]) - len = int64(binary.BigEndian.Uint64(r.buf[0:])) - } - - // get payload - _, err = io.ReadFull(r.Reader, r.buf[:len]) - if err != nil { - return 0, err - } - - m := copy(b, r.buf[:len]) - if m < int(len) { - r.leftover = r.buf[m:len] - } + m = copy(b, r.buf[:n]) + r.leftBytes -= int64(m) return m, err }