glider/proxy/anytls/session.go

390 lines
7.5 KiB
Go

package anytls
import (
"encoding/binary"
"fmt"
"io"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
)
const (
cmdWaste byte = iota
cmdSYN
cmdPSH
cmdFIN
cmdSettings
cmdAlert
cmdUpdatePaddingScheme
cmdSYNACK
cmdHeartRequest
cmdHeartResponse
cmdServerSettings
)
const headerOverHeadSize = 1 + 4 + 2
type frame struct {
cmd byte
sid uint32
data []byte
}
func newFrame(cmd byte, sid uint32) frame {
return frame{cmd: cmd, sid: sid}
}
type rawHeader [headerOverHeadSize]byte
func (h rawHeader) Cmd() byte {
return h[0]
}
func (h rawHeader) StreamID() uint32 {
return binary.BigEndian.Uint32(h[1:5])
}
func (h rawHeader) Length() uint16 {
return binary.BigEndian.Uint16(h[5:7])
}
type Session struct {
conn io.ReadWriteCloser
connLock sync.Mutex
streams map[uint32]*Stream
streamID atomic.Uint32
streamLock sync.RWMutex
dieOnce sync.Once
die chan struct{}
dieHook func()
seq uint64
idleSince time.Time
peerVersion byte
isClient bool
sendPadding bool
buffering bool
buffer []byte
pktCounter atomic.Uint32
}
func NewClientSession(conn io.ReadWriteCloser) *Session {
return &Session{
conn: conn,
streams: make(map[uint32]*Stream),
die: make(chan struct{}),
isClient: true,
sendPadding: true,
}
}
func (s *Session) Run() {
settings := []byte("v=" + protocolVersion + "\nclient=" + programVersionName + "\npadding-md5=" + loadPaddingFactory().Md5)
frame := newFrame(cmdSettings, 0)
frame.data = settings
s.buffering = true
_, _ = s.writeControlFrame(frame)
go s.recvLoop()
}
func (s *Session) IsClosed() bool {
select {
case <-s.die:
return true
default:
return false
}
}
func (s *Session) Close() error {
var once bool
s.dieOnce.Do(func() {
close(s.die)
once = true
})
if !once {
return io.ErrClosedPipe
}
if s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
s.streamLock.Lock()
for _, stream := range s.streams {
stream.closeLocally()
}
s.streams = make(map[uint32]*Stream)
s.streamLock.Unlock()
return s.conn.Close()
}
func (s *Session) OpenStream() (*Stream, error) {
if s.IsClosed() {
return nil, io.ErrClosedPipe
}
sid := s.streamID.Add(1)
stream := newStream(sid, s)
if _, err := s.writeControlFrame(newFrame(cmdSYN, sid)); err != nil {
return nil, err
}
s.buffering = false
s.streamLock.Lock()
s.streams[sid] = stream
s.streamLock.Unlock()
return stream, nil
}
func (s *Session) recvLoop() error {
defer s.Close()
var hdr rawHeader
for {
if s.IsClosed() {
return io.ErrClosedPipe
}
if _, err := io.ReadFull(s.conn, hdr[:]); err != nil {
return err
}
sid := hdr.StreamID()
length := int(hdr.Length())
switch hdr.Cmd() {
case cmdPSH:
if length == 0 {
continue
}
payload := make([]byte, length)
if _, err := io.ReadFull(s.conn, payload); err != nil {
return err
}
s.streamLock.RLock()
stream := s.streams[sid]
s.streamLock.RUnlock()
if stream != nil {
stream.feed(payload)
}
case cmdSYNACK:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
continue
}
s.streamLock.RLock()
stream := s.streams[sid]
s.streamLock.RUnlock()
if stream != nil {
stream.closeWithError(fmt.Errorf("remote: %s", string(payload)))
}
case cmdFIN:
s.streamLock.Lock()
stream := s.streams[sid]
delete(s.streams, sid)
s.streamLock.Unlock()
if stream != nil {
stream.closeLocally()
}
case cmdWaste:
if _, err := s.readPayload(length); err != nil {
return err
}
case cmdAlert:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
return io.ErrUnexpectedEOF
}
return fmt.Errorf("[anytls] alert from server: %s", string(payload))
case cmdUpdatePaddingScheme:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) > 0 {
UpdatePaddingScheme(payload)
}
case cmdHeartRequest:
if _, err := s.writeControlFrame(newFrame(cmdHeartResponse, sid)); err != nil {
return err
}
case cmdHeartResponse:
if _, err := s.readPayload(length); err != nil {
return err
}
case cmdServerSettings:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
continue
}
if version, err := strconv.Atoi(stringMapFromBytes(payload)["v"]); err == nil {
s.peerVersion = byte(version)
}
default:
if _, err := s.readPayload(length); err != nil {
return err
}
}
}
}
func (s *Session) readPayload(length int) ([]byte, error) {
if length == 0 {
return nil, nil
}
payload := make([]byte, length)
_, err := io.ReadFull(s.conn, payload)
return payload, err
}
func (s *Session) streamClosed(sid uint32) error {
if s.IsClosed() {
return io.ErrClosedPipe
}
_, err := s.writeControlFrame(newFrame(cmdFIN, sid))
s.streamLock.Lock()
delete(s.streams, sid)
s.streamLock.Unlock()
return err
}
func (s *Session) writeDataFrame(sid uint32, data []byte) (int, error) {
buffer := make([]byte, headerOverHeadSize+len(data))
buffer[0] = cmdPSH
binary.BigEndian.PutUint32(buffer[1:5], sid)
binary.BigEndian.PutUint16(buffer[5:7], uint16(len(data)))
copy(buffer[7:], data)
if _, err := s.writeConn(buffer); err != nil {
return 0, err
}
return len(data), nil
}
func (s *Session) writeControlFrame(frame frame) (int, error) {
buffer := make([]byte, headerOverHeadSize+len(frame.data))
buffer[0] = frame.cmd
binary.BigEndian.PutUint32(buffer[1:5], frame.sid)
binary.BigEndian.PutUint16(buffer[5:7], uint16(len(frame.data)))
copy(buffer[7:], frame.data)
if conn, ok := s.conn.(interface{ SetWriteDeadline(time.Time) error }); ok {
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
defer conn.SetWriteDeadline(time.Time{})
}
if _, err := s.writeConn(buffer); err != nil {
s.Close()
return 0, err
}
return len(frame.data), nil
}
func (s *Session) writeConn(b []byte) (int, error) {
s.connLock.Lock()
defer s.connLock.Unlock()
if s.buffering {
s.buffer = append(s.buffer, b...)
return len(b), nil
}
if len(s.buffer) > 0 {
b = slices.Concat(s.buffer, b)
s.buffer = nil
}
if s.sendPadding {
padding := loadPaddingFactory()
if padding != nil {
pkt := s.pktCounter.Add(1)
if pkt < padding.Stop {
return s.writeWithPadding(b, padding.GenerateRecordPayloadSizes(pkt))
}
}
s.sendPadding = false
}
return s.conn.Write(b)
}
func (s *Session) writeWithPadding(payload []byte, sizes []int) (int, error) {
n := 0
b := payload
for _, size := range sizes {
remain := len(b)
if size == checkMark {
if remain == 0 {
break
}
continue
}
switch {
case remain > size:
written, err := s.conn.Write(b[:size])
n += written
if err != nil {
return 0, err
}
b = b[size:]
case remain > 0:
paddingLen := size - remain - headerOverHeadSize
if paddingLen > 0 {
padding := make([]byte, headerOverHeadSize+paddingLen)
padding[0] = cmdWaste
binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen))
b = slices.Concat(b, padding)
}
written, err := s.conn.Write(b)
n += min(written, remain)
if err != nil {
return 0, err
}
b = nil
case remain == 0:
padding := make([]byte, headerOverHeadSize+size)
padding[0] = cmdWaste
binary.BigEndian.PutUint16(padding[5:7], uint16(size))
if _, err := s.conn.Write(padding); err != nil {
return 0, err
}
}
}
if len(b) == 0 {
return n, nil
}
written, err := s.conn.Write(b)
n += min(written, len(b))
return n, err
}
func min(a, b int) int {
if a < b {
return a
}
return b
}