mirror of
				https://github.com/nadoo/glider.git
				synced 2025-11-04 07:42:38 +08:00 
			
		
		
		
	dns: 1. support cache; 2. support custom records; #35
This commit is contained in:
		
							parent
							
								
									a2a67df771
								
							
						
					
					
						commit
						a226637bfb
					
				
							
								
								
									
										2
									
								
								conf.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								conf.go
									
									
									
									
									
								
							@ -25,6 +25,7 @@ var conf struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	DNS       string
 | 
						DNS       string
 | 
				
			||||||
	DNSServer []string
 | 
						DNSServer []string
 | 
				
			||||||
 | 
						DNSRecord []string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	IPSet string
 | 
						IPSet string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -43,6 +44,7 @@ func confInit() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	flag.StringVar(&conf.DNS, "dns", "", "dns forwarder server listen address")
 | 
						flag.StringVar(&conf.DNS, "dns", "", "dns forwarder server listen address")
 | 
				
			||||||
	flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server")
 | 
						flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server")
 | 
				
			||||||
 | 
						flag.StringSliceUniqVar(&conf.DNSRecord, "dnsrecord", nil, "custom dns record")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	flag.StringVar(&conf.IPSet, "ipset", "", "ipset name")
 | 
						flag.StringVar(&conf.IPSet, "ipset", "", "ipset name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										69
									
								
								dns/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								dns/cache.go
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
				
			|||||||
 | 
					package dns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import (
 | 
				
			||||||
 | 
						"sync"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// HundredYears is one hundred years duration in seconds, used for none-expired items
 | 
				
			||||||
 | 
					const HundredYears = 100 * 365 * 24 * 3600
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type item struct {
 | 
				
			||||||
 | 
						value  []byte
 | 
				
			||||||
 | 
						expire time.Time
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Cache is the struct of cache
 | 
				
			||||||
 | 
					type Cache struct {
 | 
				
			||||||
 | 
						m map[string]*item
 | 
				
			||||||
 | 
						l sync.RWMutex
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// NewCache returns a new cache
 | 
				
			||||||
 | 
					func NewCache() (c *Cache) {
 | 
				
			||||||
 | 
						c = &Cache{m: make(map[string]*item)}
 | 
				
			||||||
 | 
						go func() {
 | 
				
			||||||
 | 
							for now := range time.Tick(time.Second) {
 | 
				
			||||||
 | 
								c.l.Lock()
 | 
				
			||||||
 | 
								for k, v := range c.m {
 | 
				
			||||||
 | 
									if now.After(v.expire) {
 | 
				
			||||||
 | 
										delete(c.m, k)
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								c.l.Unlock()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Len returns the length of cache
 | 
				
			||||||
 | 
					func (c *Cache) Len() int {
 | 
				
			||||||
 | 
						return len(c.m)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Put an item into cache, invalid after ttl seconds, never invalid if ttl=0
 | 
				
			||||||
 | 
					func (c *Cache) Put(k string, v []byte, ttl int) {
 | 
				
			||||||
 | 
						if len(v) == 0 {
 | 
				
			||||||
 | 
							return
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.l.Lock()
 | 
				
			||||||
 | 
						it, ok := c.m[k]
 | 
				
			||||||
 | 
						if !ok {
 | 
				
			||||||
 | 
							it = &item{value: v}
 | 
				
			||||||
 | 
							c.m[k] = it
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						it.expire = time.Now().Add(time.Duration(ttl) * time.Second)
 | 
				
			||||||
 | 
						c.l.Unlock()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Get an item from cache
 | 
				
			||||||
 | 
					func (c *Cache) Get(k string) (v []byte) {
 | 
				
			||||||
 | 
						c.l.RLock()
 | 
				
			||||||
 | 
						if it, ok := c.m[k]; ok {
 | 
				
			||||||
 | 
							v = it.value
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						c.l.RUnlock()
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										139
									
								
								dns/client.go
									
									
									
									
									
								
							
							
						
						
									
										139
									
								
								dns/client.go
									
									
									
									
									
								
							@ -1,33 +1,39 @@
 | 
				
			|||||||
package dns
 | 
					package dns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"bytes"
 | 
				
			||||||
	"encoding/binary"
 | 
						"encoding/binary"
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
 | 
						"net"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/nadoo/glider/common/log"
 | 
						"github.com/nadoo/glider/common/log"
 | 
				
			||||||
	"github.com/nadoo/glider/proxy"
 | 
						"github.com/nadoo/glider/proxy"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// DefaultTTL is default ttl in seconds
 | 
				
			||||||
 | 
					const DefaultTTL = 600
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// HandleFunc function handles the dns TypeA or TypeAAAA answer
 | 
					// HandleFunc function handles the dns TypeA or TypeAAAA answer
 | 
				
			||||||
type HandleFunc func(Domain, ip string) error
 | 
					type HandleFunc func(Domain, ip string) error
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client is a dns client struct
 | 
					// Client is a dns client struct
 | 
				
			||||||
type Client struct {
 | 
					type Client struct {
 | 
				
			||||||
	dialer      proxy.Dialer
 | 
						dialer      proxy.Dialer
 | 
				
			||||||
	UPServers   []string
 | 
						cache       *Cache
 | 
				
			||||||
	UPServerMap map[string][]string
 | 
						upServers   []string
 | 
				
			||||||
	Handlers    []HandleFunc
 | 
						upServerMap map[string][]string
 | 
				
			||||||
 | 
						handlers    []HandleFunc
 | 
				
			||||||
	tcp bool
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewClient returns a new dns client
 | 
					// NewClient returns a new dns client
 | 
				
			||||||
func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
 | 
					func NewClient(dialer proxy.Dialer, upServers []string) (*Client, error) {
 | 
				
			||||||
	c := &Client{
 | 
						c := &Client{
 | 
				
			||||||
		dialer:      dialer,
 | 
							dialer:      dialer,
 | 
				
			||||||
		UPServers:   upServers,
 | 
							cache:       NewCache(),
 | 
				
			||||||
		UPServerMap: make(map[string][]string),
 | 
							upServers:   upServers,
 | 
				
			||||||
 | 
							upServerMap: make(map[string][]string),
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return c, nil
 | 
						return c, nil
 | 
				
			||||||
@ -35,82 +41,96 @@ func NewClient(dialer proxy.Dialer, upServers ...string) (*Client, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// Exchange handles request msg and returns response msg
 | 
					// Exchange handles request msg and returns response msg
 | 
				
			||||||
// reqBytes = reqLen + reqMsg
 | 
					// reqBytes = reqLen + reqMsg
 | 
				
			||||||
func (c *Client) Exchange(reqBytes []byte, clientAddr string) (respBytes []byte, err error) {
 | 
					func (c *Client) Exchange(reqBytes []byte, clientAddr string) ([]byte, error) {
 | 
				
			||||||
	req, err := UnmarshalMessage(reqBytes[2:])
 | 
						req, err := UnmarshalMessage(reqBytes[2:])
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
 | 
						if req.Question.QTYPE == QTypeA || req.Question.QTYPE == QTypeAAAA {
 | 
				
			||||||
		// TODO: if query.QNAME in cache
 | 
							v := c.cache.Get(getKey(req.Question))
 | 
				
			||||||
		// get respMsg from cache
 | 
							if v != nil {
 | 
				
			||||||
		// set msg id
 | 
								binary.BigEndian.PutUint16(v[2:4], req.ID)
 | 
				
			||||||
		// return respMsg, nil
 | 
								log.F("[dns] %s <-> cache, type: %d, %s",
 | 
				
			||||||
 | 
									clientAddr, req.Question.QTYPE, req.Question.QNAME)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								return v, nil
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	dnsServer := c.GetServer(req.Question.QNAME)
 | 
						dnsServer := c.GetServer(req.Question.QNAME)
 | 
				
			||||||
	rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
 | 
						rc, err := c.dialer.NextDialer(req.Question.QNAME+":53").Dial("tcp", dnsServer)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
 | 
							log.F("[dns] failed to connect to server %v: %v", dnsServer, err)
 | 
				
			||||||
		return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer rc.Close()
 | 
						defer rc.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
 | 
						if err = binary.Write(rc, binary.BigEndian, reqBytes); err != nil {
 | 
				
			||||||
		log.F("[dns] failed to write req message: %v", err)
 | 
							log.F("[dns] failed to write req message: %v", err)
 | 
				
			||||||
		return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var respLen uint16
 | 
						var respLen uint16
 | 
				
			||||||
	if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
 | 
						if err = binary.Read(rc, binary.BigEndian, &respLen); err != nil {
 | 
				
			||||||
		log.F("[dns] failed to read response length: %v", err)
 | 
							log.F("[dns] failed to read response length: %v", err)
 | 
				
			||||||
		return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	respBytes = make([]byte, respLen+2)
 | 
						respBytes := make([]byte, respLen+2)
 | 
				
			||||||
	binary.BigEndian.PutUint16(respBytes[:2], respLen)
 | 
						binary.BigEndian.PutUint16(respBytes[:2], respLen)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	respMsg := respBytes[2:]
 | 
						_, err = io.ReadFull(rc, respBytes[2:])
 | 
				
			||||||
	_, err = io.ReadFull(rc, respMsg)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		log.F("[dns] error in read respMsg %s\n", err)
 | 
							log.F("[dns] error in read respMsg %s\n", err)
 | 
				
			||||||
		return
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
 | 
						if req.Question.QTYPE != QTypeA && req.Question.QTYPE != QTypeAAAA {
 | 
				
			||||||
		return
 | 
							log.F("[dns] %s <-> %s, type: %d, %s",
 | 
				
			||||||
 | 
								clientAddr, dnsServer, req.Question.QTYPE, req.Question.QNAME)
 | 
				
			||||||
 | 
							return respBytes, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	resp, err := UnmarshalMessage(respMsg)
 | 
						resp, err := UnmarshalMessage(respBytes[2:])
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return
 | 
							return respBytes, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						ttl := 0
 | 
				
			||||||
	ips := []string{}
 | 
						ips := []string{}
 | 
				
			||||||
	for _, answer := range resp.Answers {
 | 
						for _, answer := range resp.Answers {
 | 
				
			||||||
		if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
 | 
							if answer.TYPE == QTypeA || answer.TYPE == QTypeAAAA {
 | 
				
			||||||
			for _, h := range c.Handlers {
 | 
								for _, h := range c.handlers {
 | 
				
			||||||
				h(resp.Question.QNAME, answer.IP)
 | 
									h(resp.Question.QNAME, answer.IP)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if answer.IP != "" {
 | 
								if answer.IP != "" {
 | 
				
			||||||
				ips = append(ips, answer.IP)
 | 
									ips = append(ips, answer.IP)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								ttl = int(answer.TTL)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// if ttl in packet is 0, set it to default value
 | 
				
			||||||
 | 
						if ttl == 0 {
 | 
				
			||||||
 | 
							ttl = DefaultTTL
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// add to cache
 | 
						// add to cache
 | 
				
			||||||
 | 
						c.cache.Put(getKey(resp.Question), respBytes, ttl)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	log.F("[dns] %s <-> %s, type: %d, %s: %s",
 | 
						log.F("[dns] %s <-> %s, type: %d, %s: %s",
 | 
				
			||||||
		clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
 | 
							clientAddr, dnsServer, resp.Question.QTYPE, resp.Question.QNAME, strings.Join(ips, ","))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return
 | 
						return respBytes, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetServer .
 | 
					// SetServer .
 | 
				
			||||||
func (c *Client) SetServer(domain string, servers ...string) {
 | 
					func (c *Client) SetServer(domain string, servers ...string) {
 | 
				
			||||||
	c.UPServerMap[domain] = append(c.UPServerMap[domain], servers...)
 | 
						c.upServerMap[domain] = append(c.upServerMap[domain], servers...)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// GetServer .
 | 
					// GetServer .
 | 
				
			||||||
@ -120,16 +140,75 @@ func (c *Client) GetServer(domain string) string {
 | 
				
			|||||||
	for i := length - 2; i >= 0; i-- {
 | 
						for i := length - 2; i >= 0; i-- {
 | 
				
			||||||
		domain := strings.Join(domainParts[i:length], ".")
 | 
							domain := strings.Join(domainParts[i:length], ".")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if servers, ok := c.UPServerMap[domain]; ok {
 | 
							if servers, ok := c.upServerMap[domain]; ok {
 | 
				
			||||||
			return servers[0]
 | 
								return servers[0]
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// TODO:
 | 
						// TODO:
 | 
				
			||||||
	return c.UPServers[0]
 | 
						return c.upServers[0]
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AddHandler .
 | 
					// AddHandler .
 | 
				
			||||||
func (c *Client) AddHandler(h HandleFunc) {
 | 
					func (c *Client) AddHandler(h HandleFunc) {
 | 
				
			||||||
	c.Handlers = append(c.Handlers, h)
 | 
						c.handlers = append(c.handlers, h)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// AddRecord adds custom record to dns cache, format:
 | 
				
			||||||
 | 
					// www.example.com/1.2.3.4 or www.example.com/2606:2800:220:1:248:1893:25c8:1946
 | 
				
			||||||
 | 
					func (c *Client) AddRecord(record string) error {
 | 
				
			||||||
 | 
						r := strings.Split(record, "/")
 | 
				
			||||||
 | 
						domain, ip := r[0], r[1]
 | 
				
			||||||
 | 
						m, err := c.GenResponse(domain, ip)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						b, _ := m.Marshal()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var buf bytes.Buffer
 | 
				
			||||||
 | 
						binary.Write(&buf, binary.BigEndian, uint16(len(b)))
 | 
				
			||||||
 | 
						buf.Write(b)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						c.cache.Put(getKey(m.Question), buf.Bytes(), HundredYears)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// GenResponse .
 | 
				
			||||||
 | 
					func (c *Client) GenResponse(domain string, ip string) (*Message, error) {
 | 
				
			||||||
 | 
						ipb := net.ParseIP(ip)
 | 
				
			||||||
 | 
						if ipb == nil {
 | 
				
			||||||
 | 
							return nil, errors.New("GenResponse: invalid ip format")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var rdata []byte
 | 
				
			||||||
 | 
						var qtype, rdlen uint16
 | 
				
			||||||
 | 
						if rdata = ipb.To4(); rdata != nil {
 | 
				
			||||||
 | 
							qtype = QTypeA
 | 
				
			||||||
 | 
							rdlen = net.IPv4len
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							qtype = QTypeAAAA
 | 
				
			||||||
 | 
							rdlen = net.IPv6len
 | 
				
			||||||
 | 
							rdata = ipb
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						m := NewMessage(0, Response)
 | 
				
			||||||
 | 
						m.SetQuestion(NewQuestion(qtype, domain))
 | 
				
			||||||
 | 
						rr := &RR{NAME: domain, TYPE: qtype, CLASS: CLASSIN,
 | 
				
			||||||
 | 
							RDLENGTH: rdlen, RDATA: rdata}
 | 
				
			||||||
 | 
						m.AddAnswer(rr)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return m, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func getKey(q *Question) string {
 | 
				
			||||||
 | 
						qtype := ""
 | 
				
			||||||
 | 
						switch q.QTYPE {
 | 
				
			||||||
 | 
						case QTypeA:
 | 
				
			||||||
 | 
							qtype = "A"
 | 
				
			||||||
 | 
						case QTypeAAAA:
 | 
				
			||||||
 | 
							qtype = "AAAA"
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return q.QNAME + "/" + qtype
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
@ -19,10 +19,10 @@ const UDPMaxLen = 512
 | 
				
			|||||||
// HeaderLen is the length of dns msg header
 | 
					// HeaderLen is the length of dns msg header
 | 
				
			||||||
const HeaderLen = 12
 | 
					const HeaderLen = 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// QR types
 | 
					// Message types
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	QRQuery    = 0
 | 
						Query    = 0
 | 
				
			||||||
	QRResponse = 1
 | 
						Response = 1
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// QType .
 | 
					// QType .
 | 
				
			||||||
@ -31,6 +31,9 @@ const (
 | 
				
			|||||||
	QTypeAAAA uint16 = 28 ///ipv6
 | 
						QTypeAAAA uint16 = 28 ///ipv6
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// CLASSIN .
 | 
				
			||||||
 | 
					const CLASSIN uint16 = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Message format
 | 
					// Message format
 | 
				
			||||||
// https://tools.ietf.org/html/rfc1035#section-4.1
 | 
					// https://tools.ietf.org/html/rfc1035#section-4.1
 | 
				
			||||||
// All communications inside of the domain protocol are carried in a single
 | 
					// All communications inside of the domain protocol are carried in a single
 | 
				
			||||||
@ -60,10 +63,15 @@ type Message struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewMessage returns a new message
 | 
					// NewMessage returns a new message
 | 
				
			||||||
func NewMessage() *Message {
 | 
					func NewMessage(id uint16, msgType int) *Message {
 | 
				
			||||||
	return &Message{
 | 
						if id == 0 {
 | 
				
			||||||
		Header: &Header{},
 | 
							id = uint16(rand.Uint32())
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						h := &Header{ID: id}
 | 
				
			||||||
 | 
						h.SetMsgType(msgType)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return &Message{Header: h}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SetQuestion sets a question to dns message,
 | 
					// SetQuestion sets a question to dns message,
 | 
				
			||||||
@ -111,7 +119,7 @@ func (m *Message) Marshal() ([]byte, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// UnmarshalMessage unmarshals []bytes to Message
 | 
					// UnmarshalMessage unmarshals []bytes to Message
 | 
				
			||||||
func UnmarshalMessage(b []byte) (*Message, error) {
 | 
					func UnmarshalMessage(b []byte) (*Message, error) {
 | 
				
			||||||
	msg := NewMessage()
 | 
						msg := &Message{Header: &Header{}}
 | 
				
			||||||
	msg.unMarshaled = b
 | 
						msg.unMarshaled = b
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	err := UnmarshalHeader(b[:HeaderLen], msg.Header)
 | 
						err := UnmarshalHeader(b[:HeaderLen], msg.Header)
 | 
				
			||||||
@ -174,19 +182,8 @@ type Header struct {
 | 
				
			|||||||
	ARCOUNT uint16
 | 
						ARCOUNT uint16
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewHeader returns a new dns header
 | 
					// SetMsgType .
 | 
				
			||||||
func NewHeader(id uint16, qr int) *Header {
 | 
					func (h *Header) SetMsgType(qr int) {
 | 
				
			||||||
	if id == 0 {
 | 
					 | 
				
			||||||
		id = uint16(rand.Uint32())
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	h := &Header{ID: id}
 | 
					 | 
				
			||||||
	h.SetQR(qr)
 | 
					 | 
				
			||||||
	return h
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// SetQR .
 | 
					 | 
				
			||||||
func (h *Header) SetQR(qr int) {
 | 
					 | 
				
			||||||
	h.Bits |= uint16(qr) << 15
 | 
						h.Bits |= uint16(qr) << 15
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -265,7 +262,7 @@ func NewQuestion(qtype uint16, domain string) *Question {
 | 
				
			|||||||
	return &Question{
 | 
						return &Question{
 | 
				
			||||||
		QNAME:  domain,
 | 
							QNAME:  domain,
 | 
				
			||||||
		QTYPE:  qtype,
 | 
							QTYPE:  qtype,
 | 
				
			||||||
		QCLASS: 1,
 | 
							QCLASS: CLASSIN,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -418,12 +415,12 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
 | 
				
			|||||||
		// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
							// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
 | 
				
			||||||
		if b[idx]&0xC0 == 0xC0 {
 | 
							if b[idx]&0xC0 == 0xC0 {
 | 
				
			||||||
			offset := binary.BigEndian.Uint16(b[idx : idx+2])
 | 
								offset := binary.BigEndian.Uint16(b[idx : idx+2])
 | 
				
			||||||
			lable, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
 | 
								label, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF))
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				return "", 0, err
 | 
									return "", 0, err
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			labels = append(labels, lable)
 | 
								labels = append(labels, label)
 | 
				
			||||||
			idx += 2
 | 
								idx += 2
 | 
				
			||||||
			break
 | 
								break
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
@ -432,11 +429,11 @@ func (m *Message) UnmarshalDomain(b []byte) (string, int, error) {
 | 
				
			|||||||
				idx++
 | 
									idx++
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
			} else if size > 63 {
 | 
								} else if size > 63 {
 | 
				
			||||||
				return "", 0, errors.New("UnmarshalDomain: label length more than 63")
 | 
									return "", 0, errors.New("UnmarshalDomain: label size larger than 63")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if idx+size+1 > len(b) {
 | 
								if idx+size+1 > len(b) {
 | 
				
			||||||
				return "", 0, errors.New("UnmarshalDomain: label length more than 63")
 | 
									return "", 0, errors.New("UnmarshalDomain: label size larger than msg length")
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			labels = append(labels, string(b[idx+1:idx+size+1]))
 | 
								labels = append(labels, string(b[idx+1:idx+size+1]))
 | 
				
			||||||
 | 
				
			|||||||
@ -21,8 +21,8 @@ type Server struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewServer returns a new dns server
 | 
					// NewServer returns a new dns server
 | 
				
			||||||
func NewServer(addr string, dialer proxy.Dialer, upServers ...string) (*Server, error) {
 | 
					func NewServer(addr string, dialer proxy.Dialer, upServers []string) (*Server, error) {
 | 
				
			||||||
	c, err := NewClient(dialer, upServers...)
 | 
						c, err := NewClient(dialer, upServers)
 | 
				
			||||||
	s := &Server{
 | 
						s := &Server{
 | 
				
			||||||
		addr:   addr,
 | 
							addr:   addr,
 | 
				
			||||||
		Client: c,
 | 
							Client: c,
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										7
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								main.go
									
									
									
									
									
								
							@ -57,11 +57,16 @@ func main() {
 | 
				
			|||||||
	dialer := NewRuleDialer(conf.rules, dialerFromConf())
 | 
						dialer := NewRuleDialer(conf.rules, dialerFromConf())
 | 
				
			||||||
	ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules)
 | 
						ipsetM, _ := NewIPSetManager(conf.IPSet, conf.rules)
 | 
				
			||||||
	if conf.DNS != "" {
 | 
						if conf.DNS != "" {
 | 
				
			||||||
		d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer...)
 | 
							d, err := dns.NewServer(conf.DNS, dialer, conf.DNSServer)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			log.Fatal(err)
 | 
								log.Fatal(err)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// custom records
 | 
				
			||||||
 | 
							for _, record := range conf.DNSRecord {
 | 
				
			||||||
 | 
								d.AddRecord(record)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// rule
 | 
							// rule
 | 
				
			||||||
		for _, r := range conf.rules {
 | 
							for _, r := range conf.rules {
 | 
				
			||||||
			for _, domain := range r.Domain {
 | 
								for _, domain := range r.Domain {
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user