mirror of
https://github.com/nadoo/glider.git
synced 2025-04-21 19:52:07 +08:00
dns: add tcp server support
This commit is contained in:
parent
41ddbb1168
commit
4781e7b472
@ -48,7 +48,6 @@ const (
|
||||
// +---------------------+
|
||||
// | Additional | RRs holding additional information
|
||||
type Message struct {
|
||||
// all dns messages should start with a 12 byte dns header
|
||||
*Header
|
||||
// most dns implementation only support 1 question
|
||||
Question *Question
|
||||
@ -99,13 +98,13 @@ func (m *Message) Marshal() ([]byte, error) {
|
||||
}
|
||||
buf.Write(b)
|
||||
|
||||
// for _, answer := range m.Answers {
|
||||
// b, err := answer.Marshal()
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// buf.Write(b)
|
||||
// }
|
||||
for _, answer := range m.Answers {
|
||||
b, err := answer.Marshal()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(b)
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
@ -191,6 +190,11 @@ func (h *Header) SetQR(qr int) {
|
||||
h.Bits |= uint16(qr) << 15
|
||||
}
|
||||
|
||||
// SetTC .
|
||||
func (h *Header) SetTC(tc int) {
|
||||
h.Bits |= uint16(tc) << 9
|
||||
}
|
||||
|
||||
// SetQdcount sets query count, most dns servers only support 1 query per request
|
||||
func (h *Header) SetQdcount(qdcount int) {
|
||||
h.QDCOUNT = uint16(qdcount)
|
||||
@ -282,7 +286,11 @@ func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) {
|
||||
return 0, errors.New("unmarshal question must not be nil")
|
||||
}
|
||||
|
||||
domain, idx := m.UnmarshalDomain(b)
|
||||
domain, idx, err := m.UnmarshalDomain(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
q.QNAME = domain
|
||||
q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2])
|
||||
q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4])
|
||||
@ -335,19 +343,36 @@ func NewRR() *RR {
|
||||
return rr
|
||||
}
|
||||
|
||||
// Marshal marshals RR struct to []byte
|
||||
func (rr *RR) Marshal() ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
buf.Write(MarshalDomain(rr.NAME))
|
||||
binary.Write(&buf, binary.BigEndian, rr.TYPE)
|
||||
binary.Write(&buf, binary.BigEndian, rr.CLASS)
|
||||
binary.Write(&buf, binary.BigEndian, rr.TTL)
|
||||
binary.Write(&buf, binary.BigEndian, rr.RDLENGTH)
|
||||
buf.Write(rr.RDATA)
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalRR unmarshals []bytes to RR
|
||||
func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
|
||||
if rr == nil {
|
||||
return 0, errors.New("unmarshal question must not be nil")
|
||||
return 0, errors.New("unmarshal rr must not be nil")
|
||||
}
|
||||
|
||||
p := m.unMarshaled[start:]
|
||||
|
||||
domain, n := m.UnmarshalDomain(p)
|
||||
domain, n, err := m.UnmarshalDomain(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rr.NAME = domain
|
||||
|
||||
if len(p) <= n+10 {
|
||||
return 0, errors.New("not enough data")
|
||||
return 0, errors.New("UnmarshalRR: not enough data")
|
||||
}
|
||||
|
||||
rr.TYPE = binary.BigEndian.Uint16(p[n:])
|
||||
@ -381,7 +406,7 @@ func MarshalDomain(domain string) []byte {
|
||||
}
|
||||
|
||||
// UnmarshalDomain gets domain from bytes
|
||||
func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
||||
func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
|
||||
var idx, size int
|
||||
var labels = []string{}
|
||||
|
||||
@ -393,7 +418,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||
if b[idx]&0xC0 == 0xC0 {
|
||||
offset := binary.BigEndian.Uint16(b[idx : idx+2])
|
||||
lable := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
|
||||
lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
labels = append(labels, lable)
|
||||
idx += 2
|
||||
break
|
||||
@ -402,18 +431,28 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
||||
if size == 0 {
|
||||
idx++
|
||||
break
|
||||
} else if size > 63 {
|
||||
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
|
||||
}
|
||||
|
||||
if idx+size+1 > len(b) {
|
||||
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
|
||||
}
|
||||
|
||||
labels = append(labels, string(b[idx+1:idx+size+1]))
|
||||
idx += (size + 1)
|
||||
}
|
||||
}
|
||||
|
||||
domain := strings.Join(labels, ".")
|
||||
return domain, idx
|
||||
return domain, idx, nil
|
||||
}
|
||||
|
||||
// UnmarshalDomainPoint gets domain from offset point
|
||||
func (m *Message) UnmarshalDomainPoint(offset int) string {
|
||||
domain, _ := m.UnmarshalDomain(m.unMarshaled[offset:])
|
||||
return domain
|
||||
func (m *Message) UnmarshalDomainPoint(offset int) (string, error) {
|
||||
if offset > len(m.unMarshaled) {
|
||||
return "", errors.New("UnmarshalDomainPoint: offset larger than msg length")
|
||||
}
|
||||
domain, _, err := m.UnmarshalDomain(m.unMarshaled[offset:])
|
||||
return domain, err
|
||||
}
|
||||
|
@ -2,12 +2,17 @@ package dns
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/nadoo/glider/common/log"
|
||||
"github.com/nadoo/glider/proxy"
|
||||
)
|
||||
|
||||
// conn timeout, seconds
|
||||
const timeout = 30
|
||||
|
||||
// Server is a dns server struct
|
||||
type Server struct {
|
||||
addr string
|
||||
@ -28,6 +33,12 @@ func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server,
|
||||
|
||||
// ListenAndServe .
|
||||
func (s *Server) ListenAndServe() {
|
||||
go s.ListenAndServeTCP()
|
||||
s.ListenAndServeUDP()
|
||||
}
|
||||
|
||||
// ListenAndServeUDP .
|
||||
func (s *Server) ListenAndServeUDP() {
|
||||
c, err := net.ListenPacket("udp", s.addr)
|
||||
if err != nil {
|
||||
log.F("[dns] failed to listen on %s, error: %v", s.addr, err)
|
||||
@ -70,3 +81,56 @@ func (s *Server) ListenAndServe() {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// ListenAndServeTCP .
|
||||
func (s *Server) ListenAndServeTCP() {
|
||||
l, err := net.Listen("tcp", s.addr)
|
||||
if err != nil {
|
||||
log.F("[dns]-tcp error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.F("[dns]-tcp listening TCP on %s", s.addr)
|
||||
|
||||
for {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
log.F("[dns]-tcp error: failed to accept: %v", err)
|
||||
continue
|
||||
}
|
||||
go s.ServeTCP(c)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeTCP .
|
||||
func (s *Server) ServeTCP(c net.Conn) {
|
||||
defer c.Close()
|
||||
|
||||
c.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second))
|
||||
|
||||
var reqLen uint16
|
||||
if err := binary.Read(c, binary.BigEndian, &reqLen); err != nil {
|
||||
log.F("[dns]-tcp failed to get request length: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
reqBytes := make([]byte, reqLen+2)
|
||||
_, err := io.ReadFull(c, reqBytes[2:])
|
||||
if err != nil {
|
||||
log.F("[dns]-tcp error in read reqBytes %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||
|
||||
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
|
||||
if err != nil {
|
||||
log.F("[dns]-tcp error in exchange: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := binary.Write(c, binary.BigEndian, respBytes); err != nil {
|
||||
log.F("[dns]-tcp error in local write respBytes: %s", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user