glider/proxy/anytls/session/client.go
2026-04-09 18:49:23 +08:00

197 lines
3.8 KiB
Go

package session
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"
)
// DialFunc is a function that creates a new outbound connection.
type DialFunc func(ctx context.Context) (net.Conn, error)
// Client manages a pool of sessions for stream multiplexing.
type Client struct {
ctx context.Context
cancel context.CancelFunc
dialOut DialFunc
sessionCounter atomic.Uint64
idleSessions []*Session
idleLock sync.Mutex
sessions map[uint64]*Session
sessionsLock sync.Mutex
idleTimeout time.Duration
}
// NewClient creates a new session pool client.
func NewClient(ctx context.Context, dialOut DialFunc, idleCheckInterval, idleTimeout time.Duration) *Client {
if idleCheckInterval < 5*time.Second {
idleCheckInterval = 30 * time.Second
}
if idleTimeout < 5*time.Second {
idleTimeout = 30 * time.Second
}
c := &Client{
dialOut: dialOut,
sessions: make(map[uint64]*Session),
idleTimeout: idleTimeout,
}
c.ctx, c.cancel = context.WithCancel(ctx)
go c.idleCleanupLoop(idleCheckInterval)
return c
}
// CreateStream opens a new stream, reusing an idle session or creating a new one.
func (c *Client) CreateStream(ctx context.Context) (net.Conn, error) {
select {
case <-c.ctx.Done():
return nil, io.ErrClosedPipe
default:
}
sess := c.getIdleSession()
if sess == nil {
var err error
sess, err = c.createSession(ctx)
if err != nil {
return nil, err
}
}
stream, err := sess.OpenStream()
if err != nil {
sess.Close()
return nil, err
}
// When the stream closes, return the session to the idle pool.
stream.CloseFunc = func() error {
err := stream.CloseRemote()
if !sess.IsClosed() {
select {
case <-c.ctx.Done():
go sess.Close()
default:
c.idleLock.Lock()
sess.IdleSince = time.Now()
c.idleSessions = append(c.idleSessions, sess)
c.idleLock.Unlock()
}
}
return err
}
return stream, nil
}
func (c *Client) getIdleSession() *Session {
c.idleLock.Lock()
defer c.idleLock.Unlock()
// Reuse the newest idle session (last in slice).
for len(c.idleSessions) > 0 {
n := len(c.idleSessions)
sess := c.idleSessions[n-1]
c.idleSessions = c.idleSessions[:n-1]
if !sess.IsClosed() {
return sess
}
}
return nil
}
func (c *Client) createSession(ctx context.Context) (*Session, error) {
conn, err := c.dialOut(ctx)
if err != nil {
return nil, err
}
sess := NewClientSession(conn)
sess.Seq = c.sessionCounter.Add(1)
sess.DieHook = func() {
c.idleLock.Lock()
for i, s := range c.idleSessions {
if s == sess {
c.idleSessions = append(c.idleSessions[:i], c.idleSessions[i+1:]...)
break
}
}
c.idleLock.Unlock()
c.sessionsLock.Lock()
delete(c.sessions, sess.Seq)
c.sessionsLock.Unlock()
}
c.sessionsLock.Lock()
c.sessions[sess.Seq] = sess
c.sessionsLock.Unlock()
sess.Run()
return sess, nil
}
// Close shuts down the client and all sessions.
func (c *Client) Close() error {
c.cancel()
c.sessionsLock.Lock()
toClose := make([]*Session, 0, len(c.sessions))
for _, sess := range c.sessions {
toClose = append(toClose, sess)
}
c.sessions = make(map[uint64]*Session)
c.sessionsLock.Unlock()
for _, sess := range toClose {
sess.Close()
}
return nil
}
func (c *Client) idleCleanupLoop(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-c.ctx.Done():
return
case <-ticker.C:
c.idleCleanup()
}
}
}
func (c *Client) idleCleanup() {
expTime := time.Now().Add(-c.idleTimeout)
var toClose []*Session
c.idleLock.Lock()
remaining := c.idleSessions[:0]
for _, sess := range c.idleSessions {
if sess.IsClosed() {
continue
}
if sess.IdleSince.Before(expTime) {
toClose = append(toClose, sess)
} else {
remaining = append(remaining, sess)
}
}
c.idleSessions = remaining
c.idleLock.Unlock()
for _, sess := range toClose {
sess.Close()
}
}