mirror of
https://github.com/nadoo/glider.git
synced 2025-04-21 03:32:07 +08:00
409 lines
9.6 KiB
Go
409 lines
9.6 KiB
Go
package grpc
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"fmt"
|
|
"github.com/Qv2ray/gun/pkg/cert"
|
|
"github.com/Qv2ray/gun/pkg/proto"
|
|
"github.com/nadoo/glider/pkg/pool"
|
|
"github.com/nadoo/glider/proxy"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/backoff"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/connectivity"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/keepalive"
|
|
"google.golang.org/grpc/peer"
|
|
"google.golang.org/grpc/status"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// https://github.com/v2fly/v2ray-core/blob/v5.0.6/transport/internet/grpc/dial.go
|
|
var (
|
|
globalCCMap map[string]*grpc.ClientConn
|
|
globalCCAccess sync.Mutex
|
|
)
|
|
|
|
type ccCanceller func()
|
|
|
|
type ClientConn struct {
|
|
tun proto.GunService_TunClient
|
|
closer context.CancelFunc
|
|
muReading sync.Mutex // muReading protects reading
|
|
muWriting sync.Mutex // muWriting protects writing
|
|
muRecv sync.Mutex // muReading protects recv
|
|
muSend sync.Mutex // muWriting protects send
|
|
buf []byte
|
|
offset int
|
|
|
|
deadlineMu sync.Mutex
|
|
readDeadline *time.Timer
|
|
writeDeadline *time.Timer
|
|
readClosed chan struct{}
|
|
writeClosed chan struct{}
|
|
closed chan struct{}
|
|
}
|
|
|
|
func NewClientConn(tun proto.GunService_TunClient, closer context.CancelFunc) *ClientConn {
|
|
return &ClientConn{
|
|
tun: tun,
|
|
closer: closer,
|
|
readClosed: make(chan struct{}),
|
|
writeClosed: make(chan struct{}),
|
|
closed: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
type RecvResp struct {
|
|
hunk *proto.Hunk
|
|
err error
|
|
}
|
|
|
|
func (c *ClientConn) Read(p []byte) (n int, err error) {
|
|
select {
|
|
case <-c.readClosed:
|
|
return 0, os.ErrDeadlineExceeded
|
|
case <-c.closed:
|
|
return 0, io.EOF
|
|
default:
|
|
}
|
|
|
|
c.muReading.Lock()
|
|
defer c.muReading.Unlock()
|
|
if c.buf != nil {
|
|
n = copy(p, c.buf[c.offset:])
|
|
c.offset += n
|
|
if c.offset == len(c.buf) {
|
|
pool.PutBuffer(c.buf)
|
|
c.buf = nil
|
|
}
|
|
return n, nil
|
|
}
|
|
// set 1 to avoid channel leak
|
|
readDone := make(chan RecvResp, 1)
|
|
// pass channel to the function to avoid closure leak
|
|
go func(readDone chan RecvResp) {
|
|
// FIXME: not really abort the send so there is some problems when recover
|
|
c.muRecv.Lock()
|
|
defer c.muRecv.Unlock()
|
|
recv, e := c.tun.Recv()
|
|
readDone <- RecvResp{
|
|
hunk: recv,
|
|
err: e,
|
|
}
|
|
}(readDone)
|
|
select {
|
|
case <-c.readClosed:
|
|
return 0, os.ErrDeadlineExceeded
|
|
case <-c.closed:
|
|
return 0, io.EOF
|
|
case recvResp := <-readDone:
|
|
err = recvResp.err
|
|
if err != nil {
|
|
if code := status.Code(err); code == codes.Unavailable || status.Code(err) == codes.OutOfRange {
|
|
err = io.EOF
|
|
}
|
|
return 0, err
|
|
}
|
|
n = copy(p, recvResp.hunk.Data)
|
|
c.buf = pool.GetBuffer(len(recvResp.hunk.Data) - n)
|
|
copy(c.buf, recvResp.hunk.Data[n:])
|
|
c.offset = 0
|
|
return n, nil
|
|
}
|
|
}
|
|
|
|
func (c *ClientConn) Write(p []byte) (n int, err error) {
|
|
select {
|
|
case <-c.writeClosed:
|
|
return 0, os.ErrDeadlineExceeded
|
|
case <-c.closed:
|
|
return 0, io.EOF
|
|
default:
|
|
}
|
|
|
|
c.muWriting.Lock()
|
|
defer c.muWriting.Unlock()
|
|
// set 1 to avoid channel leak
|
|
sendDone := make(chan error, 1)
|
|
// pass channel to the function to avoid closure leak
|
|
go func(sendDone chan error) {
|
|
// FIXME: not really abort the send so there is some problems when recover
|
|
c.muSend.Lock()
|
|
defer c.muSend.Unlock()
|
|
e := c.tun.Send(&proto.Hunk{Data: p})
|
|
sendDone <- e
|
|
}(sendDone)
|
|
select {
|
|
case <-c.writeClosed:
|
|
return 0, os.ErrDeadlineExceeded
|
|
case <-c.closed:
|
|
return 0, io.EOF
|
|
case err = <-sendDone:
|
|
if code := status.Code(err); code == codes.Unavailable || status.Code(err) == codes.OutOfRange {
|
|
err = io.EOF
|
|
}
|
|
return len(p), err
|
|
}
|
|
}
|
|
|
|
func (c *ClientConn) Close() error {
|
|
select {
|
|
case <-c.closed:
|
|
default:
|
|
close(c.closed)
|
|
}
|
|
c.closer()
|
|
return nil
|
|
}
|
|
func (c *ClientConn) CloseWrite() error {
|
|
return c.tun.CloseSend()
|
|
}
|
|
func (c *ClientConn) LocalAddr() net.Addr {
|
|
// FIXME
|
|
return nil
|
|
}
|
|
func (c *ClientConn) RemoteAddr() net.Addr {
|
|
p, _ := peer.FromContext(c.tun.Context())
|
|
return p.Addr
|
|
}
|
|
|
|
func (c *ClientConn) SetDeadline(t time.Time) error {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
if now := time.Now(); t.After(now) {
|
|
// refresh the deadline if the deadline has been exceeded
|
|
select {
|
|
case <-c.readClosed:
|
|
c.readClosed = make(chan struct{})
|
|
default:
|
|
}
|
|
select {
|
|
case <-c.writeClosed:
|
|
c.writeClosed = make(chan struct{})
|
|
default:
|
|
}
|
|
// reset the deadline timer to make the c.readClosed and c.writeClosed with the new pointer (if it is)
|
|
if c.readDeadline != nil {
|
|
c.readDeadline.Stop()
|
|
}
|
|
c.readDeadline = time.AfterFunc(t.Sub(now), func() {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
select {
|
|
case <-c.readClosed:
|
|
default:
|
|
close(c.readClosed)
|
|
}
|
|
})
|
|
if c.writeDeadline != nil {
|
|
c.writeDeadline.Stop()
|
|
}
|
|
c.writeDeadline = time.AfterFunc(t.Sub(now), func() {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
select {
|
|
case <-c.writeClosed:
|
|
default:
|
|
close(c.writeClosed)
|
|
}
|
|
})
|
|
} else {
|
|
select {
|
|
case <-c.readClosed:
|
|
default:
|
|
close(c.readClosed)
|
|
}
|
|
select {
|
|
case <-c.writeClosed:
|
|
default:
|
|
close(c.writeClosed)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientConn) SetReadDeadline(t time.Time) error {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
if now := time.Now(); t.After(now) {
|
|
// refresh the deadline if the deadline has been exceeded
|
|
select {
|
|
case <-c.readClosed:
|
|
c.readClosed = make(chan struct{})
|
|
default:
|
|
}
|
|
// reset the deadline timer to make the c.readClosed and c.writeClosed with the new pointer (if it is)
|
|
if c.readDeadline != nil {
|
|
c.readDeadline.Stop()
|
|
}
|
|
c.readDeadline = time.AfterFunc(t.Sub(now), func() {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
select {
|
|
case <-c.readClosed:
|
|
default:
|
|
close(c.readClosed)
|
|
}
|
|
})
|
|
} else {
|
|
select {
|
|
case <-c.readClosed:
|
|
default:
|
|
close(c.readClosed)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *ClientConn) SetWriteDeadline(t time.Time) error {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
if now := time.Now(); t.After(now) {
|
|
// refresh the deadline if the deadline has been exceeded
|
|
select {
|
|
case <-c.writeClosed:
|
|
c.writeClosed = make(chan struct{})
|
|
default:
|
|
}
|
|
if c.writeDeadline != nil {
|
|
c.writeDeadline.Stop()
|
|
}
|
|
c.writeDeadline = time.AfterFunc(t.Sub(now), func() {
|
|
c.deadlineMu.Lock()
|
|
defer c.deadlineMu.Unlock()
|
|
select {
|
|
case <-c.writeClosed:
|
|
default:
|
|
close(c.writeClosed)
|
|
}
|
|
})
|
|
} else {
|
|
select {
|
|
case <-c.writeClosed:
|
|
default:
|
|
close(c.writeClosed)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// NewGRPCDialer returns a gRPC proxy dialer.
|
|
func NewGRPCDialer(s string, d proxy.Dialer) (proxy.Dialer, error) {
|
|
w, err := NewGRPC(s, d, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("[gRPC] create instance error: %s", err)
|
|
}
|
|
|
|
if w.certFile != "" {
|
|
certData, err := os.ReadFile(w.certFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("[gRPC] read cert file error: %s", err)
|
|
}
|
|
|
|
certPool := x509.NewCertPool()
|
|
if !certPool.AppendCertsFromPEM(certData) {
|
|
return nil, fmt.Errorf("[gRPC] can not append cert file: %s", w.certFile)
|
|
}
|
|
w.certPool = certPool
|
|
} else {
|
|
w.certPool, err = cert.GetSystemCertPool()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get system certificate pool")
|
|
}
|
|
}
|
|
|
|
return w, err
|
|
}
|
|
|
|
// Addr returns forwarder's address.
|
|
func (g *GRPC) Addr() string {
|
|
if g.addr == "" {
|
|
return g.dialer.Addr()
|
|
}
|
|
return g.addr
|
|
}
|
|
|
|
// Dial connects to the address addr on the network net via the proxy.
|
|
func (g *GRPC) Dial(network string, address string) (net.Conn, error) {
|
|
cc, cancel, err := getGrpcClientConn(g.dialer, g.addr, g.serverName, g.certPool, g.skipVerify)
|
|
if err != nil {
|
|
cancel()
|
|
return nil, err
|
|
}
|
|
client := proto.NewGunServiceClient(cc)
|
|
|
|
clientX := client.(proto.GunServiceClientX)
|
|
serviceName := g.serviceName
|
|
if serviceName == "" {
|
|
serviceName = "GunService"
|
|
}
|
|
// ctx is the lifetime of the tun
|
|
ctxStream, streamCloser := context.WithCancel(context.Background())
|
|
tun, err := clientX.TunCustomName(ctxStream, serviceName)
|
|
if err != nil {
|
|
streamCloser()
|
|
return nil, err
|
|
}
|
|
return NewClientConn(tun, streamCloser), nil
|
|
}
|
|
|
|
// DialUDP connects to the given address via the proxy.
|
|
func (g *GRPC) DialUDP(network, addr string) (net.PacketConn, error) {
|
|
return nil, proxy.ErrNotSupported
|
|
}
|
|
|
|
func getGrpcClientConn(dialer proxy.Dialer, address string, serverName string, certPool *x509.CertPool, skipVerify bool) (*grpc.ClientConn, ccCanceller, error) {
|
|
globalCCAccess.Lock()
|
|
if globalCCMap == nil {
|
|
globalCCMap = make(map[string]*grpc.ClientConn)
|
|
}
|
|
globalCCAccess.Unlock()
|
|
|
|
canceller := func() {
|
|
globalCCAccess.Lock()
|
|
defer globalCCAccess.Unlock()
|
|
globalCCMap[address].Close()
|
|
delete(globalCCMap, address)
|
|
}
|
|
|
|
// TODO Should support chain proxy to the same destination
|
|
globalCCAccess.Lock()
|
|
if client, found := globalCCMap[address]; found && client.GetState() != connectivity.Shutdown {
|
|
globalCCAccess.Unlock()
|
|
return client, canceller, nil
|
|
}
|
|
globalCCAccess.Unlock()
|
|
dialOptions := []grpc.DialOption{
|
|
grpc.WithTransportCredentials(credentials.NewClientTLSFromCert(certPool, serverName)),
|
|
grpc.WithContextDialer(func(ctxGrpc context.Context, s string) (net.Conn, error) {
|
|
return dialer.Dial("tcp", s)
|
|
}), grpc.WithConnectParams(grpc.ConnectParams{
|
|
Backoff: backoff.Config{
|
|
BaseDelay: 500 * time.Millisecond,
|
|
Multiplier: 1.5,
|
|
Jitter: 0.2,
|
|
MaxDelay: 19 * time.Second,
|
|
},
|
|
MinConnectTimeout: 5 * time.Second,
|
|
}), grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
|
Time: 30 * time.Second,
|
|
Timeout: 10 * time.Second,
|
|
PermitWithoutStream: true,
|
|
}),
|
|
}
|
|
if skipVerify {
|
|
dialOptions = append(dialOptions, grpc.WithInsecure())
|
|
}
|
|
cc, err := grpc.Dial(address, dialOptions...)
|
|
globalCCAccess.Lock()
|
|
globalCCMap[address] = cc
|
|
globalCCAccess.Unlock()
|
|
return cc, canceller, err
|
|
}
|