mirror of
https://github.com/nadoo/glider.git
synced 2025-02-24 01:45:39 +08:00
dns: query in udp when client requests in udp and no forwarder specified
This commit is contained in:
parent
8d20331096
commit
9acaff5b4a
22
dns/cache.go
22
dns/cache.go
@ -42,20 +42,18 @@ func (c *Cache) Len() int {
|
|||||||
return len(c.m)
|
return len(c.m)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
|
// Put an item into cache, invalid after ttl seconds
|
||||||
func (c *Cache) Put(k string, v []byte, ttl int) {
|
func (c *Cache) Put(k string, v []byte, ttl int) {
|
||||||
if len(v) == 0 {
|
if len(v) != 0 {
|
||||||
return
|
c.l.Lock()
|
||||||
|
it, ok := c.m[k]
|
||||||
|
if !ok {
|
||||||
|
it = &item{value: v}
|
||||||
|
c.m[k] = it
|
||||||
|
}
|
||||||
|
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
|
||||||
|
c.l.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
c.l.Lock()
|
|
||||||
it, ok := c.m[k]
|
|
||||||
if !ok {
|
|
||||||
it = &item{value: v}
|
|
||||||
c.m[k] = it
|
|
||||||
}
|
|
||||||
it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
|
|
||||||
c.l.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get an item from cache
|
// Get an item from cache
|
||||||
|
120
dns/client.go
120
dns/client.go
@ -41,7 +41,7 @@ func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
|
|||||||
|
|
||||||
// Exchange handles request msg and returns response msg
|
// Exchange handles request msg and returns response msg
|
||||||
// reqBytes = reqLen + reqMsg
|
// reqBytes = reqLen + reqMsg
|
||||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
|
||||||
req, err := UnmarshalMessage(reqBytes[2:])
|
req, err := UnmarshalMessage(reqBytes[2:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -58,37 +58,14 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsServer := c.GetServer(req.Question.QNAME)
|
dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
|
||||||
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rc.Close()
|
|
||||||
|
|
||||||
if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
|
|
||||||
log.F("[dns] failed to write req message: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var respLen uint16
|
|
||||||
if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
|
|
||||||
log.F("[dns] failed to read response length: %v", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
respBytes := make([]byte, respLen+2)
|
|
||||||
binary.BigEndian.PutUint16(respBytes[:2], respLen)
|
|
||||||
|
|
||||||
_, err = io.ReadFull(rc, respBytes[2:])
|
|
||||||
if err != nil {
|
|
||||||
log.F("[dns] error in read respMsg %s\n", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
|
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
|
||||||
log.F("[dns] %s <-> %s, type: %d, %s",
|
log.F("[dns] %s <-> %s(%s), type: %d, %s",
|
||||||
clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
|
clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME)
|
||||||
return respBytes, nil
|
return respBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,28 +81,101 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
|
|||||||
for _, h := range c.handlers {
|
for _, h := range c.handlers {
|
||||||
h(resp.Question.QNAME, answer.IP)
|
h(resp.Question.QNAME, answer.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if answer.IP != "" {
|
if answer.IP != "" {
|
||||||
ips = append(ips, answer.IP)
|
ips = append(ips, answer.IP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if answer.TTL != 0 {
|
if answer.TTL != 0 {
|
||||||
ttl = int(answer.TTL)
|
ttl = int(answer.TTL)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// add to cache
|
// add to cache
|
||||||
c.cache.Put(getKey(resp.Question), respBytes, ttl)
|
if len(ips) != 0 {
|
||||||
|
c.cache.Put(getKey(resp.Question), respBytes, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
log.F("[dns] %s <-> %s, type: %d, %s: %s",
|
log.F("[dns] %s <-> %s(%s), type: %d, %s: %s",
|
||||||
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
||||||
|
|
||||||
return respBytes, nil
|
return respBytes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// exchange choose a upstream dns server based on qname, communicate with it on the network
|
||||||
|
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) {
|
||||||
|
// use tcp to connect upstream server default
|
||||||
|
network = "tcp"
|
||||||
|
|
||||||
|
dialer := c.dialer.NextDialer(qname + ":53")
|
||||||
|
|
||||||
|
// If client uses udp and no forwarders specified, use udp
|
||||||
|
if !preferTCP && dialer.Addr() == "DIRECT" {
|
||||||
|
network = "udp"
|
||||||
|
}
|
||||||
|
|
||||||
|
server = c.GetServer(qname)
|
||||||
|
rc, err := dialer.Dial(network, server)
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns] failed to connect to server %v: %v", server, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch network {
|
||||||
|
case "tcp":
|
||||||
|
respBytes, err = c.exchangeTCP(rc, reqBytes)
|
||||||
|
case "udp":
|
||||||
|
respBytes, err = c.exchangeUDP(rc, reqBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeTCP exchange with server over tcp
|
||||||
|
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if _, err := rc.Write(reqBytes); err != nil {
|
||||||
|
log.F("[dns] failed to write req message: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var respLen uint16
|
||||||
|
if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
|
||||||
|
log.F("[dns] failed to read response length: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes := make([]byte, respLen+2)
|
||||||
|
binary.BigEndian.PutUint16(respBytes[:2], respLen)
|
||||||
|
|
||||||
|
_, err := io.ReadFull(rc, respBytes[2:])
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns] error in read respMsg %s\n", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return respBytes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// exchangeUDP exchange with server over udp
|
||||||
|
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
if _, err := rc.Write(reqBytes[2:]); err != nil {
|
||||||
|
log.F("[dns] failed to write req message: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBytes = make([]byte, 2+UDPMaxLen)
|
||||||
|
n, err := rc.Read(reqBytes[2:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint16(reqBytes[:2], uint16(n))
|
||||||
|
|
||||||
|
return reqBytes[:2+n], nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetServer .
|
// SetServer .
|
||||||
func (c *Client) SetServer(domain string, servers ...string) {
|
func (c *Client) SetServer(domain string, servers ...string) {
|
||||||
c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
|
c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
|
||||||
@ -193,7 +243,7 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
|
|||||||
|
|
||||||
m := NewMessage(0, Response)
|
m := NewMessage(0, Response)
|
||||||
m.SetQuestion(NewQuestion(qtype, domain))
|
m.SetQuestion(NewQuestion(qtype, domain))
|
||||||
rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
|
rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
|
||||||
RDLENGTH: rdlen, RDATA: rdata}
|
RDLENGTH: rdlen, RDATA: rdata}
|
||||||
m.AddAnswer(rr)
|
m.AddAnswer(rr)
|
||||||
|
|
||||||
|
@ -31,8 +31,8 @@ const (
|
|||||||
QTypeAAAA uint16 = 28 ///ipv6
|
QTypeAAAA uint16 = 28 ///ipv6
|
||||||
)
|
)
|
||||||
|
|
||||||
// CLASSIN .
|
// ClassINET .
|
||||||
const CLASSIN uint16 = 1
|
const ClassINET uint16 = 1
|
||||||
|
|
||||||
// Message format
|
// Message format
|
||||||
// https://tools.ietf.org/html/rfc1035#section-4.1
|
// https://tools.ietf.org/html/rfc1035#section-4.1
|
||||||
@ -51,7 +51,7 @@ const CLASSIN uint16 = 1
|
|||||||
// +---------------------+
|
// +---------------------+
|
||||||
// | Additional | RRs holding additional information
|
// | Additional | RRs holding additional information
|
||||||
type Message struct {
|
type Message struct {
|
||||||
*Header
|
Header
|
||||||
// most dns implementation only support 1 question
|
// most dns implementation only support 1 question
|
||||||
Question *Question
|
Question *Question
|
||||||
Answers []*RR
|
Answers []*RR
|
||||||
@ -68,10 +68,10 @@ func NewMessage(id uint16, msgType int) *Message {
|
|||||||
id = uint16(rand.Uint32())
|
id = uint16(rand.Uint32())
|
||||||
}
|
}
|
||||||
|
|
||||||
h := &Header{ID: id}
|
m := &Message{Header: Header{ID: id}}
|
||||||
h.SetMsgType(msgType)
|
m.SetMsgType(msgType)
|
||||||
|
|
||||||
return &Message{Header: h}
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetQuestion sets a question to dns message,
|
// SetQuestion sets a question to dns message,
|
||||||
@ -119,38 +119,35 @@ func (m *Message) Marshal() ([]byte, error) {
|
|||||||
|
|
||||||
// UnmarshalMessage unmarshals []bytes to Message
|
// UnmarshalMessage unmarshals []bytes to Message
|
||||||
func UnmarshalMessage(b []byte) (*Message, error) {
|
func UnmarshalMessage(b []byte) (*Message, error) {
|
||||||
msg := &Message{Header: &Header{}}
|
m := &Message{unMarshaled: b}
|
||||||
msg.unMarshaled = b
|
err := UnmarshalHeader(b[:HeaderLen], &m.Header)
|
||||||
|
|
||||||
err := UnmarshalHeader(b[:HeaderLen], msg.Header)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
q := &Question{}
|
q := &Question{}
|
||||||
qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
|
qLen, err := m.UnmarshalQuestion(b[HeaderLen:], q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
m.SetQuestion(q)
|
||||||
msg.SetQuestion(q)
|
|
||||||
|
|
||||||
// resp answers
|
// resp answers
|
||||||
rridx := HeaderLen + qLen
|
rrIdx := HeaderLen + qLen
|
||||||
for i := 0; i < int(msg.Header.ANCOUNT); i++ {
|
for i := 0; i < int(m.Header.ANCOUNT); i++ {
|
||||||
rr := &RR{}
|
rr := &RR{}
|
||||||
rrLen, err := msg.UnmarshalRR(rridx, rr)
|
rrLen, err := m.UnmarshalRR(rrIdx, rr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
msg.AddAnswer(rr)
|
m.AddAnswer(rr)
|
||||||
|
|
||||||
rridx += rrLen
|
rrIdx += rrLen
|
||||||
}
|
}
|
||||||
|
|
||||||
msg.Header.SetAncount(len(msg.Answers))
|
m.Header.SetAncount(len(m.Answers))
|
||||||
|
|
||||||
return msg, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Header format
|
// Header format
|
||||||
@ -262,7 +259,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
|
|||||||
return &Question{
|
return &Question{
|
||||||
QNAME: domain,
|
QNAME: domain,
|
||||||
QTYPE: qtype,
|
QTYPE: qtype,
|
||||||
QCLASS: CLASSIN,
|
QCLASS: ClassINET,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -428,7 +425,9 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
|
|||||||
if size == 0 {
|
if size == 0 {
|
||||||
idx++
|
idx++
|
||||||
break
|
break
|
||||||
} else if size > 63 {
|
}
|
||||||
|
|
||||||
|
if size > 63 {
|
||||||
return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
|
return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,11 +61,10 @@ func (s *Server) ListenAndServeUDP() {
|
|||||||
log.F("[dns] not enough message data")
|
log.F("[dns] not enough message data")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String())
|
respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String(), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[dns] error in exchange: %s", err)
|
log.F("[dns] error in exchange: %s", err)
|
||||||
return
|
return
|
||||||
@ -123,7 +122,7 @@ func (s *Server) ServeTCP(c net.Conn) {
|
|||||||
|
|
||||||
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||||
|
|
||||||
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
|
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[dns]-tcp error in exchange: %s", err)
|
log.F("[dns]-tcp error in exchange: %s", err)
|
||||||
return
|
return
|
||||||
|
Loading…
Reference in New Issue
Block a user