feat(anytls): implement AnyTLS protocol

This commit is contained in:
0xec 2026-04-30 23:40:42 +08:00
parent 5ca07d4464
commit a92efe65c5
10 changed files with 1367 additions and 1 deletions

View File

@ -5,6 +5,7 @@ import (
// _ "github.com/nadoo/glider/service/xxx" // _ "github.com/nadoo/glider/service/xxx"
// comment out the protocols you don't need to make the compiled binary smaller. // comment out the protocols you don't need to make the compiled binary smaller.
_ "github.com/nadoo/glider/proxy/anytls"
_ "github.com/nadoo/glider/proxy/http" _ "github.com/nadoo/glider/proxy/http"
_ "github.com/nadoo/glider/proxy/kcp" _ "github.com/nadoo/glider/proxy/kcp"
_ "github.com/nadoo/glider/proxy/mixed" _ "github.com/nadoo/glider/proxy/mixed"

102
proxy/anytls/addr.go Normal file
View File

@ -0,0 +1,102 @@
package anytls
import (
"net"
"net/netip"
"strconv"
"strings"
"github.com/nadoo/glider/pkg/socks"
)
const (
uotAddrIPv4 = 0x00
uotAddrIPv6 = 0x01
uotAddrFQDN = 0x02
)
type addrPort struct {
Addr netip.Addr
FQDN string
Port uint16
}
func (a addrPort) IsValid() bool {
return a.Addr.IsValid() || a.FQDN != ""
}
func (a addrPort) UDPAddr() *net.UDPAddr {
if a.Addr.IsValid() {
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(a.Addr, a.Port))
}
addr, _ := net.ResolveUDPAddr("udp", net.JoinHostPort(a.FQDN, strconv.Itoa(int(a.Port))))
return addr
}
func parseAddrPort(value string) addrPort {
host, port, err := net.SplitHostPort(value)
if err != nil {
return addrPort{}
}
portNum, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return addrPort{}
}
if ip, err := netip.ParseAddr(host); err == nil {
return addrPort{Addr: ip, Port: uint16(portNum)}
}
return addrPort{FQDN: strings.TrimSuffix(host, "."), Port: uint16(portNum)}
}
func addrPortLen(addr addrPort) int {
if !addr.IsValid() {
return 1
}
if addr.Addr.IsValid() {
if addr.Addr.Is4() {
return 1 + 4 + 2
}
return 1 + 16 + 2
}
return 1 + 1 + len(addr.FQDN) + 2
}
func serializeAddrPort(addr addrPort) []byte {
if !addr.IsValid() {
return []byte{0}
}
if addr.Addr.IsValid() {
if addr.Addr.Is4() {
buf := make([]byte, 1+4+2)
buf[0] = uotAddrIPv4
copy(buf[1:5], addr.Addr.AsSlice())
binaryPort(buf[5:7], addr.Port)
return buf
}
buf := make([]byte, 1+16+2)
buf[0] = uotAddrIPv6
copy(buf[1:17], addr.Addr.AsSlice())
binaryPort(buf[17:19], addr.Port)
return buf
}
buf := make([]byte, 1+1+len(addr.FQDN)+2)
buf[0] = uotAddrFQDN
buf[1] = byte(len(addr.FQDN))
copy(buf[2:2+len(addr.FQDN)], addr.FQDN)
binaryPort(buf[2+len(addr.FQDN):], addr.Port)
return buf
}
func binaryPort(dst []byte, port uint16) {
dst[0] = byte(port >> 8)
dst[1] = byte(port)
}
func socksAddr(addr string) socks.Addr {
return socks.ParseAddr(addr)
}

254
proxy/anytls/anytls.go Normal file
View File

@ -0,0 +1,254 @@
package anytls
import (
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/nadoo/glider/proxy"
)
const (
defaultIdleSessionCheckInterval = 30 * time.Second
defaultIdleSessionTimeout = 30 * time.Second
defaultMinIdleSession = 5
protocolVersion = "2"
programVersionName = "glider-anytls"
uotMagicAddress = "sp.v2.udp-over-tcp.arpa"
)
type AnyTLS struct {
dialer proxy.Dialer
addr string
passwordSHA256 [32]byte
tlsConfig *tls.Config
client *Client
}
func init() {
proxy.RegisterDialer("anytls", NewAnyTLSDialer)
proxy.AddUsage("anytls", `
AnyTLS client scheme:
anytls://password@host:port[?sni=SERVERNAME][&insecure=1][&cert=PATH]
anytls://password@host:port[?serverName=SERVERNAME][&skipVerify=true][&cert=PATH]
anytls://password@host:port[?minIdleSession=5][&idleSessionCheckInterval=30s][&idleSessionTimeout=30s]
`)
}
func NewAnyTLSDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
a, err := NewAnyTLS(s, d)
if err != nil {
return nil, err
}
return a, nil
}
func NewAnyTLS(s string, d proxy.Dialer) (*AnyTLS, error) {
u, err := url.Parse(s)
if err != nil {
return nil, fmt.Errorf("[anytls] parse url err: %w", err)
}
password := ""
if u.User != nil {
password = u.User.Username()
}
if password == "" {
return nil, fmt.Errorf("[anytls] password must be specified")
}
addr := u.Host
if addr == "" {
return nil, fmt.Errorf("[anytls] server address must be specified")
}
if _, port, _ := net.SplitHostPort(addr); port == "" {
addr = net.JoinHostPort(addr, "443")
}
query := u.Query()
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("[anytls] invalid server address %q: %w", addr, err)
}
serverName := firstNonEmpty(query.Get("sni"), query.Get("serverName"))
if serverName == "" {
serverName = host
}
if net.ParseIP(serverName) != nil {
serverName = ""
}
skipVerify := isTrue(query.Get("insecure")) || isTrue(query.Get("skipVerify"))
certFile := query.Get("cert")
tlsConfig := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: skipVerify,
MinVersion: tls.VersionTLS12,
}
if certFile != "" {
certData, err := os.ReadFile(certFile)
if err != nil {
return nil, fmt.Errorf("[anytls] read cert file error: %w", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certData) {
return nil, fmt.Errorf("[anytls] can not append cert file: %s", certFile)
}
tlsConfig.RootCAs = certPool
}
idleCheckInterval := defaultIdleSessionCheckInterval
if value := query.Get("idleSessionCheckInterval"); value != "" {
idleCheckInterval, err = time.ParseDuration(value)
if err != nil {
return nil, fmt.Errorf("[anytls] invalid idleSessionCheckInterval: %w", err)
}
}
idleTimeout := defaultIdleSessionTimeout
if value := query.Get("idleSessionTimeout"); value != "" {
idleTimeout, err = time.ParseDuration(value)
if err != nil {
return nil, fmt.Errorf("[anytls] invalid idleSessionTimeout: %w", err)
}
}
minIdleSession := defaultMinIdleSession
if value := query.Get("minIdleSession"); value != "" {
minIdleSession, err = strconv.Atoi(value)
if err != nil {
return nil, fmt.Errorf("[anytls] invalid minIdleSession: %w", err)
}
if minIdleSession < 0 {
minIdleSession = 0
}
}
a := &AnyTLS{
dialer: d,
addr: addr,
passwordSHA256: sha256.Sum256([]byte(password)),
tlsConfig: tlsConfig,
}
a.client = NewClient(a.createAuthenticatedConn, idleCheckInterval, idleTimeout, minIdleSession)
return a, nil
}
func (a *AnyTLS) Addr() string {
if a.addr == "" {
return a.dialer.Addr()
}
return a.addr
}
func (a *AnyTLS) Dial(network, addr string) (net.Conn, error) {
stream, err := a.client.CreateStream()
if err != nil {
return nil, err
}
target := socksAddr(addr)
if target == nil {
stream.Close()
return nil, fmt.Errorf("[anytls] invalid target address: %s", addr)
}
if _, err := stream.Write(target); err != nil {
stream.Close()
return nil, err
}
return stream, nil
}
func (a *AnyTLS) DialUDP(network, addr string) (net.PacketConn, error) {
target := parseAddrPort(addr)
if !target.IsValid() {
return nil, fmt.Errorf("[anytls] invalid udp target address: %s", addr)
}
stream, err := a.client.CreateStream()
if err != nil {
return nil, err
}
proxyTarget := socksAddr(net.JoinHostPort(uotMagicAddress, "0"))
if proxyTarget == nil {
stream.Close()
return nil, fmt.Errorf("[anytls] invalid uot target")
}
if _, err := stream.Write(proxyTarget); err != nil {
stream.Close()
return nil, err
}
pc := NewPktConn(stream, target)
if err := pc.writeRequest(); err != nil {
stream.Close()
return nil, err
}
return pc, nil
}
func (a *AnyTLS) createAuthenticatedConn() (net.Conn, error) {
rawConn, err := a.dialer.Dial("tcp", a.addr)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, a.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
rawConn.Close()
return nil, err
}
paddingLen := 0
if padding := loadPaddingFactory(); padding != nil {
sizes := padding.GenerateRecordPayloadSizes(0)
if len(sizes) > 0 && sizes[0] > 0 {
paddingLen = sizes[0]
}
}
auth := make([]byte, 32+2+paddingLen)
copy(auth[:32], a.passwordSHA256[:])
auth[32] = byte(paddingLen >> 8)
auth[33] = byte(paddingLen)
if _, err := tlsConn.Write(auth); err != nil {
tlsConn.Close()
return nil, err
}
return tlsConn, nil
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if value != "" {
return value
}
}
return ""
}
func isTrue(value string) bool {
value = strings.ToLower(value)
return value == "1" || value == "true" || value == "yes" || value == "on"
}

211
proxy/anytls/client.go Normal file
View File

@ -0,0 +1,211 @@
package anytls
import (
"io"
"math"
"net"
"sort"
"sync"
"sync/atomic"
"time"
)
type Client struct {
dialOut func() (net.Conn, error)
sessionCounter atomic.Uint64
idleSessionLock sync.Mutex
idleSessions []*Session
sessionsLock sync.Mutex
sessions map[uint64]*Session
idleSessionTimeout time.Duration
minIdleSession int
closed atomic.Bool
stopCleanup chan struct{}
}
func NewClient(dialOut func() (net.Conn, error), idleSessionCheckInterval, idleSessionTimeout time.Duration, minIdleSession int) *Client {
if idleSessionCheckInterval <= 5*time.Second {
idleSessionCheckInterval = defaultIdleSessionCheckInterval
}
if idleSessionTimeout <= 5*time.Second {
idleSessionTimeout = defaultIdleSessionTimeout
}
c := &Client{
dialOut: dialOut,
sessions: make(map[uint64]*Session),
idleSessionTimeout: idleSessionTimeout,
minIdleSession: minIdleSession,
stopCleanup: make(chan struct{}),
}
go c.idleCleanupLoop(idleSessionCheckInterval)
return c
}
func (c *Client) CreateStream() (*Stream, error) {
if c.closed.Load() {
return nil, io.ErrClosedPipe
}
var (
session *Session
err error
)
session = c.getIdleSession()
if session == nil {
session, err = c.createSession()
}
if session == nil {
if err == nil {
err = io.ErrClosedPipe
}
return nil, err
}
stream, err := session.OpenStream()
if err != nil {
session.Close()
return nil, err
}
stream.dieHook = func() {
if c.closed.Load() || session.IsClosed() {
session.Close()
return
}
c.idleSessionLock.Lock()
session.idleSince = time.Now()
c.idleSessions = append(c.idleSessions, session)
sort.Slice(c.idleSessions, func(i, j int) bool {
return c.idleSessions[i].seq > c.idleSessions[j].seq
})
c.idleSessionLock.Unlock()
}
return stream, nil
}
func (c *Client) Close() error {
if !c.closed.CompareAndSwap(false, true) {
return io.ErrClosedPipe
}
close(c.stopCleanup)
c.sessionsLock.Lock()
sessions := make([]*Session, 0, len(c.sessions))
for _, session := range c.sessions {
sessions = append(sessions, session)
}
c.sessions = make(map[uint64]*Session)
c.sessionsLock.Unlock()
for _, session := range sessions {
session.Close()
}
return nil
}
func (c *Client) getIdleSession() *Session {
c.idleSessionLock.Lock()
defer c.idleSessionLock.Unlock()
for len(c.idleSessions) > 0 {
session := c.idleSessions[0]
c.idleSessions = c.idleSessions[1:]
if session != nil && !session.IsClosed() {
return session
}
}
return nil
}
func (c *Client) createSession() (*Session, error) {
underlying, err := c.dialOut()
if err != nil {
return nil, err
}
session := NewClientSession(underlying)
session.seq = c.sessionCounter.Add(1)
session.dieHook = func() {
c.idleSessionLock.Lock()
filtered := c.idleSessions[:0]
for _, idle := range c.idleSessions {
if idle != session {
filtered = append(filtered, idle)
}
}
c.idleSessions = filtered
c.idleSessionLock.Unlock()
c.sessionsLock.Lock()
delete(c.sessions, session.seq)
c.sessionsLock.Unlock()
}
c.sessionsLock.Lock()
c.sessions[session.seq] = session
c.sessionsLock.Unlock()
session.Run()
return session, nil
}
func (c *Client) idleCleanupLoop(interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.idleCleanup(time.Now().Add(-c.idleSessionTimeout))
case <-c.stopCleanup:
return
}
}
}
func (c *Client) idleCleanup(expireBefore time.Time) {
var toClose []*Session
c.idleSessionLock.Lock()
activeCount := 0
kept := c.idleSessions[:0]
for _, session := range c.idleSessions {
if session == nil || session.IsClosed() {
continue
}
if !session.idleSince.Before(expireBefore) {
activeCount++
kept = append(kept, session)
continue
}
if activeCount < max(c.minIdleSession, 0) {
activeCount++
session.idleSince = time.Now()
kept = append(kept, session)
continue
}
toClose = append(toClose, session)
}
c.idleSessions = kept
c.idleSessionLock.Unlock()
for _, session := range toClose {
session.Close()
}
}
func max(a, b int) int {
return int(math.Max(float64(a), float64(b)))
}

76
proxy/anytls/packet.go Normal file
View File

@ -0,0 +1,76 @@
package anytls
import (
"encoding/binary"
"errors"
"io"
"net"
)
type PktConn struct {
net.Conn
target addrPort
init bool
}
func NewPktConn(conn net.Conn, target addrPort) *PktConn {
return &PktConn{Conn: conn, target: target}
}
func (pc *PktConn) writeRequest() error {
if pc.init {
return nil
}
req := make([]byte, 0, 1+addrPortLen(pc.target))
req = append(req, 1)
req = append(req, serializeAddrPort(pc.target)...)
if _, err := pc.Conn.Write(req); err != nil {
return err
}
pc.init = true
return nil
}
func (pc *PktConn) ReadFrom(b []byte) (int, net.Addr, error) {
if len(b) < 2 {
return 0, pc.target.UDPAddr(), errors.New("buf size is not enough")
}
if _, err := io.ReadFull(pc.Conn, b[:2]); err != nil {
return 0, pc.target.UDPAddr(), err
}
length := int(binary.BigEndian.Uint16(b[:2]))
if len(b) < length {
return 0, pc.target.UDPAddr(), errors.New("buf size is not enough")
}
n, err := io.ReadFull(pc.Conn, b[:length])
return n, pc.target.UDPAddr(), err
}
func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) {
target := pc.target
if addr != nil {
target = parseAddrPort(addr.String())
}
if !target.IsValid() {
return 0, errors.New("invalid addr")
}
if !pc.init {
if err := pc.writeRequest(); err != nil {
return 0, err
}
}
frame := make([]byte, 2+len(b))
binary.BigEndian.PutUint16(frame[:2], uint16(len(b)))
copy(frame[2:], b)
n, err := pc.Conn.Write(frame)
if n > 2 {
return n - 2, err
}
return 0, err
}

127
proxy/anytls/padding.go Normal file
View File

@ -0,0 +1,127 @@
package anytls
import (
"crypto/md5"
"crypto/rand"
"fmt"
"math/big"
"strconv"
"strings"
"sync/atomic"
)
const (
checkMark = -1
)
var defaultPaddingScheme = []byte(`stop=8
0=30-30
1=100-400
2=400-500,c,500-1000,c,500-1000,c,500-1000,c,500-1000
3=9-9,500-1000
4=500-1000
5=500-1000
6=500-1000
7=500-1000`)
type PaddingFactory struct {
scheme map[string]string
RawScheme []byte
Stop uint32
Md5 string
}
var defaultPaddingFactory atomic.Pointer[PaddingFactory]
func init() {
UpdatePaddingScheme(defaultPaddingScheme)
}
func UpdatePaddingScheme(rawScheme []byte) bool {
padding := NewPaddingFactory(rawScheme)
if padding == nil {
return false
}
defaultPaddingFactory.Store(padding)
return true
}
func NewPaddingFactory(rawScheme []byte) *PaddingFactory {
scheme := stringMapFromBytes(rawScheme)
if len(scheme) == 0 {
return nil
}
stop, err := strconv.Atoi(scheme["stop"])
if err != nil || stop <= 0 {
return nil
}
rawCopy := append([]byte(nil), rawScheme...)
return &PaddingFactory{
scheme: scheme,
RawScheme: rawCopy,
Stop: uint32(stop),
Md5: fmt.Sprintf("%x", md5.Sum(rawCopy)),
}
}
func loadPaddingFactory() *PaddingFactory {
return defaultPaddingFactory.Load()
}
func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) []int {
if p == nil {
return nil
}
raw, ok := p.scheme[strconv.Itoa(int(pkt))]
if !ok {
return nil
}
parts := strings.Split(raw, ",")
sizes := make([]int, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "c" {
sizes = append(sizes, checkMark)
continue
}
bounds := strings.SplitN(part, "-", 2)
if len(bounds) != 2 {
continue
}
minSize, err := strconv.ParseInt(bounds[0], 10, 64)
if err != nil {
continue
}
maxSize, err := strconv.ParseInt(bounds[1], 10, 64)
if err != nil {
continue
}
if minSize > maxSize {
minSize, maxSize = maxSize, minSize
}
if minSize <= 0 || maxSize <= 0 {
continue
}
if minSize == maxSize {
sizes = append(sizes, int(minSize))
continue
}
delta, err := rand.Int(rand.Reader, big.NewInt(maxSize-minSize+1))
if err != nil {
sizes = append(sizes, int(minSize))
continue
}
sizes = append(sizes, int(minSize+delta.Int64()))
}
return sizes
}

389
proxy/anytls/session.go Normal file
View File

@ -0,0 +1,389 @@
package anytls
import (
"encoding/binary"
"fmt"
"io"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
)
const (
cmdWaste byte = iota
cmdSYN
cmdPSH
cmdFIN
cmdSettings
cmdAlert
cmdUpdatePaddingScheme
cmdSYNACK
cmdHeartRequest
cmdHeartResponse
cmdServerSettings
)
const headerOverHeadSize = 1 + 4 + 2
type frame struct {
cmd byte
sid uint32
data []byte
}
func newFrame(cmd byte, sid uint32) frame {
return frame{cmd: cmd, sid: sid}
}
type rawHeader [headerOverHeadSize]byte
func (h rawHeader) Cmd() byte {
return h[0]
}
func (h rawHeader) StreamID() uint32 {
return binary.BigEndian.Uint32(h[1:5])
}
func (h rawHeader) Length() uint16 {
return binary.BigEndian.Uint16(h[5:7])
}
type Session struct {
conn io.ReadWriteCloser
connLock sync.Mutex
streams map[uint32]*Stream
streamID atomic.Uint32
streamLock sync.RWMutex
dieOnce sync.Once
die chan struct{}
dieHook func()
seq uint64
idleSince time.Time
peerVersion byte
isClient bool
sendPadding bool
buffering bool
buffer []byte
pktCounter atomic.Uint32
}
func NewClientSession(conn io.ReadWriteCloser) *Session {
return &Session{
conn: conn,
streams: make(map[uint32]*Stream),
die: make(chan struct{}),
isClient: true,
sendPadding: true,
}
}
func (s *Session) Run() {
settings := []byte("v=" + protocolVersion + "\nclient=" + programVersionName + "\npadding-md5=" + loadPaddingFactory().Md5)
frame := newFrame(cmdSettings, 0)
frame.data = settings
s.buffering = true
_, _ = s.writeControlFrame(frame)
go s.recvLoop()
}
func (s *Session) IsClosed() bool {
select {
case <-s.die:
return true
default:
return false
}
}
func (s *Session) Close() error {
var once bool
s.dieOnce.Do(func() {
close(s.die)
once = true
})
if !once {
return io.ErrClosedPipe
}
if s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
s.streamLock.Lock()
for _, stream := range s.streams {
stream.closeLocally()
}
s.streams = make(map[uint32]*Stream)
s.streamLock.Unlock()
return s.conn.Close()
}
func (s *Session) OpenStream() (*Stream, error) {
if s.IsClosed() {
return nil, io.ErrClosedPipe
}
sid := s.streamID.Add(1)
stream := newStream(sid, s)
if _, err := s.writeControlFrame(newFrame(cmdSYN, sid)); err != nil {
return nil, err
}
s.buffering = false
s.streamLock.Lock()
s.streams[sid] = stream
s.streamLock.Unlock()
return stream, nil
}
func (s *Session) recvLoop() error {
defer s.Close()
var hdr rawHeader
for {
if s.IsClosed() {
return io.ErrClosedPipe
}
if _, err := io.ReadFull(s.conn, hdr[:]); err != nil {
return err
}
sid := hdr.StreamID()
length := int(hdr.Length())
switch hdr.Cmd() {
case cmdPSH:
if length == 0 {
continue
}
payload := make([]byte, length)
if _, err := io.ReadFull(s.conn, payload); err != nil {
return err
}
s.streamLock.RLock()
stream := s.streams[sid]
s.streamLock.RUnlock()
if stream != nil {
stream.feed(payload)
}
case cmdSYNACK:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
continue
}
s.streamLock.RLock()
stream := s.streams[sid]
s.streamLock.RUnlock()
if stream != nil {
stream.closeWithError(fmt.Errorf("remote: %s", string(payload)))
}
case cmdFIN:
s.streamLock.Lock()
stream := s.streams[sid]
delete(s.streams, sid)
s.streamLock.Unlock()
if stream != nil {
stream.closeLocally()
}
case cmdWaste:
if _, err := s.readPayload(length); err != nil {
return err
}
case cmdAlert:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
return io.ErrUnexpectedEOF
}
return fmt.Errorf("[anytls] alert from server: %s", string(payload))
case cmdUpdatePaddingScheme:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) > 0 {
UpdatePaddingScheme(payload)
}
case cmdHeartRequest:
if _, err := s.writeControlFrame(newFrame(cmdHeartResponse, sid)); err != nil {
return err
}
case cmdHeartResponse:
if _, err := s.readPayload(length); err != nil {
return err
}
case cmdServerSettings:
payload, err := s.readPayload(length)
if err != nil {
return err
}
if len(payload) == 0 {
continue
}
if version, err := strconv.Atoi(stringMapFromBytes(payload)["v"]); err == nil {
s.peerVersion = byte(version)
}
default:
if _, err := s.readPayload(length); err != nil {
return err
}
}
}
}
func (s *Session) readPayload(length int) ([]byte, error) {
if length == 0 {
return nil, nil
}
payload := make([]byte, length)
_, err := io.ReadFull(s.conn, payload)
return payload, err
}
func (s *Session) streamClosed(sid uint32) error {
if s.IsClosed() {
return io.ErrClosedPipe
}
_, err := s.writeControlFrame(newFrame(cmdFIN, sid))
s.streamLock.Lock()
delete(s.streams, sid)
s.streamLock.Unlock()
return err
}
func (s *Session) writeDataFrame(sid uint32, data []byte) (int, error) {
buffer := make([]byte, headerOverHeadSize+len(data))
buffer[0] = cmdPSH
binary.BigEndian.PutUint32(buffer[1:5], sid)
binary.BigEndian.PutUint16(buffer[5:7], uint16(len(data)))
copy(buffer[7:], data)
if _, err := s.writeConn(buffer); err != nil {
return 0, err
}
return len(data), nil
}
func (s *Session) writeControlFrame(frame frame) (int, error) {
buffer := make([]byte, headerOverHeadSize+len(frame.data))
buffer[0] = frame.cmd
binary.BigEndian.PutUint32(buffer[1:5], frame.sid)
binary.BigEndian.PutUint16(buffer[5:7], uint16(len(frame.data)))
copy(buffer[7:], frame.data)
if conn, ok := s.conn.(interface{ SetWriteDeadline(time.Time) error }); ok {
_ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
defer conn.SetWriteDeadline(time.Time{})
}
if _, err := s.writeConn(buffer); err != nil {
s.Close()
return 0, err
}
return len(frame.data), nil
}
func (s *Session) writeConn(b []byte) (int, error) {
s.connLock.Lock()
defer s.connLock.Unlock()
if s.buffering {
s.buffer = append(s.buffer, b...)
return len(b), nil
}
if len(s.buffer) > 0 {
b = slices.Concat(s.buffer, b)
s.buffer = nil
}
if s.sendPadding {
padding := loadPaddingFactory()
if padding != nil {
pkt := s.pktCounter.Add(1)
if pkt < padding.Stop {
return s.writeWithPadding(b, padding.GenerateRecordPayloadSizes(pkt))
}
}
s.sendPadding = false
}
return s.conn.Write(b)
}
func (s *Session) writeWithPadding(payload []byte, sizes []int) (int, error) {
n := 0
b := payload
for _, size := range sizes {
remain := len(b)
if size == checkMark {
if remain == 0 {
break
}
continue
}
switch {
case remain > size:
written, err := s.conn.Write(b[:size])
n += written
if err != nil {
return 0, err
}
b = b[size:]
case remain > 0:
paddingLen := size - remain - headerOverHeadSize
if paddingLen > 0 {
padding := make([]byte, headerOverHeadSize+paddingLen)
padding[0] = cmdWaste
binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen))
b = slices.Concat(b, padding)
}
written, err := s.conn.Write(b)
n += min(written, remain)
if err != nil {
return 0, err
}
b = nil
case remain == 0:
padding := make([]byte, headerOverHeadSize+size)
padding[0] = cmdWaste
binary.BigEndian.PutUint16(padding[5:7], uint16(size))
if _, err := s.conn.Write(padding); err != nil {
return 0, err
}
}
}
if len(b) == 0 {
return n, nil
}
written, err := s.conn.Write(b)
n += min(written, len(b))
return n, err
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

187
proxy/anytls/stream.go Normal file
View File

@ -0,0 +1,187 @@
package anytls
import (
"bytes"
"io"
"net"
"os"
"sync"
"sync/atomic"
"time"
)
type Stream struct {
id uint32
sess *Session
mu sync.Mutex
readBuf bytes.Buffer
readNotify chan struct{}
closed chan struct{}
dieOnce sync.Once
dieErr error
dieHook func()
readDeadline atomic.Value
writeDeadline atomic.Value
}
func newStream(id uint32, sess *Session) *Stream {
return &Stream{
id: id,
sess: sess,
readNotify: make(chan struct{}, 1),
closed: make(chan struct{}),
}
}
func (s *Stream) Read(b []byte) (int, error) {
for {
s.mu.Lock()
if s.readBuf.Len() > 0 {
n, _ := s.readBuf.Read(b)
err := error(nil)
if n == 0 && s.dieErr != nil {
err = s.dieErr
}
s.mu.Unlock()
return n, err
}
err := s.dieErr
s.mu.Unlock()
if err != nil {
return 0, err
}
deadline := s.loadDeadline(&s.readDeadline)
if deadline.IsZero() {
select {
case <-s.readNotify:
case <-s.closed:
}
continue
}
wait := time.Until(deadline)
if wait <= 0 {
return 0, os.ErrDeadlineExceeded
}
timer := time.NewTimer(wait)
select {
case <-s.readNotify:
timer.Stop()
case <-s.closed:
timer.Stop()
case <-timer.C:
return 0, os.ErrDeadlineExceeded
}
}
}
func (s *Stream) Write(b []byte) (int, error) {
if deadline := s.loadDeadline(&s.writeDeadline); !deadline.IsZero() && time.Until(deadline) <= 0 {
return 0, os.ErrDeadlineExceeded
}
s.mu.Lock()
err := s.dieErr
s.mu.Unlock()
if err != nil {
return 0, err
}
return s.sess.writeDataFrame(s.id, b)
}
func (s *Stream) Close() error {
return s.closeWithError(io.ErrClosedPipe)
}
func (s *Stream) closeLocally() {
var once bool
s.dieOnce.Do(func() {
s.mu.Lock()
s.dieErr = net.ErrClosed
s.mu.Unlock()
close(s.closed)
once = true
})
if once && s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
}
func (s *Stream) closeWithError(err error) error {
var once bool
s.dieOnce.Do(func() {
s.mu.Lock()
s.dieErr = err
s.mu.Unlock()
close(s.closed)
once = true
})
if !once {
s.mu.Lock()
defer s.mu.Unlock()
return s.dieErr
}
if s.dieHook != nil {
s.dieHook()
s.dieHook = nil
}
return s.sess.streamClosed(s.id)
}
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline.Store(t)
return nil
}
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline.Store(t)
return nil
}
func (s *Stream) SetDeadline(t time.Time) error {
_ = s.SetReadDeadline(t)
_ = s.SetWriteDeadline(t)
return nil
}
func (s *Stream) LocalAddr() net.Addr {
if conn, ok := s.sess.conn.(interface{ LocalAddr() net.Addr }); ok {
return conn.LocalAddr()
}
return nil
}
func (s *Stream) RemoteAddr() net.Addr {
if conn, ok := s.sess.conn.(interface{ RemoteAddr() net.Addr }); ok {
return conn.RemoteAddr()
}
return nil
}
func (s *Stream) feed(data []byte) {
s.mu.Lock()
defer s.mu.Unlock()
if s.dieErr != nil {
return
}
_, _ = s.readBuf.Write(data)
select {
case s.readNotify <- struct{}{}:
default:
}
}
func (s *Stream) loadDeadline(value *atomic.Value) time.Time {
v := value.Load()
if v == nil {
return time.Time{}
}
deadline, _ := v.(time.Time)
return deadline
}

19
proxy/anytls/util.go Normal file
View File

@ -0,0 +1,19 @@
package anytls
import "strings"
func stringMapFromBytes(raw []byte) map[string]string {
result := make(map[string]string)
for _, line := range strings.Split(string(raw), "\n") {
line = strings.TrimSpace(line)
if line == "" {
continue
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
result[key] = value
}
return result
}

View File

@ -108,7 +108,7 @@ func NewClient(uuidStr, security string, alterID int, aead bool) (*Client, error
case "zero": case "zero":
c.security = SecurityNone c.security = SecurityNone
c.opt = OptBasicFormat c.opt = OptBasicFormat
case "": case "", "auto":
c.security = SecurityChacha20Poly1305 c.security = SecurityChacha20Poly1305
if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" || runtime.GOARCH == "arm64" { if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" || runtime.GOARCH == "arm64" {
c.security = SecurityAES128GCM c.security = SecurityAES128GCM