glider/dns/client.go

136 lines
3.0 KiB
Go
Raw Normal View History

2018-07-29 23:44:23 +08:00
package dns
import (
"encoding/binary"
"io"
"strings"
"github.com/nadoo/glider/common/log"
"github.com/nadoo/glider/proxy"
)
// HandleFunc function handles the dns TypeA or TypeAAAA answer
type HandleFunc func(Domain, ip string) error
// Client is a dns client struct
type Client struct {
dialer proxy.Dialer
UPServers []string
UPServerMap map[string][]string
Handlers []HandleFunc
tcp bool
}
// NewClient returns a new dns client
func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
c := &Client{
dialer: dialer,
UPServers: upServers,
UPServerMap: make(map[string][]string),
}
return c, nil
}
// Exchange handles request msg and returns response msg
// reqBytes = reqLen + reqMsg
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
req, err := UnmarshalMessage(reqBytes[2:])
2018-07-29 23:44:23 +08:00
if err != nil {
return
}
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
2018-07-29 23:44:23 +08:00
// TODO: if query.QNAME in cache
// get respMsg from cache
// set msg id
// return respMsg, nil
}
dnsServer := c.GetServer(req.Question.QNAME)
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
2018-07-29 23:44:23 +08:00
if err != nil {
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
return
}
defer rc.Close()
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
log.F("[dns] failed to write req message: %v", err)
return
}
var respLen uint16
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
log.F("[dns] failed to read response length: %v", err)
return
}
respBytes = make([]byte, respLen+2)
binary.BigEndian.PutUint16(respBytes[:2], respLen)
respMsg := respBytes[2:]
_, err = io.ReadFull(rc, respMsg)
if err != nil {
log.F("[dns] error in read respMsg %s\n", err)
return
}
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
2018-07-29 23:44:23 +08:00
return
}
resp, err := UnmarshalMessage(respMsg)
2018-07-29 23:44:23 +08:00
if err != nil {
return
}
2018-07-30 00:18:10 +08:00
ips := []string{}
for _, answer := range resp.Answers {
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
2018-07-29 23:44:23 +08:00
for _, h := range c.Handlers {
h(resp.Question.QNAME, answer.IP)
2018-07-29 23:44:23 +08:00
}
if answer.IP != "" {
2018-07-30 00:18:10 +08:00
ips = append(ips, answer.IP)
2018-07-29 23:44:23 +08:00
}
}
}
// add to cache
log.F("[dns] %s <-> %s, type: %d, %s: %s",
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
2018-07-29 23:44:23 +08:00
return
}
// SetServer .
func (c *Client) SetServer(domain string, servers ...string) {
c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...)
}
// GetServer .
func (c *Client) GetServer(domain string) string {
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
if servers, ok := c.UPServerMap[domain]; ok {
return servers[0]
}
}
// TODO:
return c.UPServers[0]
}
// AddHandler .
func (c *Client) AddHandler(h HandleFunc) {
c.Handlers = append(c.Handlers, h)
}