Update client.go

This commit is contained in:
xiaolunzhou 2020-12-07 21:25:00 +08:00
parent 085ba97a6a
commit 8184bd3cd0

View File

@ -31,7 +31,6 @@ type Config struct {
Records []string Records []string
AlwaysTCP bool AlwaysTCP bool
CacheSize int CacheSize int
} }
// Client is a dns client struct. // Client is a dns client struct.
@ -42,7 +41,7 @@ type Client struct {
upStream *UPStream upStream *UPStream
upStreamMap map[string]*UPStream upStreamMap map[string]*UPStream
handlers []AnswerHandler handlers []AnswerHandler
httpClient *http.Client httpClient *http.Client
} }
// NewClient returns a new dns client. // NewClient returns a new dns client.
@ -168,52 +167,54 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
if !preferTCP && !c.config.AlwaysTCP && dialer.Addr() == "DIRECT" { if !preferTCP && !c.config.AlwaysTCP && dialer.Addr() == "DIRECT" {
network = "udp" network = "udp"
} }
//init conn and option //init conn and scheme
var rc net.Conn var rc net.Conn
var op string var scheme string
for i := 0; i < ups.Len(); i++ { for i := 0; i < ups.Len(); i++ {
u, err := url.Parse(ups.Server()) u, err := url.Parse(ups.Server())
if err!=nil{ if err != nil {
server=ups.Server() server = ups.Server()
op=network scheme = network
}else{ } else {
server=u.Host server = u.Host
op=u.Scheme scheme = u.Scheme
} }
//if not set option use network else use special option //if not set option use network else use special scheme
switch op{ var e error
switch scheme {
case "tcp": case "tcp":
network = "tcp" rc, e = dialer.Dial("tcp", server)
rc, err = dialer.Dial(network, server)
case "udp": case "udp":
network = "udp" rc, e = dialer.Dial("udp", server)
rc, err = dialer.Dial(network, server)
case "dot": case "dot":
rc,err=tls.Dial("tcp",server,&tls.Config{InsecureSkipVerify: false,}) rc, e = tls.Dial("tcp", server, &tls.Config{InsecureSkipVerify: false})
case "doh": case "doh":
net.DefaultResolver=&net.Resolver{} net.DefaultResolver = &net.Resolver{}
default: default:
scheme=network
break break
} }
if err != nil { if e != nil {
newServer := ups.SwitchIf(server) newServer := ups.SwitchIf(server)
log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v, next server: %s", log.F("[dns] error in resolving %s, failed to connect to server %v via %s: %v, next server: %s",
qname, server, dialer.Addr(), err, newServer) qname, server, dialer.Addr(), err, newServer)
server = newServer server = newServer
continue continue
} }
//TODO: if we use DOH (op=="doh") we don't need close connection //TODO: if we use DOH (scheme=="doh") we don't need close connection
if op!="doh"{ if scheme != "doh" {
defer rc.Close() defer rc.Close()
} }
// TODO: support timeout setting for different upstream server // TODO: support timeout setting for different upstream server
if c.config.Timeout > 0 && op!="doh" { if c.config.Timeout > 0 && scheme != "doh" {
rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second)) rc.SetDeadline(time.Now().Add(time.Duration(c.config.Timeout) * time.Second))
} }
switch op { switch scheme {
case "tcp","dot": case "tcp", "dot":
respBytes, err = c.exchangeTCP(rc, reqBytes) respBytes, err = c.exchangeTCP(rc, reqBytes)
case "udp": case "udp":
respBytes, err = c.exchangeUDP(rc, reqBytes) respBytes, err = c.exchangeUDP(rc, reqBytes)
@ -236,10 +237,11 @@ func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (
c.proxy.Record(dialer, false) c.proxy.Record(dialer, false)
} }
return server, op, dialer.Addr(), respBytes, err return server, scheme, dialer.Addr(), respBytes, err
} }
//exchangeHTTP exchange with server over https //exchangeHTTP exchange with server over https
func (c*Client) exchangeHTTPS(server string,reqBytes[]byte)(body[]byte,err error){ func (c *Client) exchangeHTTPS(server string, reqBytes []byte) (body []byte, err error) {
query := strings.Replace(base64.URLEncoding.EncodeToString(reqBytes), "=", "", -1) query := strings.Replace(base64.URLEncoding.EncodeToString(reqBytes), "=", "", -1)
urls := "https://" + server + "/dns-query?dns=" + query urls := "https://" + server + "/dns-query?dns=" + query
res, err := c.httpClient.Get(urls) res, err := c.httpClient.Get(urls)
@ -253,6 +255,7 @@ func (c*Client) exchangeHTTPS(server string,reqBytes[]byte)(body[]byte,err error
} }
return return
} }
// exchangeTCP exchange with server over tcp. // exchangeTCP exchange with server over tcp.
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) { func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
lenBuf := pool.GetBuffer(2) lenBuf := pool.GetBuffer(2)