mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 17:35:40 +08:00
dns: query in udp when client requests in udp and no forwarder specified
This commit is contained in:
parent
8d20331096
commit
9acaff5b4a
22
dns/cache.go
22
dns/cache.go
@ -42,20 +42,18 @@ func (c *Cache) Len() int {
|
||||
return len(c.m)
|
||||
}
|
||||
|
||||
// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
|
||||
// Put an item into cache, invalid after ttl seconds
|
||||
func (c *Cache) Put(k string, v []byte, ttl int) {
|
||||
if len(v) == 0 {
|
||||
return
|
||||
if len(v) != 0 {
|
||||
c.l.Lock()
|
||||
it, ok := c.m[k]
|
||||
if !ok {
|
||||
it = &item{value: v}
|
||||
c.m[k] = it
|
||||
}
|
||||
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
c.l.Unlock()
|
||||
}
|
||||
|
||||
c.l.Lock()
|
||||
it, ok := c.m[k]
|
||||
if !ok {
|
||||
it = &item{value: v}
|
||||
c.m[k] = it
|
||||
}
|
||||
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
|
||||
c.l.Unlock()
|
||||
}
|
||||
|
||||
// Get an item from cache
|
||||
|
120
dns/client.go
120
dns/client.go
@ -41,7 +41,7 @@ func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
|
||||
|
||||
// Exchange handles request msg and returns response msg
|
||||
// reqBytes = reqLen + reqMsg
|
||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
|
||||
req, err := UnmarshalMessage(reqBytes[2:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -58,37 +58,14 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
||||
}
|
||||
}
|
||||
|
||||
dnsServer := c.GetServer(req.Question.QNAME)
|
||||
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
|
||||
dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
|
||||
if err != nil {
|
||||
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
|
||||
return nil, err
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
|
||||
log.F("[dns] failed to write req message: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var respLen uint16
|
||||
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
|
||||
log.F("[dns] failed to read response length: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respBytes := make([]byte, respLen+2)
|
||||
binary.BigEndian.PutUint16(respBytes[:2], respLen)
|
||||
|
||||
_, err = io.ReadFull(rc, respBytes[2:])
|
||||
if err != nil {
|
||||
log.F("[dns] error in read respMsg %s\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
|
||||
log.F("[dns] %s <-> %s, type: %d, %s",
|
||||
clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
|
||||
log.F("[dns] %s <-> %s(%s), type: %d, %s",
|
||||
clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME)
|
||||
return respBytes, nil
|
||||
}
|
||||
|
||||
@ -104,28 +81,101 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
||||
for _, h := range c.handlers {
|
||||
h(resp.Question.QNAME, answer.IP)
|
||||
}
|
||||
|
||||
if answer.IP != "" {
|
||||
ips = append(ips, answer.IP)
|
||||
}
|
||||
|
||||
if answer.TTL != 0 {
|
||||
ttl = int(answer.TTL)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// add to cache
|
||||
c.cache.Put(getKey(resp.Question), respBytes, ttl)
|
||||
if len(ips) != 0 {
|
||||
c.cache.Put(getKey(resp.Question), respBytes, ttl)
|
||||
}
|
||||
|
||||
log.F("[dns] %s <-> %s, type: %d, %s: %s",
|
||||
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
||||
log.F("[dns] %s <-> %s(%s), type: %d, %s: %s",
|
||||
clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
||||
|
||||
return respBytes, nil
|
||||
}
|
||||
|
||||
// exchange choose a upstream dns server based on qname, communicate with it on the network
|
||||
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) {
|
||||
// use tcp to connect upstream server default
|
||||
network = "tcp"
|
||||
|
||||
dialer := c.dialer.NextDialer(qname + ":53")
|
||||
|
||||
// If client uses udp and no forwarders specified, use udp
|
||||
if !preferTCP && dialer.Addr() == "DIRECT" {
|
||||
network = "udp"
|
||||
}
|
||||
|
||||
server = c.GetServer(qname)
|
||||
rc, err := dialer.Dial(network, server)
|
||||
if err != nil {
|
||||
log.F("[dns] failed to connect to server %v: %v", server, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch network {
|
||||
case "tcp":
|
||||
respBytes, err = c.exchangeTCP(rc, reqBytes)
|
||||
case "udp":
|
||||
respBytes, err = c.exchangeUDP(rc, reqBytes)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// exchangeTCP exchange with server over tcp
|
||||
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
|
||||
defer rc.Close()
|
||||
|
||||
if _, err := rc.Write(reqBytes); err != nil {
|
||||
log.F("[dns] failed to write req message: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var respLen uint16
|
||||
if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
|
||||
log.F("[dns] failed to read response length: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
respBytes := make([]byte, respLen+2)
|
||||
binary.BigEndian.PutUint16(respBytes[:2], respLen)
|
||||
|
||||
_, err := io.ReadFull(rc, respBytes[2:])
|
||||
if err != nil {
|
||||
log.F("[dns] error in read respMsg %s\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return respBytes, nil
|
||||
}
|
||||
|
||||
// exchangeUDP exchange with server over udp
|
||||
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
|
||||
defer rc.Close()
|
||||
|
||||
if _, err := rc.Write(reqBytes[2:]); err != nil {
|
||||
log.F("[dns] failed to write req message: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBytes = make([]byte, 2+UDPMaxLen)
|
||||
n, err := rc.Read(reqBytes[2:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
binary.BigEndian.PutUint16(reqBytes[:2], uint16(n))
|
||||
|
||||
return reqBytes[:2+n], nil
|
||||
}
|
||||
|
||||
// SetServer .
|
||||
func (c *Client) SetServer(domain string, servers ...string) {
|
||||
c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
|
||||
@ -193,7 +243,7 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
|
||||
|
||||
m := NewMessage(0, Response)
|
||||
m.SetQuestion(NewQuestion(qtype, domain))
|
||||
rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
|
||||
rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
|
||||
RDLENGTH: rdlen, RDATA: rdata}
|
||||
m.AddAnswer(rr)
|
||||
|
||||
|
@ -31,8 +31,8 @@ const (
|
||||
QTypeAAAA uint16 = 28 ///ipv6
|
||||
)
|
||||
|
||||
// CLASSIN .
|
||||
const CLASSIN uint16 = 1
|
||||
// ClassINET .
|
||||
const ClassINET uint16 = 1
|
||||
|
||||
// Message format
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1
|
||||
@ -51,7 +51,7 @@ const CLASSIN uint16 = 1
|
||||
// +---------------------+
|
||||
// | Additional | RRs holding additional information
|
||||
type Message struct {
|
||||
*Header
|
||||
Header
|
||||
// most dns implementation only support 1 question
|
||||
Question *Question
|
||||
Answers []*RR
|
||||
@ -68,10 +68,10 @@ func NewMessage(id uint16, msgType int) *Message {
|
||||
id = uint16(rand.Uint32())
|
||||
}
|
||||
|
||||
h := &Header{ID: id}
|
||||
h.SetMsgType(msgType)
|
||||
m := &Message{Header: Header{ID: id}}
|
||||
m.SetMsgType(msgType)
|
||||
|
||||
return &Message{Header: h}
|
||||
return m
|
||||
}
|
||||
|
||||
// SetQuestion sets a question to dns message,
|
||||
@ -119,38 +119,35 @@ func (m *Message) Marshal() ([]byte, error) {
|
||||
|
||||
// UnmarshalMessage unmarshals []bytes to Message
|
||||
func UnmarshalMessage(b []byte) (*Message, error) {
|
||||
msg := &Message{Header: &Header{}}
|
||||
msg.unMarshaled = b
|
||||
|
||||
err := UnmarshalHeader(b[:HeaderLen], msg.Header)
|
||||
m := &Message{unMarshaled: b}
|
||||
err := UnmarshalHeader(b[:HeaderLen], &m.Header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := &Question{}
|
||||
qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
|
||||
qLen, err := m.UnmarshalQuestion(b[HeaderLen:], q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.SetQuestion(q)
|
||||
m.SetQuestion(q)
|
||||
|
||||
// resp answers
|
||||
rridx := HeaderLen + qLen
|
||||
for i := 0; i < int(msg.Header.ANCOUNT); i++ {
|
||||
rrIdx := HeaderLen + qLen
|
||||
for i := 0; i < int(m.Header.ANCOUNT); i++ {
|
||||
rr := &RR{}
|
||||
rrLen, err := msg.UnmarshalRR(rridx, rr)
|
||||
rrLen, err := m.UnmarshalRR(rrIdx, rr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg.AddAnswer(rr)
|
||||
m.AddAnswer(rr)
|
||||
|
||||
rridx += rrLen
|
||||
rrIdx += rrLen
|
||||
}
|
||||
|
||||
msg.Header.SetAncount(len(msg.Answers))
|
||||
m.Header.SetAncount(len(m.Answers))
|
||||
|
||||
return msg, nil
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Header format
|
||||
@ -262,7 +259,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
|
||||
return &Question{
|
||||
QNAME: domain,
|
||||
QTYPE: qtype,
|
||||
QCLASS: CLASSIN,
|
||||
QCLASS: ClassINET,
|
||||
}
|
||||
}
|
||||
|
||||
@ -428,7 +425,9 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
|
||||
if size == 0 {
|
||||
idx++
|
||||
break
|
||||
} else if size > 63 {
|
||||
}
|
||||
|
||||
if size > 63 {
|
||||
return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
|
||||
}
|
||||
|
||||
|
@ -61,11 +61,10 @@ func (s *Server) ListenAndServeUDP() {
|
||||
log.F("[dns] not enough message data")
|
||||
continue
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||
|
||||
go func() {
|
||||
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String())
|
||||
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String(), false)
|
||||
if err != nil {
|
||||
log.F("[dns] error in exchange: %s", err)
|
||||
return
|
||||
@ -123,7 +122,7 @@ func (s *Server) ServeTCP(c net.Conn) {
|
||||
|
||||
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||
|
||||
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
|
||||
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
|
||||
if err != nil {
|
||||
log.F("[dns]-tcp error in exchange: %s", err)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user