mirror of
https://github.com/nadoo/glider.git
synced 2025-02-24 01:45:39 +08:00
136 lines
3.0 KiB
Go
136 lines
3.0 KiB
Go
package dns
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/nadoo/glider/common/log"
|
|
"github.com/nadoo/glider/proxy"
|
|
)
|
|
|
|
// HandleFunc function handles the dns TypeA or TypeAAAA answer
|
|
type HandleFunc func(Domain, ip string) error
|
|
|
|
// Client is a dns client struct
|
|
type Client struct {
|
|
dialer proxy.Dialer
|
|
UPServers []string
|
|
UPServerMap map[string][]string
|
|
Handlers []HandleFunc
|
|
|
|
tcp bool
|
|
}
|
|
|
|
// NewClient returns a new dns client
|
|
func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
|
|
c := &Client{
|
|
dialer: dialer,
|
|
UPServers: upServers,
|
|
UPServerMap: make(map[string][]string),
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
// Exchange handles request msg and returns response msg
|
|
// reqBytes = reqLen + reqMsg
|
|
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
|
|
req, err := UnmarshalMessage(reqBytes[2:])
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
|
|
// TODO: if query.QNAME in cache
|
|
// get respMsg from cache
|
|
// set msg id
|
|
// return respMsg, nil
|
|
}
|
|
|
|
dnsServer := c.GetServer(req.Question.QNAME)
|
|
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
|
|
if err != nil {
|
|
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
|
|
return
|
|
}
|
|
defer rc.Close()
|
|
|
|
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
|
|
log.F("[dns] failed to write req message: %v", err)
|
|
return
|
|
}
|
|
|
|
var respLen uint16
|
|
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
|
|
log.F("[dns] failed to read response length: %v", err)
|
|
return
|
|
}
|
|
|
|
respBytes = make([]byte, respLen+2)
|
|
binary.BigEndian.PutUint16(respBytes[:2], respLen)
|
|
|
|
respMsg := respBytes[2:]
|
|
_, err = io.ReadFull(rc, respMsg)
|
|
if err != nil {
|
|
log.F("[dns] error in read respMsg %s\n", err)
|
|
return
|
|
}
|
|
|
|
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
|
|
return
|
|
}
|
|
|
|
resp, err := UnmarshalMessage(respMsg)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
ips := []string{}
|
|
for _, answer := range resp.Answers {
|
|
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
|
|
for _, h := range c.Handlers {
|
|
h(resp.Question.QNAME, answer.IP)
|
|
}
|
|
|
|
if answer.IP != "" {
|
|
ips = append(ips, answer.IP)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// add to cache
|
|
|
|
log.F("[dns] %s <-> %s, type: %d, %s: %s",
|
|
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
|
|
|
return
|
|
}
|
|
|
|
// SetServer .
|
|
func (c *Client) SetServer(domain string, servers ...string) {
|
|
c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...)
|
|
}
|
|
|
|
// GetServer .
|
|
func (c *Client) GetServer(domain string) string {
|
|
domainParts := strings.Split(domain, ".")
|
|
length := len(domainParts)
|
|
for i := length - 2; i >= 0; i-- {
|
|
domain := strings.Join(domainParts[i:length], ".")
|
|
|
|
if servers, ok := c.UPServerMap[domain]; ok {
|
|
return servers[0]
|
|
}
|
|
}
|
|
|
|
// TODO:
|
|
return c.UPServers[0]
|
|
}
|
|
|
|
// AddHandler .
|
|
func (c *Client) AddHandler(h HandleFunc) {
|
|
c.Handlers = append(c.Handlers, h)
|
|
}
|