mirror of
				https://github.com/nadoo/glider.git
				synced 2025-11-04 07:42:38 +08:00 
			
		
		
		
	dns: query in udp when client requests in udp and no forwarder specified
This commit is contained in:
		
							parent
							
								
									8d20331096
								
							
						
					
					
						commit
						9acaff5b4a
					
				
							
								
								
									
										22
									
								
								dns/cache.go
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								dns/cache.go
									
									
									
									
									
								
							@ -42,20 +42,18 @@ func (c *Cache) Len() int {
 | 
			
		||||
	return len(c.m)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
 | 
			
		||||
// Put an item into cache, invalid after ttl seconds
 | 
			
		||||
func (c *Cache) Put(k string, v []byte, ttl int) {
 | 
			
		||||
	if len(v) == 0 {
 | 
			
		||||
		return
 | 
			
		||||
	if len(v) != 0 {
 | 
			
		||||
		c.l.Lock()
 | 
			
		||||
		it, ok := c.m[k]
 | 
			
		||||
		if !ok {
 | 
			
		||||
			it = &item{value: v}
 | 
			
		||||
			c.m[k] = it
 | 
			
		||||
		}
 | 
			
		||||
		it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
 | 
			
		||||
		c.l.Unlock()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.l.Lock()
 | 
			
		||||
	it, ok := c.m[k]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		it = &item{value: v}
 | 
			
		||||
		c.m[k] = it
 | 
			
		||||
	}
 | 
			
		||||
	it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
 | 
			
		||||
	c.l.Unlock()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get an item from cache
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										120
									
								
								dns/client.go
									
									
									
									
									
								
							
							
						
						
									
										120
									
								
								dns/client.go
									
									
									
									
									
								
							@ -41,7 +41,7 @@ func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
 | 
			
		||||
 | 
			
		||||
// Exchange handles request msg and returns response msg
 | 
			
		||||
// reqBytes = reqLen + reqMsg
 | 
			
		||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
 | 
			
		||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([]byte, error) {
 | 
			
		||||
	req, err := UnmarshalMessage(reqBytes[2:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@ -58,37 +58,14 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dnsServer := c.GetServer(req.Question.QNAME)
 | 
			
		||||
	rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
 | 
			
		||||
	dnsServer, network, respBytes, err := c.exchange(req.Question.QNAME, reqBytes, preferTCP)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer rc.Close()
 | 
			
		||||
 | 
			
		||||
	if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
 | 
			
		||||
		log.F("[dns] failed to write req message: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var respLen uint16
 | 
			
		||||
	if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
 | 
			
		||||
		log.F("[dns] failed to read response length: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBytes := make([]byte, respLen+2)
 | 
			
		||||
	binary.BigEndian.PutUint16(respBytes[:2], respLen)
 | 
			
		||||
 | 
			
		||||
	_, err = io.ReadFull(rc, respBytes[2:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.F("[dns] error in read respMsg %s\n", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
 | 
			
		||||
		log.F("[dns] %s <-> %s, type: %d, %s",
 | 
			
		||||
			clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
 | 
			
		||||
		log.F("[dns] %s <-> %s(%s), type: %d, %s",
 | 
			
		||||
			clientAddr, dnsServer, network, req.Question.QTYPE, req.Question.QNAME)
 | 
			
		||||
		return respBytes, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@ -104,28 +81,101 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
 | 
			
		||||
			for _, h := range c.handlers {
 | 
			
		||||
				h(resp.Question.QNAME, answer.IP)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if answer.IP != "" {
 | 
			
		||||
				ips = append(ips, answer.IP)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if answer.TTL != 0 {
 | 
			
		||||
				ttl = int(answer.TTL)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// add to cache
 | 
			
		||||
	c.cache.Put(getKey(resp.Question), respBytes, ttl)
 | 
			
		||||
	if len(ips) != 0 {
 | 
			
		||||
		c.cache.Put(getKey(resp.Question), respBytes, ttl)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.F("[dns] %s <-> %s, type: %d, %s: %s",
 | 
			
		||||
		clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
 | 
			
		||||
	log.F("[dns] %s <-> %s(%s), type: %d, %s: %s",
 | 
			
		||||
		clientAddr, dnsServer, network, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
 | 
			
		||||
 | 
			
		||||
	return respBytes, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// exchange choose a upstream dns server based on qname, communicate with it on the network
 | 
			
		||||
func (c *Client) exchange(qname string, reqBytes []byte, preferTCP bool) (server, network string, respBytes []byte, err error) {
 | 
			
		||||
	// use tcp to connect upstream server default
 | 
			
		||||
	network = "tcp"
 | 
			
		||||
 | 
			
		||||
	dialer := c.dialer.NextDialer(qname + ":53")
 | 
			
		||||
 | 
			
		||||
	// If client uses udp and no forwarders specified, use udp
 | 
			
		||||
	if !preferTCP && dialer.Addr() == "DIRECT" {
 | 
			
		||||
		network = "udp"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	server = c.GetServer(qname)
 | 
			
		||||
	rc, err := dialer.Dial(network, server)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.F("[dns] failed to connect to server %v: %v", server, err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	switch network {
 | 
			
		||||
	case "tcp":
 | 
			
		||||
		respBytes, err = c.exchangeTCP(rc, reqBytes)
 | 
			
		||||
	case "udp":
 | 
			
		||||
		respBytes, err = c.exchangeUDP(rc, reqBytes)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// exchangeTCP exchange with server over tcp
 | 
			
		||||
func (c *Client) exchangeTCP(rc net.Conn, reqBytes []byte) ([]byte, error) {
 | 
			
		||||
	defer rc.Close()
 | 
			
		||||
 | 
			
		||||
	if _, err := rc.Write(reqBytes); err != nil {
 | 
			
		||||
		log.F("[dns] failed to write req message: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var respLen uint16
 | 
			
		||||
	if err := binary.Read(rc, binary.BigEndian, &respLen); err != nil {
 | 
			
		||||
		log.F("[dns] failed to read response length: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	respBytes := make([]byte, respLen+2)
 | 
			
		||||
	binary.BigEndian.PutUint16(respBytes[:2], respLen)
 | 
			
		||||
 | 
			
		||||
	_, err := io.ReadFull(rc, respBytes[2:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.F("[dns] error in read respMsg %s\n", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return respBytes, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// exchangeUDP exchange with server over udp
 | 
			
		||||
func (c *Client) exchangeUDP(rc net.Conn, reqBytes []byte) ([]byte, error) {
 | 
			
		||||
	defer rc.Close()
 | 
			
		||||
 | 
			
		||||
	if _, err := rc.Write(reqBytes[2:]); err != nil {
 | 
			
		||||
		log.F("[dns] failed to write req message: %v", err)
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	reqBytes = make([]byte, 2+UDPMaxLen)
 | 
			
		||||
	n, err := rc.Read(reqBytes[2:])
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	binary.BigEndian.PutUint16(reqBytes[:2], uint16(n))
 | 
			
		||||
 | 
			
		||||
	return reqBytes[:2+n], nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetServer .
 | 
			
		||||
func (c *Client) SetServer(domain string, servers ...string) {
 | 
			
		||||
	c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
 | 
			
		||||
@ -193,7 +243,7 @@ func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
 | 
			
		||||
 | 
			
		||||
	m := NewMessage(0, Response)
 | 
			
		||||
	m.SetQuestion(NewQuestion(qtype, domain))
 | 
			
		||||
	rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
 | 
			
		||||
	rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET,
 | 
			
		||||
		RDLENGTH: rdlen, RDATA: rdata}
 | 
			
		||||
	m.AddAnswer(rr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -31,8 +31,8 @@ const (
 | 
			
		||||
	QTypeAAAA uint16 = 28 ///ipv6
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// CLASSIN .
 | 
			
		||||
const CLASSIN uint16 = 1
 | 
			
		||||
// ClassINET .
 | 
			
		||||
const ClassINET uint16 = 1
 | 
			
		||||
 | 
			
		||||
// Message format
 | 
			
		||||
// https://tools.ietf.org/html/rfc1035#section-4.1
 | 
			
		||||
@ -51,7 +51,7 @@ const CLASSIN uint16 = 1
 | 
			
		||||
//     +---------------------+
 | 
			
		||||
//     |      Additional     | RRs holding additional information
 | 
			
		||||
type Message struct {
 | 
			
		||||
	*Header
 | 
			
		||||
	Header
 | 
			
		||||
	// most dns implementation only support 1 question
 | 
			
		||||
	Question   *Question
 | 
			
		||||
	Answers    []*RR
 | 
			
		||||
@ -68,10 +68,10 @@ func NewMessage(id uint16, msgType int) *Message {
 | 
			
		||||
		id = uint16(rand.Uint32())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h := &Header{ID: id}
 | 
			
		||||
	h.SetMsgType(msgType)
 | 
			
		||||
	m := &Message{Header: Header{ID: id}}
 | 
			
		||||
	m.SetMsgType(msgType)
 | 
			
		||||
 | 
			
		||||
	return &Message{Header: h}
 | 
			
		||||
	return m
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SetQuestion sets a question to dns message,
 | 
			
		||||
@ -119,38 +119,35 @@ func (m *Message) Marshal() ([]byte, error) {
 | 
			
		||||
 | 
			
		||||
// UnmarshalMessage unmarshals []bytes to Message
 | 
			
		||||
func UnmarshalMessage(b []byte) (*Message, error) {
 | 
			
		||||
	msg := &Message{Header: &Header{}}
 | 
			
		||||
	msg.unMarshaled = b
 | 
			
		||||
 | 
			
		||||
	err := UnmarshalHeader(b[:HeaderLen], msg.Header)
 | 
			
		||||
	m := &Message{unMarshaled: b}
 | 
			
		||||
	err := UnmarshalHeader(b[:HeaderLen], &m.Header)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	q := &Question{}
 | 
			
		||||
	qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
 | 
			
		||||
	qLen, err := m.UnmarshalQuestion(b[HeaderLen:], q)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msg.SetQuestion(q)
 | 
			
		||||
	m.SetQuestion(q)
 | 
			
		||||
 | 
			
		||||
	// resp answers
 | 
			
		||||
	rridx := HeaderLen + qLen
 | 
			
		||||
	for i := 0; i < int(msg.Header.ANCOUNT); i++ {
 | 
			
		||||
	rrIdx := HeaderLen + qLen
 | 
			
		||||
	for i := 0; i < int(m.Header.ANCOUNT); i++ {
 | 
			
		||||
		rr := &RR{}
 | 
			
		||||
		rrLen, err := msg.UnmarshalRR(rridx, rr)
 | 
			
		||||
		rrLen, err := m.UnmarshalRR(rrIdx, rr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		msg.AddAnswer(rr)
 | 
			
		||||
		m.AddAnswer(rr)
 | 
			
		||||
 | 
			
		||||
		rridx += rrLen
 | 
			
		||||
		rrIdx += rrLen
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	msg.Header.SetAncount(len(msg.Answers))
 | 
			
		||||
	m.Header.SetAncount(len(m.Answers))
 | 
			
		||||
 | 
			
		||||
	return msg, nil
 | 
			
		||||
	return m, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Header format
 | 
			
		||||
@ -262,7 +259,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
 | 
			
		||||
	return &Question{
 | 
			
		||||
		QNAME:  domain,
 | 
			
		||||
		QTYPE:  qtype,
 | 
			
		||||
		QCLASS: CLASSIN,
 | 
			
		||||
		QCLASS: ClassINET,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -428,7 +425,9 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
 | 
			
		||||
			if size == 0 {
 | 
			
		||||
				idx++
 | 
			
		||||
				break
 | 
			
		||||
			} else if size > 63 {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if size > 63 {
 | 
			
		||||
				return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -61,11 +61,10 @@ func (s *Server) ListenAndServeUDP() {
 | 
			
		||||
			log.F("[dns] not enough message data")
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
 | 
			
		||||
 | 
			
		||||
		go func() {
 | 
			
		||||
			respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String())
 | 
			
		||||
			respBytes, err := s.Client.Exchange(reqBytes[:2+n], caddr.String(), false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.F("[dns] error in exchange: %s", err)
 | 
			
		||||
				return
 | 
			
		||||
@ -123,7 +122,7 @@ func (s *Server) ServeTCP(c net.Conn) {
 | 
			
		||||
 | 
			
		||||
	binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
 | 
			
		||||
 | 
			
		||||
	respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
 | 
			
		||||
	respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String(), true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.F("[dns]-tcp error in exchange: %s", err)
 | 
			
		||||
		return
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user