diff --git a/common/conn/conn.go b/common/conn/conn.go index aca363b..b96ca66 100644 --- a/common/conn/conn.go +++ b/common/conn/conn.go @@ -45,14 +45,41 @@ func Relay(left, right net.Conn) (int64, int64, error) { go func() { n, err := io.Copy(right, left) - right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right - left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left + + switch src := left.(type) { + case *net.TCPConn: + src.CloseRead() + default: + src.SetDeadline(time.Now()) + } + + switch dst := right.(type) { + case *net.TCPConn: + dst.CloseWrite() + dst.SetDeadline(time.Now().Add(time.Second * 60)) + default: + dst.SetDeadline(time.Now()) + } + ch <- res{n, err} }() n, err := io.Copy(left, right) - right.SetDeadline(time.Now()) // wake up the other goroutine blocking on right - left.SetDeadline(time.Now()) // wake up the other goroutine blocking on left + switch src := left.(type) { + case *net.TCPConn: + src.CloseWrite() + src.SetDeadline(time.Now().Add(time.Second * 60)) + default: + src.SetDeadline(time.Now()) + } + + switch dst := right.(type) { + case *net.TCPConn: + dst.CloseRead() + default: + dst.SetDeadline(time.Now()) + } + rs := <-ch if err == nil {