From a92efe65c5d0b0a48634596fddb9ea3e0da153c4 Mon Sep 17 00:00:00 2001 From: 0xec Date: Thu, 30 Apr 2026 23:40:42 +0800 Subject: [PATCH] feat(anytls): implement AnyTLS protocol --- feature.go | 1 + proxy/anytls/addr.go | 102 +++++++++++ proxy/anytls/anytls.go | 254 ++++++++++++++++++++++++++ proxy/anytls/client.go | 211 ++++++++++++++++++++++ proxy/anytls/packet.go | 76 ++++++++ proxy/anytls/padding.go | 127 +++++++++++++ proxy/anytls/session.go | 389 ++++++++++++++++++++++++++++++++++++++++ proxy/anytls/stream.go | 187 +++++++++++++++++++ proxy/anytls/util.go | 19 ++ proxy/vmess/client.go | 2 +- 10 files changed, 1367 insertions(+), 1 deletion(-) create mode 100644 proxy/anytls/addr.go create mode 100644 proxy/anytls/anytls.go create mode 100644 proxy/anytls/client.go create mode 100644 proxy/anytls/packet.go create mode 100644 proxy/anytls/padding.go create mode 100644 proxy/anytls/session.go create mode 100644 proxy/anytls/stream.go create mode 100644 proxy/anytls/util.go diff --git a/feature.go b/feature.go index 21b49fc..9489a92 100644 --- a/feature.go +++ b/feature.go @@ -5,6 +5,7 @@ import ( // _ "github.com/nadoo/glider/service/xxx" // comment out the protocols you don't need to make the compiled binary smaller. + _ "github.com/nadoo/glider/proxy/anytls" _ "github.com/nadoo/glider/proxy/http" _ "github.com/nadoo/glider/proxy/kcp" _ "github.com/nadoo/glider/proxy/mixed" diff --git a/proxy/anytls/addr.go b/proxy/anytls/addr.go new file mode 100644 index 0000000..0f498de --- /dev/null +++ b/proxy/anytls/addr.go @@ -0,0 +1,102 @@ +package anytls + +import ( + "net" + "net/netip" + "strconv" + "strings" + + "github.com/nadoo/glider/pkg/socks" +) + +const ( + uotAddrIPv4 = 0x00 + uotAddrIPv6 = 0x01 + uotAddrFQDN = 0x02 +) + +type addrPort struct { + Addr netip.Addr + FQDN string + Port uint16 +} + +func (a addrPort) IsValid() bool { + return a.Addr.IsValid() || a.FQDN != "" +} + +func (a addrPort) UDPAddr() *net.UDPAddr { + if a.Addr.IsValid() { + return net.UDPAddrFromAddrPort(netip.AddrPortFrom(a.Addr, a.Port)) + } + addr, _ := net.ResolveUDPAddr("udp", net.JoinHostPort(a.FQDN, strconv.Itoa(int(a.Port)))) + return addr +} + +func parseAddrPort(value string) addrPort { + host, port, err := net.SplitHostPort(value) + if err != nil { + return addrPort{} + } + + portNum, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return addrPort{} + } + + if ip, err := netip.ParseAddr(host); err == nil { + return addrPort{Addr: ip, Port: uint16(portNum)} + } + + return addrPort{FQDN: strings.TrimSuffix(host, "."), Port: uint16(portNum)} +} + +func addrPortLen(addr addrPort) int { + if !addr.IsValid() { + return 1 + } + if addr.Addr.IsValid() { + if addr.Addr.Is4() { + return 1 + 4 + 2 + } + return 1 + 16 + 2 + } + return 1 + 1 + len(addr.FQDN) + 2 +} + +func serializeAddrPort(addr addrPort) []byte { + if !addr.IsValid() { + return []byte{0} + } + + if addr.Addr.IsValid() { + if addr.Addr.Is4() { + buf := make([]byte, 1+4+2) + buf[0] = uotAddrIPv4 + copy(buf[1:5], addr.Addr.AsSlice()) + binaryPort(buf[5:7], addr.Port) + return buf + } + buf := make([]byte, 1+16+2) + buf[0] = uotAddrIPv6 + copy(buf[1:17], addr.Addr.AsSlice()) + binaryPort(buf[17:19], addr.Port) + return buf + } + + buf := make([]byte, 1+1+len(addr.FQDN)+2) + buf[0] = uotAddrFQDN + buf[1] = byte(len(addr.FQDN)) + copy(buf[2:2+len(addr.FQDN)], addr.FQDN) + binaryPort(buf[2+len(addr.FQDN):], addr.Port) + return buf +} + +func binaryPort(dst []byte, port uint16) { + dst[0] = byte(port >> 8) + dst[1] = byte(port) +} + +func socksAddr(addr string) socks.Addr { + return socks.ParseAddr(addr) +} diff --git a/proxy/anytls/anytls.go b/proxy/anytls/anytls.go new file mode 100644 index 0000000..cf4c46c --- /dev/null +++ b/proxy/anytls/anytls.go @@ -0,0 +1,254 @@ +package anytls + +import ( + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "net/url" + "os" + "strconv" + "strings" + "time" + + "github.com/nadoo/glider/proxy" +) + +const ( + defaultIdleSessionCheckInterval = 30 * time.Second + defaultIdleSessionTimeout = 30 * time.Second + defaultMinIdleSession = 5 + protocolVersion = "2" + programVersionName = "glider-anytls" + uotMagicAddress = "sp.v2.udp-over-tcp.arpa" +) + +type AnyTLS struct { + dialer proxy.Dialer + addr string + + passwordSHA256 [32]byte + tlsConfig *tls.Config + + client *Client +} + +func init() { + proxy.RegisterDialer("anytls", NewAnyTLSDialer) + proxy.AddUsage("anytls", ` +AnyTLS client scheme: + anytls://password@host:port[?sni=SERVERNAME][&insecure=1][&cert=PATH] + anytls://password@host:port[?serverName=SERVERNAME][&skipVerify=true][&cert=PATH] + anytls://password@host:port[?minIdleSession=5][&idleSessionCheckInterval=30s][&idleSessionTimeout=30s] +`) +} + +func NewAnyTLSDialer(s string, d proxy.Dialer) (proxy.Dialer, error) { + a, err := NewAnyTLS(s, d) + if err != nil { + return nil, err + } + return a, nil +} + +func NewAnyTLS(s string, d proxy.Dialer) (*AnyTLS, error) { + u, err := url.Parse(s) + if err != nil { + return nil, fmt.Errorf("[anytls] parse url err: %w", err) + } + + password := "" + if u.User != nil { + password = u.User.Username() + } + if password == "" { + return nil, fmt.Errorf("[anytls] password must be specified") + } + + addr := u.Host + if addr == "" { + return nil, fmt.Errorf("[anytls] server address must be specified") + } + if _, port, _ := net.SplitHostPort(addr); port == "" { + addr = net.JoinHostPort(addr, "443") + } + + query := u.Query() + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("[anytls] invalid server address %q: %w", addr, err) + } + + serverName := firstNonEmpty(query.Get("sni"), query.Get("serverName")) + if serverName == "" { + serverName = host + } + if net.ParseIP(serverName) != nil { + serverName = "" + } + + skipVerify := isTrue(query.Get("insecure")) || isTrue(query.Get("skipVerify")) + certFile := query.Get("cert") + + tlsConfig := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: skipVerify, + MinVersion: tls.VersionTLS12, + } + + if certFile != "" { + certData, err := os.ReadFile(certFile) + if err != nil { + return nil, fmt.Errorf("[anytls] read cert file error: %w", err) + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(certData) { + return nil, fmt.Errorf("[anytls] can not append cert file: %s", certFile) + } + tlsConfig.RootCAs = certPool + } + + idleCheckInterval := defaultIdleSessionCheckInterval + if value := query.Get("idleSessionCheckInterval"); value != "" { + idleCheckInterval, err = time.ParseDuration(value) + if err != nil { + return nil, fmt.Errorf("[anytls] invalid idleSessionCheckInterval: %w", err) + } + } + + idleTimeout := defaultIdleSessionTimeout + if value := query.Get("idleSessionTimeout"); value != "" { + idleTimeout, err = time.ParseDuration(value) + if err != nil { + return nil, fmt.Errorf("[anytls] invalid idleSessionTimeout: %w", err) + } + } + + minIdleSession := defaultMinIdleSession + if value := query.Get("minIdleSession"); value != "" { + minIdleSession, err = strconv.Atoi(value) + if err != nil { + return nil, fmt.Errorf("[anytls] invalid minIdleSession: %w", err) + } + if minIdleSession < 0 { + minIdleSession = 0 + } + } + + a := &AnyTLS{ + dialer: d, + addr: addr, + passwordSHA256: sha256.Sum256([]byte(password)), + tlsConfig: tlsConfig, + } + a.client = NewClient(a.createAuthenticatedConn, idleCheckInterval, idleTimeout, minIdleSession) + + return a, nil +} + +func (a *AnyTLS) Addr() string { + if a.addr == "" { + return a.dialer.Addr() + } + return a.addr +} + +func (a *AnyTLS) Dial(network, addr string) (net.Conn, error) { + stream, err := a.client.CreateStream() + if err != nil { + return nil, err + } + + target := socksAddr(addr) + if target == nil { + stream.Close() + return nil, fmt.Errorf("[anytls] invalid target address: %s", addr) + } + + if _, err := stream.Write(target); err != nil { + stream.Close() + return nil, err + } + + return stream, nil +} + +func (a *AnyTLS) DialUDP(network, addr string) (net.PacketConn, error) { + target := parseAddrPort(addr) + if !target.IsValid() { + return nil, fmt.Errorf("[anytls] invalid udp target address: %s", addr) + } + + stream, err := a.client.CreateStream() + if err != nil { + return nil, err + } + + proxyTarget := socksAddr(net.JoinHostPort(uotMagicAddress, "0")) + if proxyTarget == nil { + stream.Close() + return nil, fmt.Errorf("[anytls] invalid uot target") + } + + if _, err := stream.Write(proxyTarget); err != nil { + stream.Close() + return nil, err + } + + pc := NewPktConn(stream, target) + if err := pc.writeRequest(); err != nil { + stream.Close() + return nil, err + } + + return pc, nil +} + +func (a *AnyTLS) createAuthenticatedConn() (net.Conn, error) { + rawConn, err := a.dialer.Dial("tcp", a.addr) + if err != nil { + return nil, err + } + + tlsConn := tls.Client(rawConn, a.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + rawConn.Close() + return nil, err + } + + paddingLen := 0 + if padding := loadPaddingFactory(); padding != nil { + sizes := padding.GenerateRecordPayloadSizes(0) + if len(sizes) > 0 && sizes[0] > 0 { + paddingLen = sizes[0] + } + } + + auth := make([]byte, 32+2+paddingLen) + copy(auth[:32], a.passwordSHA256[:]) + auth[32] = byte(paddingLen >> 8) + auth[33] = byte(paddingLen) + + if _, err := tlsConn.Write(auth); err != nil { + tlsConn.Close() + return nil, err + } + + return tlsConn, nil +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func isTrue(value string) bool { + value = strings.ToLower(value) + return value == "1" || value == "true" || value == "yes" || value == "on" +} diff --git a/proxy/anytls/client.go b/proxy/anytls/client.go new file mode 100644 index 0000000..3a27077 --- /dev/null +++ b/proxy/anytls/client.go @@ -0,0 +1,211 @@ +package anytls + +import ( + "io" + "math" + "net" + "sort" + "sync" + "sync/atomic" + "time" +) + +type Client struct { + dialOut func() (net.Conn, error) + + sessionCounter atomic.Uint64 + + idleSessionLock sync.Mutex + idleSessions []*Session + + sessionsLock sync.Mutex + sessions map[uint64]*Session + + idleSessionTimeout time.Duration + minIdleSession int + closed atomic.Bool + stopCleanup chan struct{} +} + +func NewClient(dialOut func() (net.Conn, error), idleSessionCheckInterval, idleSessionTimeout time.Duration, minIdleSession int) *Client { + if idleSessionCheckInterval <= 5*time.Second { + idleSessionCheckInterval = defaultIdleSessionCheckInterval + } + if idleSessionTimeout <= 5*time.Second { + idleSessionTimeout = defaultIdleSessionTimeout + } + + c := &Client{ + dialOut: dialOut, + sessions: make(map[uint64]*Session), + idleSessionTimeout: idleSessionTimeout, + minIdleSession: minIdleSession, + stopCleanup: make(chan struct{}), + } + + go c.idleCleanupLoop(idleSessionCheckInterval) + return c +} + +func (c *Client) CreateStream() (*Stream, error) { + if c.closed.Load() { + return nil, io.ErrClosedPipe + } + + var ( + session *Session + err error + ) + + session = c.getIdleSession() + if session == nil { + session, err = c.createSession() + } + if session == nil { + if err == nil { + err = io.ErrClosedPipe + } + return nil, err + } + + stream, err := session.OpenStream() + if err != nil { + session.Close() + return nil, err + } + + stream.dieHook = func() { + if c.closed.Load() || session.IsClosed() { + session.Close() + return + } + + c.idleSessionLock.Lock() + session.idleSince = time.Now() + c.idleSessions = append(c.idleSessions, session) + sort.Slice(c.idleSessions, func(i, j int) bool { + return c.idleSessions[i].seq > c.idleSessions[j].seq + }) + c.idleSessionLock.Unlock() + } + + return stream, nil +} + +func (c *Client) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return io.ErrClosedPipe + } + + close(c.stopCleanup) + + c.sessionsLock.Lock() + sessions := make([]*Session, 0, len(c.sessions)) + for _, session := range c.sessions { + sessions = append(sessions, session) + } + c.sessions = make(map[uint64]*Session) + c.sessionsLock.Unlock() + + for _, session := range sessions { + session.Close() + } + + return nil +} + +func (c *Client) getIdleSession() *Session { + c.idleSessionLock.Lock() + defer c.idleSessionLock.Unlock() + + for len(c.idleSessions) > 0 { + session := c.idleSessions[0] + c.idleSessions = c.idleSessions[1:] + if session != nil && !session.IsClosed() { + return session + } + } + + return nil +} + +func (c *Client) createSession() (*Session, error) { + underlying, err := c.dialOut() + if err != nil { + return nil, err + } + + session := NewClientSession(underlying) + session.seq = c.sessionCounter.Add(1) + session.dieHook = func() { + c.idleSessionLock.Lock() + filtered := c.idleSessions[:0] + for _, idle := range c.idleSessions { + if idle != session { + filtered = append(filtered, idle) + } + } + c.idleSessions = filtered + c.idleSessionLock.Unlock() + + c.sessionsLock.Lock() + delete(c.sessions, session.seq) + c.sessionsLock.Unlock() + } + + c.sessionsLock.Lock() + c.sessions[session.seq] = session + c.sessionsLock.Unlock() + + session.Run() + return session, nil +} + +func (c *Client) idleCleanupLoop(interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + c.idleCleanup(time.Now().Add(-c.idleSessionTimeout)) + case <-c.stopCleanup: + return + } + } +} + +func (c *Client) idleCleanup(expireBefore time.Time) { + var toClose []*Session + + c.idleSessionLock.Lock() + activeCount := 0 + kept := c.idleSessions[:0] + for _, session := range c.idleSessions { + if session == nil || session.IsClosed() { + continue + } + if !session.idleSince.Before(expireBefore) { + activeCount++ + kept = append(kept, session) + continue + } + if activeCount < max(c.minIdleSession, 0) { + activeCount++ + session.idleSince = time.Now() + kept = append(kept, session) + continue + } + toClose = append(toClose, session) + } + c.idleSessions = kept + c.idleSessionLock.Unlock() + + for _, session := range toClose { + session.Close() + } +} + +func max(a, b int) int { + return int(math.Max(float64(a), float64(b))) +} diff --git a/proxy/anytls/packet.go b/proxy/anytls/packet.go new file mode 100644 index 0000000..45396f0 --- /dev/null +++ b/proxy/anytls/packet.go @@ -0,0 +1,76 @@ +package anytls + +import ( + "encoding/binary" + "errors" + "io" + "net" +) + +type PktConn struct { + net.Conn + target addrPort + init bool +} + +func NewPktConn(conn net.Conn, target addrPort) *PktConn { + return &PktConn{Conn: conn, target: target} +} + +func (pc *PktConn) writeRequest() error { + if pc.init { + return nil + } + + req := make([]byte, 0, 1+addrPortLen(pc.target)) + req = append(req, 1) + req = append(req, serializeAddrPort(pc.target)...) + if _, err := pc.Conn.Write(req); err != nil { + return err + } + pc.init = true + return nil +} + +func (pc *PktConn) ReadFrom(b []byte) (int, net.Addr, error) { + if len(b) < 2 { + return 0, pc.target.UDPAddr(), errors.New("buf size is not enough") + } + + if _, err := io.ReadFull(pc.Conn, b[:2]); err != nil { + return 0, pc.target.UDPAddr(), err + } + length := int(binary.BigEndian.Uint16(b[:2])) + if len(b) < length { + return 0, pc.target.UDPAddr(), errors.New("buf size is not enough") + } + + n, err := io.ReadFull(pc.Conn, b[:length]) + return n, pc.target.UDPAddr(), err +} + +func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) { + target := pc.target + if addr != nil { + target = parseAddrPort(addr.String()) + } + if !target.IsValid() { + return 0, errors.New("invalid addr") + } + + if !pc.init { + if err := pc.writeRequest(); err != nil { + return 0, err + } + } + + frame := make([]byte, 2+len(b)) + binary.BigEndian.PutUint16(frame[:2], uint16(len(b))) + copy(frame[2:], b) + + n, err := pc.Conn.Write(frame) + if n > 2 { + return n - 2, err + } + return 0, err +} diff --git a/proxy/anytls/padding.go b/proxy/anytls/padding.go new file mode 100644 index 0000000..7a72757 --- /dev/null +++ b/proxy/anytls/padding.go @@ -0,0 +1,127 @@ +package anytls + +import ( + "crypto/md5" + "crypto/rand" + "fmt" + "math/big" + "strconv" + "strings" + "sync/atomic" +) + +const ( + checkMark = -1 +) + +var defaultPaddingScheme = []byte(`stop=8 +0=30-30 +1=100-400 +2=400-500,c,500-1000,c,500-1000,c,500-1000,c,500-1000 +3=9-9,500-1000 +4=500-1000 +5=500-1000 +6=500-1000 +7=500-1000`) + +type PaddingFactory struct { + scheme map[string]string + RawScheme []byte + Stop uint32 + Md5 string +} + +var defaultPaddingFactory atomic.Pointer[PaddingFactory] + +func init() { + UpdatePaddingScheme(defaultPaddingScheme) +} + +func UpdatePaddingScheme(rawScheme []byte) bool { + padding := NewPaddingFactory(rawScheme) + if padding == nil { + return false + } + defaultPaddingFactory.Store(padding) + return true +} + +func NewPaddingFactory(rawScheme []byte) *PaddingFactory { + scheme := stringMapFromBytes(rawScheme) + if len(scheme) == 0 { + return nil + } + + stop, err := strconv.Atoi(scheme["stop"]) + if err != nil || stop <= 0 { + return nil + } + + rawCopy := append([]byte(nil), rawScheme...) + return &PaddingFactory{ + scheme: scheme, + RawScheme: rawCopy, + Stop: uint32(stop), + Md5: fmt.Sprintf("%x", md5.Sum(rawCopy)), + } +} + +func loadPaddingFactory() *PaddingFactory { + return defaultPaddingFactory.Load() +} + +func (p *PaddingFactory) GenerateRecordPayloadSizes(pkt uint32) []int { + if p == nil { + return nil + } + + raw, ok := p.scheme[strconv.Itoa(int(pkt))] + if !ok { + return nil + } + + parts := strings.Split(raw, ",") + sizes := make([]int, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "c" { + sizes = append(sizes, checkMark) + continue + } + + bounds := strings.SplitN(part, "-", 2) + if len(bounds) != 2 { + continue + } + + minSize, err := strconv.ParseInt(bounds[0], 10, 64) + if err != nil { + continue + } + maxSize, err := strconv.ParseInt(bounds[1], 10, 64) + if err != nil { + continue + } + + if minSize > maxSize { + minSize, maxSize = maxSize, minSize + } + if minSize <= 0 || maxSize <= 0 { + continue + } + + if minSize == maxSize { + sizes = append(sizes, int(minSize)) + continue + } + + delta, err := rand.Int(rand.Reader, big.NewInt(maxSize-minSize+1)) + if err != nil { + sizes = append(sizes, int(minSize)) + continue + } + sizes = append(sizes, int(minSize+delta.Int64())) + } + + return sizes +} diff --git a/proxy/anytls/session.go b/proxy/anytls/session.go new file mode 100644 index 0000000..5091bc6 --- /dev/null +++ b/proxy/anytls/session.go @@ -0,0 +1,389 @@ +package anytls + +import ( + "encoding/binary" + "fmt" + "io" + "slices" + "strconv" + "sync" + "sync/atomic" + "time" +) + +const ( + cmdWaste byte = iota + cmdSYN + cmdPSH + cmdFIN + cmdSettings + cmdAlert + cmdUpdatePaddingScheme + cmdSYNACK + cmdHeartRequest + cmdHeartResponse + cmdServerSettings +) + +const headerOverHeadSize = 1 + 4 + 2 + +type frame struct { + cmd byte + sid uint32 + data []byte +} + +func newFrame(cmd byte, sid uint32) frame { + return frame{cmd: cmd, sid: sid} +} + +type rawHeader [headerOverHeadSize]byte + +func (h rawHeader) Cmd() byte { + return h[0] +} + +func (h rawHeader) StreamID() uint32 { + return binary.BigEndian.Uint32(h[1:5]) +} + +func (h rawHeader) Length() uint16 { + return binary.BigEndian.Uint16(h[5:7]) +} + +type Session struct { + conn io.ReadWriteCloser + connLock sync.Mutex + + streams map[uint32]*Stream + streamID atomic.Uint32 + streamLock sync.RWMutex + + dieOnce sync.Once + die chan struct{} + dieHook func() + + seq uint64 + idleSince time.Time + + peerVersion byte + isClient bool + sendPadding bool + buffering bool + buffer []byte + pktCounter atomic.Uint32 +} + +func NewClientSession(conn io.ReadWriteCloser) *Session { + return &Session{ + conn: conn, + streams: make(map[uint32]*Stream), + die: make(chan struct{}), + isClient: true, + sendPadding: true, + } +} + +func (s *Session) Run() { + settings := []byte("v=" + protocolVersion + "\nclient=" + programVersionName + "\npadding-md5=" + loadPaddingFactory().Md5) + frame := newFrame(cmdSettings, 0) + frame.data = settings + s.buffering = true + _, _ = s.writeControlFrame(frame) + go s.recvLoop() +} + +func (s *Session) IsClosed() bool { + select { + case <-s.die: + return true + default: + return false + } +} + +func (s *Session) Close() error { + var once bool + s.dieOnce.Do(func() { + close(s.die) + once = true + }) + if !once { + return io.ErrClosedPipe + } + + if s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } + + s.streamLock.Lock() + for _, stream := range s.streams { + stream.closeLocally() + } + s.streams = make(map[uint32]*Stream) + s.streamLock.Unlock() + + return s.conn.Close() +} + +func (s *Session) OpenStream() (*Stream, error) { + if s.IsClosed() { + return nil, io.ErrClosedPipe + } + + sid := s.streamID.Add(1) + stream := newStream(sid, s) + + if _, err := s.writeControlFrame(newFrame(cmdSYN, sid)); err != nil { + return nil, err + } + s.buffering = false + + s.streamLock.Lock() + s.streams[sid] = stream + s.streamLock.Unlock() + + return stream, nil +} + +func (s *Session) recvLoop() error { + defer s.Close() + + var hdr rawHeader + for { + if s.IsClosed() { + return io.ErrClosedPipe + } + + if _, err := io.ReadFull(s.conn, hdr[:]); err != nil { + return err + } + + sid := hdr.StreamID() + length := int(hdr.Length()) + + switch hdr.Cmd() { + case cmdPSH: + if length == 0 { + continue + } + payload := make([]byte, length) + if _, err := io.ReadFull(s.conn, payload); err != nil { + return err + } + s.streamLock.RLock() + stream := s.streams[sid] + s.streamLock.RUnlock() + if stream != nil { + stream.feed(payload) + } + case cmdSYNACK: + payload, err := s.readPayload(length) + if err != nil { + return err + } + if len(payload) == 0 { + continue + } + s.streamLock.RLock() + stream := s.streams[sid] + s.streamLock.RUnlock() + if stream != nil { + stream.closeWithError(fmt.Errorf("remote: %s", string(payload))) + } + case cmdFIN: + s.streamLock.Lock() + stream := s.streams[sid] + delete(s.streams, sid) + s.streamLock.Unlock() + if stream != nil { + stream.closeLocally() + } + case cmdWaste: + if _, err := s.readPayload(length); err != nil { + return err + } + case cmdAlert: + payload, err := s.readPayload(length) + if err != nil { + return err + } + if len(payload) == 0 { + return io.ErrUnexpectedEOF + } + return fmt.Errorf("[anytls] alert from server: %s", string(payload)) + case cmdUpdatePaddingScheme: + payload, err := s.readPayload(length) + if err != nil { + return err + } + if len(payload) > 0 { + UpdatePaddingScheme(payload) + } + case cmdHeartRequest: + if _, err := s.writeControlFrame(newFrame(cmdHeartResponse, sid)); err != nil { + return err + } + case cmdHeartResponse: + if _, err := s.readPayload(length); err != nil { + return err + } + case cmdServerSettings: + payload, err := s.readPayload(length) + if err != nil { + return err + } + if len(payload) == 0 { + continue + } + if version, err := strconv.Atoi(stringMapFromBytes(payload)["v"]); err == nil { + s.peerVersion = byte(version) + } + default: + if _, err := s.readPayload(length); err != nil { + return err + } + } + } +} + +func (s *Session) readPayload(length int) ([]byte, error) { + if length == 0 { + return nil, nil + } + payload := make([]byte, length) + _, err := io.ReadFull(s.conn, payload) + return payload, err +} + +func (s *Session) streamClosed(sid uint32) error { + if s.IsClosed() { + return io.ErrClosedPipe + } + _, err := s.writeControlFrame(newFrame(cmdFIN, sid)) + s.streamLock.Lock() + delete(s.streams, sid) + s.streamLock.Unlock() + return err +} + +func (s *Session) writeDataFrame(sid uint32, data []byte) (int, error) { + buffer := make([]byte, headerOverHeadSize+len(data)) + buffer[0] = cmdPSH + binary.BigEndian.PutUint32(buffer[1:5], sid) + binary.BigEndian.PutUint16(buffer[5:7], uint16(len(data))) + copy(buffer[7:], data) + + if _, err := s.writeConn(buffer); err != nil { + return 0, err + } + return len(data), nil +} + +func (s *Session) writeControlFrame(frame frame) (int, error) { + buffer := make([]byte, headerOverHeadSize+len(frame.data)) + buffer[0] = frame.cmd + binary.BigEndian.PutUint32(buffer[1:5], frame.sid) + binary.BigEndian.PutUint16(buffer[5:7], uint16(len(frame.data))) + copy(buffer[7:], frame.data) + + if conn, ok := s.conn.(interface{ SetWriteDeadline(time.Time) error }); ok { + _ = conn.SetWriteDeadline(time.Now().Add(5 * time.Second)) + defer conn.SetWriteDeadline(time.Time{}) + } + + if _, err := s.writeConn(buffer); err != nil { + s.Close() + return 0, err + } + + return len(frame.data), nil +} + +func (s *Session) writeConn(b []byte) (int, error) { + s.connLock.Lock() + defer s.connLock.Unlock() + + if s.buffering { + s.buffer = append(s.buffer, b...) + return len(b), nil + } + if len(s.buffer) > 0 { + b = slices.Concat(s.buffer, b) + s.buffer = nil + } + + if s.sendPadding { + padding := loadPaddingFactory() + if padding != nil { + pkt := s.pktCounter.Add(1) + if pkt < padding.Stop { + return s.writeWithPadding(b, padding.GenerateRecordPayloadSizes(pkt)) + } + } + s.sendPadding = false + } + + return s.conn.Write(b) +} + +func (s *Session) writeWithPadding(payload []byte, sizes []int) (int, error) { + n := 0 + b := payload + for _, size := range sizes { + remain := len(b) + if size == checkMark { + if remain == 0 { + break + } + continue + } + + switch { + case remain > size: + written, err := s.conn.Write(b[:size]) + n += written + if err != nil { + return 0, err + } + b = b[size:] + case remain > 0: + paddingLen := size - remain - headerOverHeadSize + if paddingLen > 0 { + padding := make([]byte, headerOverHeadSize+paddingLen) + padding[0] = cmdWaste + binary.BigEndian.PutUint16(padding[5:7], uint16(paddingLen)) + b = slices.Concat(b, padding) + } + written, err := s.conn.Write(b) + n += min(written, remain) + if err != nil { + return 0, err + } + b = nil + case remain == 0: + padding := make([]byte, headerOverHeadSize+size) + padding[0] = cmdWaste + binary.BigEndian.PutUint16(padding[5:7], uint16(size)) + if _, err := s.conn.Write(padding); err != nil { + return 0, err + } + } + } + + if len(b) == 0 { + return n, nil + } + + written, err := s.conn.Write(b) + n += min(written, len(b)) + return n, err +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/proxy/anytls/stream.go b/proxy/anytls/stream.go new file mode 100644 index 0000000..ecdfd63 --- /dev/null +++ b/proxy/anytls/stream.go @@ -0,0 +1,187 @@ +package anytls + +import ( + "bytes" + "io" + "net" + "os" + "sync" + "sync/atomic" + "time" +) + +type Stream struct { + id uint32 + sess *Session + + mu sync.Mutex + readBuf bytes.Buffer + readNotify chan struct{} + closed chan struct{} + dieOnce sync.Once + dieErr error + dieHook func() + + readDeadline atomic.Value + writeDeadline atomic.Value +} + +func newStream(id uint32, sess *Session) *Stream { + return &Stream{ + id: id, + sess: sess, + readNotify: make(chan struct{}, 1), + closed: make(chan struct{}), + } +} + +func (s *Stream) Read(b []byte) (int, error) { + for { + s.mu.Lock() + if s.readBuf.Len() > 0 { + n, _ := s.readBuf.Read(b) + err := error(nil) + if n == 0 && s.dieErr != nil { + err = s.dieErr + } + s.mu.Unlock() + return n, err + } + err := s.dieErr + s.mu.Unlock() + + if err != nil { + return 0, err + } + + deadline := s.loadDeadline(&s.readDeadline) + if deadline.IsZero() { + select { + case <-s.readNotify: + case <-s.closed: + } + continue + } + + wait := time.Until(deadline) + if wait <= 0 { + return 0, os.ErrDeadlineExceeded + } + timer := time.NewTimer(wait) + select { + case <-s.readNotify: + timer.Stop() + case <-s.closed: + timer.Stop() + case <-timer.C: + return 0, os.ErrDeadlineExceeded + } + } +} + +func (s *Stream) Write(b []byte) (int, error) { + if deadline := s.loadDeadline(&s.writeDeadline); !deadline.IsZero() && time.Until(deadline) <= 0 { + return 0, os.ErrDeadlineExceeded + } + + s.mu.Lock() + err := s.dieErr + s.mu.Unlock() + if err != nil { + return 0, err + } + + return s.sess.writeDataFrame(s.id, b) +} + +func (s *Stream) Close() error { + return s.closeWithError(io.ErrClosedPipe) +} + +func (s *Stream) closeLocally() { + var once bool + s.dieOnce.Do(func() { + s.mu.Lock() + s.dieErr = net.ErrClosed + s.mu.Unlock() + close(s.closed) + once = true + }) + if once && s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } +} + +func (s *Stream) closeWithError(err error) error { + var once bool + s.dieOnce.Do(func() { + s.mu.Lock() + s.dieErr = err + s.mu.Unlock() + close(s.closed) + once = true + }) + if !once { + s.mu.Lock() + defer s.mu.Unlock() + return s.dieErr + } + if s.dieHook != nil { + s.dieHook() + s.dieHook = nil + } + return s.sess.streamClosed(s.id) +} + +func (s *Stream) SetReadDeadline(t time.Time) error { + s.readDeadline.Store(t) + return nil +} + +func (s *Stream) SetWriteDeadline(t time.Time) error { + s.writeDeadline.Store(t) + return nil +} + +func (s *Stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) + _ = s.SetWriteDeadline(t) + return nil +} + +func (s *Stream) LocalAddr() net.Addr { + if conn, ok := s.sess.conn.(interface{ LocalAddr() net.Addr }); ok { + return conn.LocalAddr() + } + return nil +} + +func (s *Stream) RemoteAddr() net.Addr { + if conn, ok := s.sess.conn.(interface{ RemoteAddr() net.Addr }); ok { + return conn.RemoteAddr() + } + return nil +} + +func (s *Stream) feed(data []byte) { + s.mu.Lock() + defer s.mu.Unlock() + if s.dieErr != nil { + return + } + _, _ = s.readBuf.Write(data) + select { + case s.readNotify <- struct{}{}: + default: + } +} + +func (s *Stream) loadDeadline(value *atomic.Value) time.Time { + v := value.Load() + if v == nil { + return time.Time{} + } + deadline, _ := v.(time.Time) + return deadline +} diff --git a/proxy/anytls/util.go b/proxy/anytls/util.go new file mode 100644 index 0000000..c9ccc03 --- /dev/null +++ b/proxy/anytls/util.go @@ -0,0 +1,19 @@ +package anytls + +import "strings" + +func stringMapFromBytes(raw []byte) map[string]string { + result := make(map[string]string) + for _, line := range strings.Split(string(raw), "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + key, value, ok := strings.Cut(line, "=") + if !ok { + continue + } + result[key] = value + } + return result +} diff --git a/proxy/vmess/client.go b/proxy/vmess/client.go index e03afd2..27c7b3e 100644 --- a/proxy/vmess/client.go +++ b/proxy/vmess/client.go @@ -108,7 +108,7 @@ func NewClient(uuidStr, security string, alterID int, aead bool) (*Client, error case "zero": c.security = SecurityNone c.opt = OptBasicFormat - case "": + case "", "auto": c.security = SecurityChacha20Poly1305 if runtime.GOARCH == "amd64" || runtime.GOARCH == "s390x" || runtime.GOARCH == "arm64" { c.security = SecurityAES128GCM