From 1b484cca8f7077fd5459c5969d0914a919e82e99 Mon Sep 17 00:00:00 2001 From: mzz Date: Thu, 3 Dec 2020 15:27:52 +0800 Subject: [PATCH] feat(vmess): support forward/udp (#199) --- proxy/vmess/client.go | 16 +++++++++------- proxy/vmess/packet.go | 22 ++++++++++++++++++++++ proxy/vmess/vmess.go | 12 ++++++++++-- 3 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 proxy/vmess/packet.go diff --git a/proxy/vmess/client.go b/proxy/vmess/client.go index d823192..b6988e5 100644 --- a/proxy/vmess/client.go +++ b/proxy/vmess/client.go @@ -35,9 +35,11 @@ const ( ) // CMD types +type CmdType byte + const ( - CmdTCP byte = 1 - CmdUDP byte = 2 + CmdTCP CmdType = 1 + CmdUDP CmdType = 2 ) // Client is a vmess client. @@ -107,7 +109,7 @@ func NewClient(uuidStr, security string, alterID int) (*Client, error) { } // NewConn returns a new vmess conn. -func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { +func (c *Client) NewConn(rc net.Conn, target string, cmd CmdType) (*Conn, error) { r := rand.Intn(c.count) conn := &Conn{user: c.users[r], opt: c.opt, security: c.security, Conn: rc} @@ -135,7 +137,7 @@ func (c *Client) NewConn(rc net.Conn, target string) (*Conn, error) { } // Request - err = conn.Request() + err = conn.Request(cmd) if err != nil { return nil, err } @@ -158,7 +160,7 @@ func (c *Conn) Auth() error { } // Request sends request to server. -func (c *Conn) Request() error { +func (c *Conn) Request(cmd CmdType) error { buf := pool.GetBytesBuffer() defer pool.PutBytesBuffer(buf) @@ -174,8 +176,8 @@ func (c *Conn) Request() error { pSec := byte(paddingLen<<4) | c.security // P(4bit) and Sec(4bit) buf.WriteByte(pSec) - buf.WriteByte(0) // reserved - buf.WriteByte(CmdTCP) // cmd + buf.WriteByte(0) // reserved + buf.WriteByte(byte(cmd)) // cmd // target err := binary.Write(buf, binary.BigEndian, uint16(c.port)) // port diff --git a/proxy/vmess/packet.go b/proxy/vmess/packet.go new file mode 100644 index 0000000..b44e30f --- /dev/null +++ b/proxy/vmess/packet.go @@ -0,0 +1,22 @@ +package vmess + +import ( + "net" +) + +// PktConn is a udp Packet.Conn. +type PktConn struct{ net.Conn } + +// NewPktConn returns a PktConn. +func NewPktConn(c net.Conn) *PktConn { return &PktConn{Conn: c} } + +// ReadFrom implements the necessary function of net.PacketConn. +func (pc *PktConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := pc.Read(b) + return n, nil, err +} + +// WriteTo implements the necessary function of net.PacketConn. +func (pc *PktConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return pc.Write(b) +} diff --git a/proxy/vmess/vmess.go b/proxy/vmess/vmess.go index ae5f3fe..0d0791a 100644 --- a/proxy/vmess/vmess.go +++ b/proxy/vmess/vmess.go @@ -92,10 +92,18 @@ func (s *VMess) Dial(network, addr string) (net.Conn, error) { return nil, err } - return s.client.NewConn(rc, addr) + return s.client.NewConn(rc, addr, CmdTCP) } // DialUDP connects to the given address via the proxy. func (s *VMess) DialUDP(network, addr string) (net.PacketConn, net.Addr, error) { - return nil, nil, proxy.ErrNotSupported + rc, err := s.dialer.Dial("tcp", s.addr) + if err != nil { + return nil, nil, err + } + rc, err = s.client.NewConn(rc, addr, CmdUDP) + if err != nil { + return nil, nil, err + } + return NewPktConn(rc), nil, err }