From f6a578f849139831a0ddad19c241c6cac92741fb Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Mon, 30 Jul 2018 01:05:08 +0800 Subject: [PATCH] dns: changed UnmarshalMessage to return *Messaeg --- dns/client.go | 24 ++++++++++-------------- dns/message.go | 32 +++++++++++++++++++++----------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/dns/client.go b/dns/client.go index 47b32fd..a5ae487 100644 --- a/dns/client.go +++ b/dns/client.go @@ -36,23 +36,20 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) { // Exchange handles request msg and returns response msg // reqBytes = reqLen + reqMsg func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) { - reqMsg := reqBytes[2:] - - reqM := NewMessage() - err = UnmarshalMessage(reqMsg, reqM) + req, err := UnmarshalMessage(reqBytes[2:]) if err != nil { return } - if reqM.Question.QTYPE == QTypeA || reqM.Question.QTYPE == QTypeAAAA { + if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA { // TODO: if query.QNAME in cache // get respMsg from cache // set msg id // return respMsg, nil } - dnsServer := c.GetServer(reqM.Question.QNAME) - rc, err := c.dialer.NextDialer(reqM.Question.QNAME+":53").Dial("tcp", dnsServer) + dnsServer := c.GetServer(req.Question.QNAME) + rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer) if err != nil { log.F("[dns] failed to connect to server %v: %v", dnsServer, err) return @@ -80,21 +77,20 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, return } - if reqM.Question.QTYPE != QTypeA && reqM.Question.QTYPE != QTypeAAAA { + if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA { return } - respM := NewMessage() - err = UnmarshalMessage(respMsg, respM) + resp, err := UnmarshalMessage(respMsg) if err != nil { return } ips := []string{} - for _, answer := range respM.Answers { - if answer.TYPE == QTypeA { + for _, answer := range resp.Answers { + if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA { for _, h := range c.Handlers { - h(reqM.Question.QNAME, answer.IP) + h(resp.Question.QNAME, answer.IP) } if answer.IP != "" { @@ -107,7 +103,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, // add to cache log.F("[dns] %s <-> %s, type: %d, %s: %s", - clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ",")) + clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ",")) return } diff --git a/dns/message.go b/dns/message.go index 37299e7..2a97cd8 100644 --- a/dns/message.go +++ b/dns/message.go @@ -3,7 +3,9 @@ package dns import ( "bytes" "encoding/binary" + "encoding/hex" "errors" + "fmt" "math/rand" "net" "strings" @@ -111,18 +113,21 @@ func (m *Message) Marshal() ([]byte, error) { } // UnmarshalMessage unmarshals []bytes to Message -func UnmarshalMessage(b []byte, msg *Message) error { +func UnmarshalMessage(b []byte) (*Message, error) { + msg := NewMessage() msg.unMarshaled = b + fmt.Printf("msg.unMarshaled:\n%s\n", hex.Dump(msg.unMarshaled)) + err := UnmarshalHeader(b[:HeaderLen], msg.Header) if err != nil { - return err + return nil, err } q := &Question{} qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q) if err != nil { - return err + return nil, err } msg.SetQuestion(q) @@ -133,7 +138,7 @@ func UnmarshalMessage(b []byte, msg *Message) error { rr := &RR{} rrLen, err := msg.UnmarshalRR(rridx, rr) if err != nil { - return err + return nil, err } msg.AddAnswer(rr) @@ -142,7 +147,7 @@ func UnmarshalMessage(b []byte, msg *Message) error { msg.Header.SetAncount(len(msg.Answers)) - return nil + return msg, nil } // Header format @@ -340,13 +345,10 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { return 0, errors.New("unmarshal question must not be nil") } - // https://tools.ietf.org/html/rfc1035#section-4.1.4 - // "Message compression", - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - // | 1 1| OFFSET | - // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ p := m.unMarshaled[start:] + fmt.Printf("rr bytes:\n%s\n", hex.Dump(p[:10])) + domain, n := m.GetDomain(p) rr.NAME = domain @@ -368,6 +370,8 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { n = n + 10 + int(rr.RDLENGTH) + fmt.Printf("rr: %+#v\n", rr) + return n, nil } @@ -390,9 +394,14 @@ func (m *Message) GetDomain(b []byte) (string, int) { var labels = []string{} for { + // https://tools.ietf.org/html/rfc1035#section-4.1.4 + // "Message compression", + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ + // | 1 1| OFFSET | + // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ if b[idx]&0xC0 == 0xC0 { offset := binary.BigEndian.Uint16(b[idx : idx+2]) - lable := m.GetDomainByPoint(int(offset) & 0x3F) + lable := m.GetDomainByPoint(int(offset & 0x3F)) labels = append(labels, lable) idx += 2 break @@ -414,5 +423,6 @@ func (m *Message) GetDomain(b []byte) (string, int) { // GetDomainByPoint gets domain from func (m *Message) GetDomainByPoint(offset int) string { domain, _ := m.GetDomain(m.unMarshaled[offset:]) + fmt.Printf("GetDomainByPoint: %02x\n", offset) return domain }