dns: changed UnmarshalMessage to return *Messaeg

This commit is contained in:
nadoo 2018-07-30 01:05:08 +08:00
parent 5e32133eb9
commit f6a578f849
2 changed files with 31 additions and 25 deletions

View File

@ -36,23 +36,20 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
// Exchange handles request msg and returns response msg
// reqBytes = reqLen + reqMsg
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
reqMsg := reqBytes[2:]
reqM := NewMessage()
err = UnmarshalMessage(reqMsg, reqM)
req, err := UnmarshalMessage(reqBytes[2:])
if err != nil {
return
}
if reqM.Question.QTYPE == QTypeA || reqM.Question.QTYPE == QTypeAAAA {
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
// TODO: if query.QNAME in cache
// get respMsg from cache
// set msg id
// return respMsg, nil
}
dnsServer := c.GetServer(reqM.Question.QNAME)
rc, err := c.dialer.NextDialer(reqM.Question.QNAME+":53").Dial("tcp", dnsServer)
dnsServer := c.GetServer(req.Question.QNAME)
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
if err != nil {
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
return
@ -80,21 +77,20 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
return
}
if reqM.Question.QTYPE != QTypeA && reqM.Question.QTYPE != QTypeAAAA {
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
return
}
respM := NewMessage()
err = UnmarshalMessage(respMsg, respM)
resp, err := UnmarshalMessage(respMsg)
if err != nil {
return
}
ips := []string{}
for _, answer := range respM.Answers {
if answer.TYPE == QTypeA {
for _, answer := range resp.Answers {
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
for _, h := range c.Handlers {
h(reqM.Question.QNAME, answer.IP)
h(resp.Question.QNAME, answer.IP)
}
if answer.IP != "" {
@ -107,7 +103,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
// add to cache
log.F("[dns] %s <-> %s, type: %d, %s: %s",
clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ","))
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
return
}

View File

@ -3,7 +3,9 @@ package dns
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"net"
"strings"
@ -111,18 +113,21 @@ func (m *Message) Marshal() ([]byte, error) {
}
// UnmarshalMessage unmarshals []bytes to Message
func UnmarshalMessage(b []byte, msg *Message) error {
func UnmarshalMessage(b []byte) (*Message, error) {
msg := NewMessage()
msg.unMarshaled = b
fmt.Printf("msg.unMarshaled:\n%s\n", hex.Dump(msg.unMarshaled))
err := UnmarshalHeader(b[:HeaderLen], msg.Header)
if err != nil {
return err
return nil, err
}
q := &Question{}
qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
if err != nil {
return err
return nil, err
}
msg.SetQuestion(q)
@ -133,7 +138,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
rr := &RR{}
rrLen, err := msg.UnmarshalRR(rridx, rr)
if err != nil {
return err
return nil, err
}
msg.AddAnswer(rr)
@ -142,7 +147,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
msg.Header.SetAncount(len(msg.Answers))
return nil
return msg, nil
}
// Header format
@ -340,13 +345,10 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
return 0, errors.New("unmarshal question must not be nil")
}
// https://tools.ietf.org/html/rfc1035#section-4.1.4
// "Message compression",
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | 1 1| OFFSET |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
p := m.unMarshaled[start:]
fmt.Printf("rr bytes:\n%s\n", hex.Dump(p[:10]))
domain, n := m.GetDomain(p)
rr.NAME = domain
@ -368,6 +370,8 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
n = n + 10 + int(rr.RDLENGTH)
fmt.Printf("rr: %+#v\n", rr)
return n, nil
}
@ -390,9 +394,14 @@ func (m *Message) GetDomain(b []byte) (string, int) {
var labels = []string{}
for {
// https://tools.ietf.org/html/rfc1035#section-4.1.4
// "Message compression",
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | 1 1| OFFSET |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
if b[idx]&0xC0 == 0xC0 {
offset := binary.BigEndian.Uint16(b[idx : idx+2])
lable := m.GetDomainByPoint(int(offset) & 0x3F)
lable := m.GetDomainByPoint(int(offset & 0x3F))
labels = append(labels, lable)
idx += 2
break
@ -414,5 +423,6 @@ func (m *Message) GetDomain(b []byte) (string, int) {
// GetDomainByPoint gets domain from
func (m *Message) GetDomainByPoint(offset int) string {
domain, _ := m.GetDomain(m.unMarshaled[offset:])
fmt.Printf("GetDomainByPoint: %02x\n", offset)
return domain
}