glider/proxy/anytls/session.go
2026-06-23 23:17:09 +08:00

234 lines
4.2 KiB
Go

package anytls
import (
"errors"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
type session struct {
conn net.Conn
writeMu sync.Mutex
mu sync.Mutex
streams map[uint32]*stream
synack map[uint32]chan synackResult
nextID uint32
incoming chan *stream
done chan struct{}
closeOnce sync.Once
err atomic.Value
settingsSeen bool
}
type synackResult struct {
data []byte
}
func newSession(conn net.Conn) *session {
return &session{
conn: conn,
streams: map[uint32]*stream{},
synack: map[uint32]chan synackResult{},
nextID: 1,
incoming: make(chan *stream, 32),
done: make(chan struct{}),
}
}
func (s *session) start() {
go s.readLoop()
}
func (s *session) acceptStream() (*stream, error) {
select {
case st, ok := <-s.incoming:
if !ok {
return nil, s.Err()
}
return st, nil
case <-s.done:
return nil, s.Err()
}
}
func (s *session) openStream() (*stream, error) {
id := atomic.AddUint32(&s.nextID, 1) - 1
st := newStream(id, s)
s.mu.Lock()
s.streams[id] = st
s.synack[id] = make(chan synackResult, 1)
s.mu.Unlock()
if err := s.writeFrame(frame{command: cmdSYN, streamID: id}); err != nil {
s.removeStream(id)
return nil, err
}
return st, nil
}
func (s *session) waitSYNACK(id uint32, timeout time.Duration) error {
s.mu.Lock()
ch := s.synack[id]
s.mu.Unlock()
if ch == nil {
return nil
}
var timer <-chan time.Time
if timeout > 0 {
t := time.NewTimer(timeout)
defer t.Stop()
timer = t.C
}
select {
case r, ok := <-ch:
if !ok {
return s.Err()
}
if len(r.data) > 0 {
return fmt.Errorf("stream open failed: %s", string(r.data))
}
return nil
case <-timer:
return errors.New("timeout waiting for SYNACK")
case <-s.done:
return s.Err()
}
}
func (s *session) writeFrame(f frame) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
return writeFrame(s.conn, f)
}
func (s *session) readLoop() {
for {
f, err := readFrame(s.conn)
if err != nil {
if !errors.Is(err, io.EOF) {
s.setErr(err)
}
s.Close()
return
}
if err := s.handleFrame(f); err != nil {
s.setErr(err)
s.Close()
return
}
}
}
func (s *session) handleFrame(f frame) error {
switch f.command {
case cmdWaste:
return nil
case cmdHeartRequest:
return s.writeFrame(frame{command: cmdHeartResponse, streamID: f.streamID})
case cmdHeartResponse:
return nil
case cmdSettings:
m := parseSettings(f.data)
s.settingsSeen = true
if settingsVersion(m) >= 2 {
return s.writeFrame(frame{command: cmdServerSettings, data: serverSettings()})
}
case cmdServerSettings:
return nil
case cmdAlert:
return errors.New("alert: " + string(f.data))
case cmdUpdatePaddingScheme:
return nil
case cmdSYN:
if !s.settingsSeen {
_ = s.writeFrame(frame{command: cmdAlert, data: []byte("cmdSYN received before cmdSettings")})
return errors.New("cmdSYN received before cmdSettings")
}
st := newStream(f.streamID, s)
s.mu.Lock()
s.streams[f.streamID] = st
s.mu.Unlock()
select {
case s.incoming <- st:
case <-s.done:
}
case cmdSYNACK:
s.mu.Lock()
ch := s.synack[f.streamID]
delete(s.synack, f.streamID)
s.mu.Unlock()
if ch != nil {
ch <- synackResult{data: f.data}
close(ch)
}
case cmdPSH:
st := s.getStream(f.streamID)
if st != nil {
st.push(f.data)
}
case cmdFIN:
st := s.getStream(f.streamID)
if st != nil {
st.closeRead()
s.removeStream(f.streamID)
}
default:
return errors.New("unknown command")
}
return nil
}
func (s *session) getStream(id uint32) *stream {
s.mu.Lock()
defer s.mu.Unlock()
return s.streams[id]
}
func (s *session) removeStream(id uint32) {
s.mu.Lock()
delete(s.streams, id)
if ch := s.synack[id]; ch != nil {
delete(s.synack, id)
close(ch)
}
s.mu.Unlock()
}
func (s *session) setErr(err error) {
if err != nil && s.err.Load() == nil {
s.err.Store(err)
}
}
func (s *session) Err() error {
if v := s.err.Load(); v != nil {
return v.(error)
}
return net.ErrClosed
}
func (s *session) Close() error {
s.closeOnce.Do(func() {
close(s.done)
_ = s.conn.Close()
s.mu.Lock()
for _, st := range s.streams {
st.closeRead()
}
s.streams = map[uint32]*stream{}
for id, ch := range s.synack {
delete(s.synack, id)
close(ch)
}
close(s.incoming)
s.mu.Unlock()
})
return nil
}