mirror of
				https://github.com/nadoo/glider.git
				synced 2025-11-03 23:32:37 +08:00 
			
		
		
		
	dns: changed UnmarshalMessage to return *Messaeg
This commit is contained in:
		
							parent
							
								
									5e32133eb9
								
							
						
					
					
						commit
						f6a578f849
					
				@ -36,23 +36,20 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
 | 
				
			|||||||
// Exchange handles request msg and returns response msg
 | 
					// Exchange handles request msg and returns response msg
 | 
				
			||||||
// reqBytes = reqLen + reqMsg
 | 
					// reqBytes = reqLen + reqMsg
 | 
				
			||||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
 | 
					func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
 | 
				
			||||||
	reqMsg := reqBytes[2:]
 | 
						req, err := UnmarshalMessage(reqBytes[2:])
 | 
				
			||||||
 | 
					 | 
				
			||||||
	reqM := NewMessage()
 | 
					 | 
				
			||||||
	err = UnmarshalMessage(reqMsg, reqM)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if reqM.Question.QTYPE == QTypeA || reqM.Question.QTYPE == QTypeAAAA {
 | 
						if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
 | 
				
			||||||
		// TODO: if query.QNAME in cache
 | 
							// TODO: if query.QNAME in cache
 | 
				
			||||||
		// get respMsg from cache
 | 
							// get respMsg from cache
 | 
				
			||||||
		// set msg id
 | 
							// set msg id
 | 
				
			||||||
		// return respMsg, nil
 | 
							// return respMsg, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dnsServer := c.GetServer(reqM.Question.QNAME)
 | 
						dnsServer := c.GetServer(req.Question.QNAME)
 | 
				
			||||||
	rc, err := c.dialer.NextDialer(reqM.Question.QNAME+":53").Dial("tcp", dnsServer)
 | 
						rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
 | 
							log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
@ -80,21 +77,20 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if reqM.Question.QTYPE != QTypeA && reqM.Question.QTYPE != QTypeAAAA {
 | 
						if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	respM := NewMessage()
 | 
						resp, err := UnmarshalMessage(respMsg)
 | 
				
			||||||
	err = UnmarshalMessage(respMsg, respM)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	ips := []string{}
 | 
						ips := []string{}
 | 
				
			||||||
	for _, answer := range respM.Answers {
 | 
						for _, answer := range resp.Answers {
 | 
				
			||||||
		if answer.TYPE == QTypeA {
 | 
							if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
 | 
				
			||||||
			for _, h := range c.Handlers {
 | 
								for _, h := range c.Handlers {
 | 
				
			||||||
				h(reqM.Question.QNAME, answer.IP)
 | 
									h(resp.Question.QNAME, answer.IP)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if answer.IP != "" {
 | 
								if answer.IP != "" {
 | 
				
			||||||
@ -107,7 +103,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
 | 
				
			|||||||
	// add to cache
 | 
						// add to cache
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.F("[dns] %s <-> %s, type: %d, %s: %s",
 | 
						log.F("[dns] %s <-> %s, type: %d, %s: %s",
 | 
				
			||||||
		clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ","))
 | 
							clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -3,7 +3,9 @@ package dns
 | 
				
			|||||||
import (
 | 
					import (
 | 
				
			||||||
	"bytes"
 | 
						"bytes"
 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"encoding/hex"
 | 
				
			||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
 | 
						"fmt"
 | 
				
			||||||
	"math/rand"
 | 
						"math/rand"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@ -111,18 +113,21 @@ func (m *Message) Marshal() ([]byte, error) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UnmarshalMessage unmarshals []bytes to Message
 | 
					// UnmarshalMessage unmarshals []bytes to Message
 | 
				
			||||||
func UnmarshalMessage(b []byte, msg *Message) error {
 | 
					func UnmarshalMessage(b []byte) (*Message, error) {
 | 
				
			||||||
 | 
						msg := NewMessage()
 | 
				
			||||||
	msg.unMarshaled = b
 | 
						msg.unMarshaled = b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("msg.unMarshaled:\n%s\n", hex.Dump(msg.unMarshaled))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := UnmarshalHeader(b[:HeaderLen], msg.Header)
 | 
						err := UnmarshalHeader(b[:HeaderLen], msg.Header)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	q := &Question{}
 | 
						q := &Question{}
 | 
				
			||||||
	qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
 | 
						qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	msg.SetQuestion(q)
 | 
						msg.SetQuestion(q)
 | 
				
			||||||
@ -133,7 +138,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
 | 
				
			|||||||
		rr := &RR{}
 | 
							rr := &RR{}
 | 
				
			||||||
		rrLen, err := msg.UnmarshalRR(rridx, rr)
 | 
							rrLen, err := msg.UnmarshalRR(rridx, rr)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return err
 | 
								return nil, err
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		msg.AddAnswer(rr)
 | 
							msg.AddAnswer(rr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -142,7 +147,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	msg.Header.SetAncount(len(msg.Answers))
 | 
						msg.Header.SetAncount(len(msg.Answers))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return nil
 | 
						return msg, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Header format
 | 
					// Header format
 | 
				
			||||||
@ -340,13 +345,10 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
 | 
				
			|||||||
		return 0, errors.New("unmarshal question must not be nil")
 | 
							return 0, errors.New("unmarshal question must not be nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// https://tools.ietf.org/html/rfc1035#section-4.1.4
 | 
					 | 
				
			||||||
	// "Message compression",
 | 
					 | 
				
			||||||
	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
					 | 
				
			||||||
	// | 1  1|                OFFSET                   |
 | 
					 | 
				
			||||||
	// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
					 | 
				
			||||||
	p := m.unMarshaled[start:]
 | 
						p := m.unMarshaled[start:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("rr bytes:\n%s\n", hex.Dump(p[:10]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	domain, n := m.GetDomain(p)
 | 
						domain, n := m.GetDomain(p)
 | 
				
			||||||
	rr.NAME = domain
 | 
						rr.NAME = domain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -368,6 +370,8 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	n = n + 10 + int(rr.RDLENGTH)
 | 
						n = n + 10 + int(rr.RDLENGTH)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						fmt.Printf("rr: %+#v\n", rr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return n, nil
 | 
						return n, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -390,9 +394,14 @@ func (m *Message) GetDomain(b []byte) (string, int) {
 | 
				
			|||||||
	var labels = []string{}
 | 
						var labels = []string{}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for {
 | 
						for {
 | 
				
			||||||
 | 
							// https://tools.ietf.org/html/rfc1035#section-4.1.4
 | 
				
			||||||
 | 
							// "Message compression",
 | 
				
			||||||
 | 
							// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
				
			||||||
 | 
							// | 1  1|                OFFSET                   |
 | 
				
			||||||
 | 
							// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
				
			||||||
		if b[idx]&0xC0 == 0xC0 {
 | 
							if b[idx]&0xC0 == 0xC0 {
 | 
				
			||||||
			offset := binary.BigEndian.Uint16(b[idx : idx+2])
 | 
								offset := binary.BigEndian.Uint16(b[idx : idx+2])
 | 
				
			||||||
			lable := m.GetDomainByPoint(int(offset) & 0x3F)
 | 
								lable := m.GetDomainByPoint(int(offset & 0x3F))
 | 
				
			||||||
			labels = append(labels, lable)
 | 
								labels = append(labels, lable)
 | 
				
			||||||
			idx += 2
 | 
								idx += 2
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
@ -414,5 +423,6 @@ func (m *Message) GetDomain(b []byte) (string, int) {
 | 
				
			|||||||
// GetDomainByPoint gets domain from
 | 
					// GetDomainByPoint gets domain from
 | 
				
			||||||
func (m *Message) GetDomainByPoint(offset int) string {
 | 
					func (m *Message) GetDomainByPoint(offset int) string {
 | 
				
			||||||
	domain, _ := m.GetDomain(m.unMarshaled[offset:])
 | 
						domain, _ := m.GetDomain(m.unMarshaled[offset:])
 | 
				
			||||||
 | 
						fmt.Printf("GetDomainByPoint: %02x\n", offset)
 | 
				
			||||||
	return domain
 | 
						return domain
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user