dns: fixed a bug in UnmarshalRR

This commit is contained in:
nadoo 2018-07-30 00:18:10 +08:00
parent d5e3ea539a
commit 5e32133eb9
2 changed files with 12 additions and 20 deletions

View File

@ -90,7 +90,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
return return
} }
ip := "" ips := []string{}
for _, answer := range respM.Answers { for _, answer := range respM.Answers {
if answer.TYPE == QTypeA { if answer.TYPE == QTypeA {
for _, h := range c.Handlers { for _, h := range c.Handlers {
@ -98,17 +98,16 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
} }
if answer.IP != "" { if answer.IP != "" {
ip += answer.IP + "," ips = append(ips, answer.IP)
} }
} }
log.F("rr: %+v", answer)
} }
// add to cache // add to cache
log.F("[dns] %s <-> %s, type: %d, %s: %s", log.F("[dns] %s <-> %s, type: %d, %s: %s",
clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, ip) clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ","))
return return
} }

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"net" "net"
"strings" "strings"
@ -129,15 +128,16 @@ func UnmarshalMessage(b []byte, msg *Message) error {
msg.SetQuestion(q) msg.SetQuestion(q)
// resp answers // resp answers
rrLen := 0 rridx := HeaderLen + qLen
for i := 0; i < int(msg.Header.ANCOUNT); i++ { for i := 0; i < int(msg.Header.ANCOUNT); i++ {
rr := &RR{} rr := &RR{}
rrLen, err = msg.UnmarshalRR(HeaderLen+qLen+rrLen, rr) rrLen, err := msg.UnmarshalRR(rridx, rr)
if err != nil { if err != nil {
return err return err
} }
msg.AddAnswer(rr) msg.AddAnswer(rr)
rrLen += rrLen
rridx += rrLen
} }
msg.Header.SetAncount(len(msg.Answers)) msg.Header.SetAncount(len(msg.Answers))
@ -346,16 +346,9 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
// | 1 1| OFFSET | // | 1 1| OFFSET |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
p := m.unMarshaled[start:] p := m.unMarshaled[start:]
if p[0]>>6 == 3 {
offset := binary.BigEndian.Uint16(p[:2]) domain, n := m.GetDomain(p)
rr.NAME = m.GetDomainByPoint(int(offset) & 0x3F) rr.NAME = domain
n += 2
} else {
// TODO: none compressed query name and Additional records will be ignored
return 0, nil
}
// domain, n := m.GetDomain(p)
// rr.NAME = domain
if len(p) <= n+10 { if len(p) <= n+10 {
return 0, errors.New("not enough data") return 0, errors.New("not enough data")
@ -398,7 +391,6 @@ func (m *Message) GetDomain(b []byte) (string, int) {
for { for {
if b[idx]&0xC0 == 0xC0 { if b[idx]&0xC0 == 0xC0 {
fmt.Println("aaaaaaaaaaaaaa")
offset := binary.BigEndian.Uint16(b[idx : idx+2]) offset := binary.BigEndian.Uint16(b[idx : idx+2])
lable := m.GetDomainByPoint(int(offset) & 0x3F) lable := m.GetDomainByPoint(int(offset) & 0x3F)
labels = append(labels, lable) labels = append(labels, lable)
@ -415,7 +407,8 @@ func (m *Message) GetDomain(b []byte) (string, int) {
} }
} }
return strings.Join(labels, "."), idx domain := strings.Join(labels, ".")
return domain, idx
} }
// GetDomainByPoint gets domain from // GetDomainByPoint gets domain from