From 121f3680e22d709267442ac9ed4d1c1d83ee0108 Mon Sep 17 00:00:00 2001 From: mzz2017 Date: Mon, 7 Dec 2020 14:56:01 +0800 Subject: [PATCH] feat: support socks4a --- proxy/socks4/socks4.go | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/proxy/socks4/socks4.go b/proxy/socks4/socks4.go index f5aff9e..87e626f 100644 --- a/proxy/socks4/socks4.go +++ b/proxy/socks4/socks4.go @@ -25,12 +25,14 @@ const ( // SOCKS4 is a base socks4 struct. type SOCKS4 struct { - dialer proxy.Dialer - addr string + dialer proxy.Dialer + addr string + socks4a bool } func init() { proxy.RegisterDialer("socks4", NewSocks4Dialer) + proxy.RegisterDialer("socks4a", NewSocks4Dialer) } // NewSOCKS4 returns a socks4 proxy. @@ -42,8 +44,9 @@ func NewSOCKS4(s string, dialer proxy.Dialer) (*SOCKS4, error) { } h := &SOCKS4{ - dialer: dialer, - addr: u.Host, + dialer: dialer, + addr: u.Host, + socks4a: u.Scheme == "socks4a", } return h, nil @@ -123,19 +126,36 @@ func (s *SOCKS4) connect(conn net.Conn, target string) error { return errors.New("[socks4] port number out of range: " + portStr) } - ip, err := s.lookupIP(host) - if err != nil { - return err + bufSize := 8 + var ip net.IP + if ip = net.ParseIP(host); ip == nil { + if s.socks4a { + // The client should set the first three bytes of DSTIP to NULL + // and the last byte to a non-zero value. + ip = net.ParseIP("0.0.0.1") + bufSize += len(host) + 1 + } else { + ip, err = s.lookupIP(host) + if err != nil { + return err + } + } } - // taken from https://github.com/h12w/socks/blob/master/socks.go - buf := []byte{ + // taken from https://github.com/h12w/socks/blob/master/socks.go and https://en.wikipedia.org/wiki/SOCKS + buf := pool.GetBuffer(bufSize) + defer pool.PutBuffer(buf) + copy(buf, []byte{ Version, ConnectCommand, byte(port >> 8), // higher byte of destination port byte(port), // lower byte of destination port (big endian) ip[0], ip[1], ip[2], ip[3], 0, // user id + }) + if s.socks4a { + copy(buf[8:], host) + buf[len(buf)-1] = 0 } resp := pool.GetBuffer(8)