mirror of
https://github.com/nadoo/glider.git
synced 2025-04-21 19:52:07 +08:00
dns: add tcp server support
This commit is contained in:
parent
41ddbb1168
commit
4781e7b472
@ -48,7 +48,6 @@ const (
|
|||||||
// +---------------------+
|
// +---------------------+
|
||||||
// | Additional | RRs holding additional information
|
// | Additional | RRs holding additional information
|
||||||
type Message struct {
|
type Message struct {
|
||||||
// all dns messages should start with a 12 byte dns header
|
|
||||||
*Header
|
*Header
|
||||||
// most dns implementation only support 1 question
|
// most dns implementation only support 1 question
|
||||||
Question *Question
|
Question *Question
|
||||||
@ -99,13 +98,13 @@ func (m *Message) Marshal() ([]byte, error) {
|
|||||||
}
|
}
|
||||||
buf.Write(b)
|
buf.Write(b)
|
||||||
|
|
||||||
// for _, answer := range m.Answers {
|
for _, answer := range m.Answers {
|
||||||
// b, err := answer.Marshal()
|
b, err := answer.Marshal()
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// return nil, err
|
return nil, err
|
||||||
// }
|
}
|
||||||
// buf.Write(b)
|
buf.Write(b)
|
||||||
// }
|
}
|
||||||
|
|
||||||
return buf.Bytes(), nil
|
return buf.Bytes(), nil
|
||||||
}
|
}
|
||||||
@ -191,6 +190,11 @@ func (h *Header) SetQR(qr int) {
|
|||||||
h.Bits |= uint16(qr) << 15
|
h.Bits |= uint16(qr) << 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetTC .
|
||||||
|
func (h *Header) SetTC(tc int) {
|
||||||
|
h.Bits |= uint16(tc) << 9
|
||||||
|
}
|
||||||
|
|
||||||
// SetQdcount sets query count, most dns servers only support 1 query per request
|
// SetQdcount sets query count, most dns servers only support 1 query per request
|
||||||
func (h *Header) SetQdcount(qdcount int) {
|
func (h *Header) SetQdcount(qdcount int) {
|
||||||
h.QDCOUNT = uint16(qdcount)
|
h.QDCOUNT = uint16(qdcount)
|
||||||
@ -282,7 +286,11 @@ func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) {
|
|||||||
return 0, errors.New("unmarshal question must not be nil")
|
return 0, errors.New("unmarshal question must not be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
domain, idx := m.UnmarshalDomain(b)
|
domain, idx, err := m.UnmarshalDomain(b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
q.QNAME = domain
|
q.QNAME = domain
|
||||||
q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2])
|
q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2])
|
||||||
q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4])
|
q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4])
|
||||||
@ -335,19 +343,36 @@ func NewRR() *RR {
|
|||||||
return rr
|
return rr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Marshal marshals RR struct to []byte
|
||||||
|
func (rr *RR) Marshal() ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
|
||||||
|
buf.Write(MarshalDomain(rr.NAME))
|
||||||
|
binary.Write(&buf, binary.BigEndian, rr.TYPE)
|
||||||
|
binary.Write(&buf, binary.BigEndian, rr.CLASS)
|
||||||
|
binary.Write(&buf, binary.BigEndian, rr.TTL)
|
||||||
|
binary.Write(&buf, binary.BigEndian, rr.RDLENGTH)
|
||||||
|
buf.Write(rr.RDATA)
|
||||||
|
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// UnmarshalRR unmarshals []bytes to RR
|
// UnmarshalRR unmarshals []bytes to RR
|
||||||
func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
|
func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) {
|
||||||
if rr == nil {
|
if rr == nil {
|
||||||
return 0, errors.New("unmarshal question must not be nil")
|
return 0, errors.New("unmarshal rr must not be nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
p := m.unMarshaled[start:]
|
p := m.unMarshaled[start:]
|
||||||
|
|
||||||
domain, n := m.UnmarshalDomain(p)
|
domain, n, err := m.UnmarshalDomain(p)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
rr.NAME = domain
|
rr.NAME = domain
|
||||||
|
|
||||||
if len(p) <= n+10 {
|
if len(p) <= n+10 {
|
||||||
return 0, errors.New("not enough data")
|
return 0, errors.New("UnmarshalRR: not enough data")
|
||||||
}
|
}
|
||||||
|
|
||||||
rr.TYPE = binary.BigEndian.Uint16(p[n:])
|
rr.TYPE = binary.BigEndian.Uint16(p[n:])
|
||||||
@ -381,7 +406,7 @@ func MarshalDomain(domain string) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalDomain gets domain from bytes
|
// UnmarshalDomain gets domain from bytes
|
||||||
func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
|
||||||
var idx, size int
|
var idx, size int
|
||||||
var labels = []string{}
|
var labels = []string{}
|
||||||
|
|
||||||
@ -393,7 +418,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
|||||||
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
|
||||||
if b[idx]&0xC0 == 0xC0 {
|
if b[idx]&0xC0 == 0xC0 {
|
||||||
offset := binary.BigEndian.Uint16(b[idx : idx+2])
|
offset := binary.BigEndian.Uint16(b[idx : idx+2])
|
||||||
lable := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
|
lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
|
||||||
|
if err != nil {
|
||||||
|
return "", 0, err
|
||||||
|
}
|
||||||
|
|
||||||
labels = append(labels, lable)
|
labels = append(labels, lable)
|
||||||
idx += 2
|
idx += 2
|
||||||
break
|
break
|
||||||
@ -402,18 +431,28 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int) {
|
|||||||
if size == 0 {
|
if size == 0 {
|
||||||
idx++
|
idx++
|
||||||
break
|
break
|
||||||
|
} else if size > 63 {
|
||||||
|
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if idx+size+1 > len(b) {
|
||||||
|
return "", 0, errors.New("UnmarshalDomain: label length more than 63")
|
||||||
|
}
|
||||||
|
|
||||||
labels = append(labels, string(b[idx+1:idx+size+1]))
|
labels = append(labels, string(b[idx+1:idx+size+1]))
|
||||||
idx += (size + 1)
|
idx += (size + 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
domain := strings.Join(labels, ".")
|
domain := strings.Join(labels, ".")
|
||||||
return domain, idx
|
return domain, idx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalDomainPoint gets domain from offset point
|
// UnmarshalDomainPoint gets domain from offset point
|
||||||
func (m *Message) UnmarshalDomainPoint(offset int) string {
|
func (m *Message) UnmarshalDomainPoint(offset int) (string, error) {
|
||||||
domain, _ := m.UnmarshalDomain(m.unMarshaled[offset:])
|
if offset > len(m.unMarshaled) {
|
||||||
return domain
|
return "", errors.New("UnmarshalDomainPoint: offset larger than msg length")
|
||||||
|
}
|
||||||
|
domain, _, err := m.UnmarshalDomain(m.unMarshaled[offset:])
|
||||||
|
return domain, err
|
||||||
}
|
}
|
||||||
|
@ -2,12 +2,17 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/nadoo/glider/common/log"
|
"github.com/nadoo/glider/common/log"
|
||||||
"github.com/nadoo/glider/proxy"
|
"github.com/nadoo/glider/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// conn timeout, seconds
|
||||||
|
const timeout = 30
|
||||||
|
|
||||||
// Server is a dns server struct
|
// Server is a dns server struct
|
||||||
type Server struct {
|
type Server struct {
|
||||||
addr string
|
addr string
|
||||||
@ -28,6 +33,12 @@ func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server,
|
|||||||
|
|
||||||
// ListenAndServe .
|
// ListenAndServe .
|
||||||
func (s *Server) ListenAndServe() {
|
func (s *Server) ListenAndServe() {
|
||||||
|
go s.ListenAndServeTCP()
|
||||||
|
s.ListenAndServeUDP()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServeUDP .
|
||||||
|
func (s *Server) ListenAndServeUDP() {
|
||||||
c, err := net.ListenPacket("udp", s.addr)
|
c, err := net.ListenPacket("udp", s.addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.F("[dns] failed to listen on %s, error: %v", s.addr, err)
|
log.F("[dns] failed to listen on %s, error: %v", s.addr, err)
|
||||||
@ -70,3 +81,56 @@ func (s *Server) ListenAndServe() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListenAndServeTCP .
|
||||||
|
func (s *Server) ListenAndServeTCP() {
|
||||||
|
l, err := net.Listen("tcp", s.addr)
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns]-tcp error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.F("[dns]-tcp listening TCP on %s", s.addr)
|
||||||
|
|
||||||
|
for {
|
||||||
|
c, err := l.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns]-tcp error: failed to accept: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go s.ServeTCP(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeTCP .
|
||||||
|
func (s *Server) ServeTCP(c net.Conn) {
|
||||||
|
defer c.Close()
|
||||||
|
|
||||||
|
c.SetDeadline(time.Now().Add(time.Duration(timeout) * time.Second))
|
||||||
|
|
||||||
|
var reqLen uint16
|
||||||
|
if err := binary.Read(c, binary.BigEndian, &reqLen); err != nil {
|
||||||
|
log.F("[dns]-tcp failed to get request length: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBytes := make([]byte, reqLen+2)
|
||||||
|
_, err := io.ReadFull(c, reqBytes[2:])
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns]-tcp error in read reqBytes %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(reqBytes[:2], reqLen)
|
||||||
|
|
||||||
|
respBytes, err := s.Exchange(reqBytes, c.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
log.F("[dns]-tcp error in exchange: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(c, binary.BigEndian, respBytes); err != nil {
|
||||||
|
log.F("[dns]-tcp error in local write respBytes: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user