Update client.go

This commit is contained in:
xiaolunzhou 2020-12-07 19:51:02 +08:00
parent 1242f4b500
commit 5f8800d643

View File

@ -5,7 +5,6 @@ import (
"encoding/base64" "encoding/base64"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -156,7 +155,6 @@ func (c *Client) extractAnswer(resp *Message) ([]string, int) {
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) ( func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
server, network, dialerAddr string, respBytes []byte, err error) { server, network, dialerAddr string, respBytes []byte, err error) {
ups := c.UpStream(qname) ups := c.UpStream(qname)
fmt.Println(ups)
network = "tcp" network = "tcp"
dialer := c.proxy.NextDialer(qname + ":53") dialer := c.proxy.NextDialer(qname + ":53")
// if we are resolving the dialer's domain, then use Direct to avoid denpency loop // if we are resolving the dialer's domain, then use Direct to avoid denpency loop
@ -191,11 +189,9 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
network = "udp" network = "udp"
rc, err = dialer.Dial(network, server) rc, err = dialer.Dial(network, server)
case "dot": case "dot":
network="tcp" rc,err=tls.Dial("tcp",server,&tls.Config{InsecureSkipVerify: false,})
rc,err=tls.Dial(network,server,&tls.Config{InsecureSkipVerify: false,})
case "doh": case "doh":
net.DefaultResolver=&net.Resolver{} net.DefaultResolver=&net.Resolver{}
network="doh"
default: default:
break break
} }
@ -207,14 +203,16 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
continue continue
} }
//TODO: if we use DOH (network=="doh") we don't need close connection //TODO: if we use DOH (network=="doh") we don't need close connection
if network!="doh" {defer rc.Close()} if network!="doh"{
defer rc.Close()
}
// TODO: support timeout setting for different upstream server // TODO: support timeout setting for different upstream server
if c.config.Timeout > 0 && network!="doh" { if c.config.Timeout > 0 && network!="doh" {
rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second)) rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second))
} }
switch network { switch op {
case "tcp","dot": case "tcp","dot":
respBytes, err = c.exchangeTCP(rc, reqBytes) respBytes, err = c.exchangeTCP(rc, reqBytes)
case "udp": case "udp":
@ -222,7 +220,6 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
case "doh": case "doh":
respBytes, err = c.exchangeHTTPS(server, reqBytes) respBytes, err = c.exchangeHTTPS(server, reqBytes)
} }
if err == nil { if err == nil {
break break
} }
@ -239,7 +236,7 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
c.proxy.Record(dialer, false) c.proxy.Record(dialer, false)
} }
return server, network, dialer.Addr(), respBytes, err return server, op, dialer.Addr(), respBytes, err
} }
//exchangeHTTP exchange with server over https //exchangeHTTP exchange with server over https
func (c*Client) exchangeHTTPS(server string,reqBytes[]byte)(body[]byte,err error){ func (c*Client) exchangeHTTPS(server string,reqBytes[]byte)(body[]byte,err error){