diff --git a/dns/client.go b/dns/client.go index 1f7d201..47b32fd 100644 --- a/dns/client.go +++ b/dns/client.go @@ -90,7 +90,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, return } - ip := "" + ips := []string{} for _, answer := range respM.Answers { if answer.TYPE == QTypeA { for _, h := range c.Handlers { @@ -98,17 +98,16 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, } if answer.IP != "" { - ip += answer.IP + "," + ips = append(ips, answer.IP) } } - log.F("rr: %+v", answer) } // add to cache log.F("[dns] %s <-> %s, type: %d, %s: %s", - clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, ip) + clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ",")) return } diff --git a/dns/message.go b/dns/message.go index 82566bc..37299e7 100644 --- a/dns/message.go +++ b/dns/message.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "errors" - "fmt" "math/rand" "net" "strings" @@ -129,15 +128,16 @@ func UnmarshalMessage(b []byte, msg *Message) error { msg.SetQuestion(q) // resp answers - rrLen := 0 + rridx := HeaderLen + qLen for i := 0; i < int(msg.Header.ANCOUNT); i++ { rr := &RR{} - rrLen, err = msg.UnmarshalRR(HeaderLen+qLen+rrLen, rr) + rrLen, err := msg.UnmarshalRR(rridx, rr) if err != nil { return err } msg.AddAnswer(rr) - rrLen += rrLen + + rridx += rrLen } msg.Header.SetAncount(len(msg.Answers)) @@ -346,16 +346,9 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { // | 1 1| OFFSET | // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ p := m.unMarshaled[start:] - if p[0]>>6 == 3 { - offset := binary.BigEndian.Uint16(p[:2]) - rr.NAME = m.GetDomainByPoint(int(offset) & 0x3F) - n += 2 - } else { - // TODO: none compressed query name and Additional records will be ignored - return 0, nil - } - // domain, n := m.GetDomain(p) - // rr.NAME = domain + + domain, n := m.GetDomain(p) + rr.NAME = domain if len(p) <= n+10 { return 0, errors.New("not enough data") @@ -398,7 +391,6 @@ func (m *Message) GetDomain(b []byte) (string, int) { for { if b[idx]&0xC0 == 0xC0 { - fmt.Println("aaaaaaaaaaaaaa") offset := binary.BigEndian.Uint16(b[idx : idx+2]) lable := m.GetDomainByPoint(int(offset) & 0x3F) labels = append(labels, lable) @@ -415,7 +407,8 @@ func (m *Message) GetDomain(b []byte) (string, int) { } } - return strings.Join(labels, "."), idx + domain := strings.Join(labels, ".") + return domain, idx } // GetDomainByPoint gets domain from