mirror of
https://github.com/nadoo/glider.git
synced 2025-02-23 17:35:40 +08:00
dns: changed UnmarshalMessage to return *Messaeg
This commit is contained in:
parent
5e32133eb9
commit
f6a578f849
@ -36,23 +36,20 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
|
||||
// Exchange handles request msg and returns response msg
|
||||
// reqBytes = reqLen + reqMsg
|
||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
|
||||
reqMsg := reqBytes[2:]
|
||||
|
||||
reqM := NewMessage()
|
||||
err = UnmarshalMessage(reqMsg, reqM)
|
||||
req, err := UnmarshalMessage(reqBytes[2:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if reqM.Question.QTYPE == QTypeA || reqM.Question.QTYPE == QTypeAAAA {
|
||||
if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
|
||||
// TODO: if query.QNAME in cache
|
||||
// get respMsg from cache
|
||||
// set msg id
|
||||
// return respMsg, nil
|
||||
}
|
||||
|
||||
dnsServer := c.GetServer(reqM.Question.QNAME)
|
||||
rc, err := c.dialer.NextDialer(reqM.Question.QNAME+":53").Dial("tcp", dnsServer)
|
||||
dnsServer := c.GetServer(req.Question.QNAME)
|
||||
rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
|
||||
if err != nil {
|
||||
log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
|
||||
return
|
||||
@ -80,21 +77,20 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
|
||||
return
|
||||
}
|
||||
|
||||
if reqM.Question.QTYPE != QTypeA && reqM.Question.QTYPE != QTypeAAAA {
|
||||
if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
|
||||
return
|
||||
}
|
||||
|
||||
respM := NewMessage()
|
||||
err = UnmarshalMessage(respMsg, respM)
|
||||
resp, err := UnmarshalMessage(respMsg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ips := []string{}
|
||||
for _, answer := range respM.Answers {
|
||||
if answer.TYPE == QTypeA {
|
||||
for _, answer := range resp.Answers {
|
||||
if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
|
||||
for _, h := range c.Handlers {
|
||||
h(reqM.Question.QNAME, answer.IP)
|
||||
h(resp.Question.QNAME, answer.IP)
|
||||
}
|
||||
|
||||
if answer.IP != "" {
|
||||
@ -107,7 +103,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte,
|
||||
// add to cache
|
||||
|
||||
log.F("[dns] %s <-> %s, type: %d, %s: %s",
|
||||
clientAddr, dnsServer, reqM.Question.QTYPE, reqM.Question.QNAME, strings.Join(ips, ","))
|
||||
clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -3,7 +3,9 @@ package dns
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
@ -111,18 +113,21 @@ func (m *Message) Marshal() ([]byte, error) {
|
||||
}
|
||||
|
||||
// UnmarshalMessage unmarshals []bytes to Message
|
||||
func UnmarshalMessage(b []byte, msg *Message) error {
|
||||
func UnmarshalMessage(b []byte) (*Message, error) {
|
||||
msg := NewMessage()
|
||||
msg.unMarshaled = b
|
||||
|
||||
fmt.Printf("msg.unMarshaled:\n%s\n", hex.Dump(msg.unMarshaled))
|
||||
|
||||
err := UnmarshalHeader(b[:HeaderLen], msg.Header)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := &Question{}
|
||||
qLen, err := msg.UnmarshalQuestion(b[HeaderLen:], q)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg.SetQuestion(q)
|
||||
@ -133,7 +138,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
|
||||
rr := &RR{}
|
||||
rrLen, err := msg.UnmarshalRR(rridx, rr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
msg.AddAnswer(rr)
|
||||
|
||||
@ -142,7 +147,7 @@ func UnmarshalMessage(b []byte, msg *Message) error {
|
||||
|
||||
msg.Header.SetAncount(len(msg.Answers))
|
||||
|
||||
return nil
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
// Header format
|
||||
@ -340,13 +345,10 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
|
||||
return 0, errors.New("unmarshal question must not be nil")
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.4
|
||||
// "Message compression",
|
||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
// | 1 1| OFFSET |
|
||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
p := m.unMarshaled[start:]
|
||||
|
||||
fmt.Printf("rr bytes:\n%s\n", hex.Dump(p[:10]))
|
||||
|
||||
domain, n := m.GetDomain(p)
|
||||
rr.NAME = domain
|
||||
|
||||
@ -368,6 +370,8 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
|
||||
|
||||
n = n + 10 + int(rr.RDLENGTH)
|
||||
|
||||
fmt.Printf("rr: %+#v\n", rr)
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
@ -390,9 +394,14 @@ func (m *Message) GetDomain(b []byte) (string, int) {
|
||||
var labels = []string{}
|
||||
|
||||
for {
|
||||
// https://tools.ietf.org/html/rfc1035#section-4.1.4
|
||||
// "Message compression",
|
||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
// | 1 1| OFFSET |
|
||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
if b[idx]&0xC0 == 0xC0 {
|
||||
offset := binary.BigEndian.Uint16(b[idx : idx+2])
|
||||
lable := m.GetDomainByPoint(int(offset) & 0x3F)
|
||||
lable := m.GetDomainByPoint(int(offset & 0x3F))
|
||||
labels = append(labels, lable)
|
||||
idx += 2
|
||||
break
|
||||
@ -414,5 +423,6 @@ func (m *Message) GetDomain(b []byte) (string, int) {
|
||||
// GetDomainByPoint gets domain from
|
||||
func (m *Message) GetDomainByPoint(offset int) string {
|
||||
domain, _ := m.GetDomain(m.unMarshaled[offset:])
|
||||
fmt.Printf("GetDomainByPoint: %02x\n", offset)
|
||||
return domain
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user