glider/proxy/grpc/client.go
2022-06-20 14:55:56 +08:00

409 lines
9.6 KiB
Go

package grpc
import (
"context"
"crypto/x509"
"fmt"
"github.com/Qv2ray/gun/pkg/cert"
"github.com/Qv2ray/gun/pkg/proto"
"github.com/nadoo/glider/pkg/pool"
"github.com/nadoo/glider/proxy"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"io"
"net"
"os"
"sync"
"time"
)
// https://github.com/v2fly/v2ray-core/blob/v5.0.6/transport/internet/grpc/dial.go
var (
globalCCMap map[string]*grpc.ClientConn
globalCCAccess sync.Mutex
)
type ccCanceller func()
type ClientConn struct {
tun proto.GunService_TunClient
closer context.CancelFunc
muReading sync.Mutex // muReading protects reading
muWriting sync.Mutex // muWriting protects writing
muRecv sync.Mutex // muReading protects recv
muSend sync.Mutex // muWriting protects send
buf []byte
offset int
deadlineMu sync.Mutex
readDeadline *time.Timer
writeDeadline *time.Timer
readClosed chan struct{}
writeClosed chan struct{}
closed chan struct{}
}
func NewClientConn(tun proto.GunService_TunClient, closer context.CancelFunc) *ClientConn {
return &ClientConn{
tun: tun,
closer: closer,
readClosed: make(chan struct{}),
writeClosed: make(chan struct{}),
closed: make(chan struct{}),
}
}
type RecvResp struct {
hunk *proto.Hunk
err error
}
func (c *ClientConn) Read(p []byte) (n int, err error) {
select {
case <-c.readClosed:
return 0, os.ErrDeadlineExceeded
case <-c.closed:
return 0, io.EOF
default:
}
c.muReading.Lock()
defer c.muReading.Unlock()
if c.buf != nil {
n = copy(p, c.buf[c.offset:])
c.offset += n
if c.offset == len(c.buf) {
pool.PutBuffer(c.buf)
c.buf = nil
}
return n, nil
}
// set 1 to avoid channel leak
readDone := make(chan RecvResp, 1)
// pass channel to the function to avoid closure leak
go func(readDone chan RecvResp) {
// FIXME: not really abort the send so there is some problems when recover
c.muRecv.Lock()
defer c.muRecv.Unlock()
recv, e := c.tun.Recv()
readDone <- RecvResp{
hunk: recv,
err: e,
}
}(readDone)
select {
case <-c.readClosed:
return 0, os.ErrDeadlineExceeded
case <-c.closed:
return 0, io.EOF
case recvResp := <-readDone:
err = recvResp.err
if err != nil {
if code := status.Code(err); code == codes.Unavailable || status.Code(err) == codes.OutOfRange {
err = io.EOF
}
return 0, err
}
n = copy(p, recvResp.hunk.Data)
c.buf = pool.GetBuffer(len(recvResp.hunk.Data) - n)
copy(c.buf, recvResp.hunk.Data[n:])
c.offset = 0
return n, nil
}
}
func (c *ClientConn) Write(p []byte) (n int, err error) {
select {
case <-c.writeClosed:
return 0, os.ErrDeadlineExceeded
case <-c.closed:
return 0, io.EOF
default:
}
c.muWriting.Lock()
defer c.muWriting.Unlock()
// set 1 to avoid channel leak
sendDone := make(chan error, 1)
// pass channel to the function to avoid closure leak
go func(sendDone chan error) {
// FIXME: not really abort the send so there is some problems when recover
c.muSend.Lock()
defer c.muSend.Unlock()
e := c.tun.Send(&proto.Hunk{Data: p})
sendDone <- e
}(sendDone)
select {
case <-c.writeClosed:
return 0, os.ErrDeadlineExceeded
case <-c.closed:
return 0, io.EOF
case err = <-sendDone:
if code := status.Code(err); code == codes.Unavailable || status.Code(err) == codes.OutOfRange {
err = io.EOF
}
return len(p), err
}
}
func (c *ClientConn) Close() error {
select {
case <-c.closed:
default:
close(c.closed)
}
c.closer()
return nil
}
func (c *ClientConn) CloseWrite() error {
return c.tun.CloseSend()
}
func (c *ClientConn) LocalAddr() net.Addr {
// FIXME
return nil
}
func (c *ClientConn) RemoteAddr() net.Addr {
p, _ := peer.FromContext(c.tun.Context())
return p.Addr
}
func (c *ClientConn) SetDeadline(t time.Time) error {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
if now := time.Now(); t.After(now) {
// refresh the deadline if the deadline has been exceeded
select {
case <-c.readClosed:
c.readClosed = make(chan struct{})
default:
}
select {
case <-c.writeClosed:
c.writeClosed = make(chan struct{})
default:
}
// reset the deadline timer to make the c.readClosed and c.writeClosed with the new pointer (if it is)
if c.readDeadline != nil {
c.readDeadline.Stop()
}
c.readDeadline = time.AfterFunc(t.Sub(now), func() {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
select {
case <-c.readClosed:
default:
close(c.readClosed)
}
})
if c.writeDeadline != nil {
c.writeDeadline.Stop()
}
c.writeDeadline = time.AfterFunc(t.Sub(now), func() {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
select {
case <-c.writeClosed:
default:
close(c.writeClosed)
}
})
} else {
select {
case <-c.readClosed:
default:
close(c.readClosed)
}
select {
case <-c.writeClosed:
default:
close(c.writeClosed)
}
}
return nil
}
func (c *ClientConn) SetReadDeadline(t time.Time) error {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
if now := time.Now(); t.After(now) {
// refresh the deadline if the deadline has been exceeded
select {
case <-c.readClosed:
c.readClosed = make(chan struct{})
default:
}
// reset the deadline timer to make the c.readClosed and c.writeClosed with the new pointer (if it is)
if c.readDeadline != nil {
c.readDeadline.Stop()
}
c.readDeadline = time.AfterFunc(t.Sub(now), func() {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
select {
case <-c.readClosed:
default:
close(c.readClosed)
}
})
} else {
select {
case <-c.readClosed:
default:
close(c.readClosed)
}
}
return nil
}
func (c *ClientConn) SetWriteDeadline(t time.Time) error {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
if now := time.Now(); t.After(now) {
// refresh the deadline if the deadline has been exceeded
select {
case <-c.writeClosed:
c.writeClosed = make(chan struct{})
default:
}
if c.writeDeadline != nil {
c.writeDeadline.Stop()
}
c.writeDeadline = time.AfterFunc(t.Sub(now), func() {
c.deadlineMu.Lock()
defer c.deadlineMu.Unlock()
select {
case <-c.writeClosed:
default:
close(c.writeClosed)
}
})
} else {
select {
case <-c.writeClosed:
default:
close(c.writeClosed)
}
}
return nil
}
// NewGRPCDialer returns a gRPC proxy dialer.
func NewGRPCDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
w, err := NewGRPC(s, d, nil)
if err != nil {
return nil, fmt.Errorf("[gRPC] create instance error: %s", err)
}
if w.certFile != "" {
certData, err := os.ReadFile(w.certFile)
if err != nil {
return nil, fmt.Errorf("[gRPC] read cert file error: %s", err)
}
certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(certData) {
return nil, fmt.Errorf("[gRPC] can not append cert file: %s", w.certFile)
}
w.certPool = certPool
} else {
w.certPool, err = cert.GetSystemCertPool()
if err != nil {
return nil, fmt.Errorf("failed to get system certificate pool")
}
}
return w, err
}
// Addr returns forwarder's address.
func (g *GRPC) Addr() string {
if g.addr == "" {
return g.dialer.Addr()
}
return g.addr
}
// Dial connects to the address addr on the network net via the proxy.
func (g *GRPC) Dial(network string, address string) (net.Conn, error) {
cc, cancel, err := getGrpcClientConn(g.dialer, g.addr, g.serverName, g.certPool, g.skipVerify)
if err != nil {
cancel()
return nil, err
}
client := proto.NewGunServiceClient(cc)
clientX := client.(proto.GunServiceClientX)
serviceName := g.serviceName
if serviceName == "" {
serviceName = "GunService"
}
// ctx is the lifetime of the tun
ctxStream, streamCloser := context.WithCancel(context.Background())
tun, err := clientX.TunCustomName(ctxStream, serviceName)
if err != nil {
streamCloser()
return nil, err
}
return NewClientConn(tun, streamCloser), nil
}
// DialUDP connects to the given address via the proxy.
func (g *GRPC) DialUDP(network, addr string) (net.PacketConn, error) {
return nil, proxy.ErrNotSupported
}
func getGrpcClientConn(dialer proxy.Dialer, address string, serverName string, certPool *x509.CertPool, skipVerify bool) (*grpc.ClientConn, ccCanceller, error) {
globalCCAccess.Lock()
if globalCCMap == nil {
globalCCMap = make(map[string]*grpc.ClientConn)
}
globalCCAccess.Unlock()
canceller := func() {
globalCCAccess.Lock()
defer globalCCAccess.Unlock()
globalCCMap[address].Close()
delete(globalCCMap, address)
}
// TODO Should support chain proxy to the same destination
globalCCAccess.Lock()
if client, found := globalCCMap[address]; found && client.GetState() != connectivity.Shutdown {
globalCCAccess.Unlock()
return client, canceller, nil
}
globalCCAccess.Unlock()
dialOptions := []grpc.DialOption{
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(certPool, serverName)),
grpc.WithContextDialer(func(ctxGrpc context.Context, s string) (net.Conn, error) {
return dialer.Dial("tcp", s)
}), grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 500 * time.Millisecond,
Multiplier: 1.5,
Jitter: 0.2,
MaxDelay: 19 * time.Second,
},
MinConnectTimeout: 5 * time.Second,
}), grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 30 * time.Second,
Timeout: 10 * time.Second,
PermitWithoutStream: true,
}),
}
if skipVerify {
dialOptions = append(dialOptions, grpc.WithInsecure())
}
cc, err := grpc.Dial(address, dialOptions...)
globalCCAccess.Lock()
globalCCMap[address] = cc
globalCCAccess.Unlock()
return cc, canceller, err
}