dns: add experimental codes to specify different remote dns server in rule file

This commit is contained in:
nadoo 2017-08-16 13:20:12 +08:00
parent 5c0c2e926d
commit 5fef071349
4 changed files with 74 additions and 13 deletions

49
dns.go
View File

@ -6,6 +6,7 @@ import (
"encoding/binary"
"io/ioutil"
"net"
"strings"
)
// UDPDNSHeaderLen is the length of UDP dns msg header
@ -23,23 +24,26 @@ const TCPDNSHEADERLen = 2 + UDPDNSHeaderLen
// so we should also serve tcp requests.
const MaxUDPDNSLen = 512
type dns struct {
type DNS struct {
*proxy
raddr string
dnsServer string
dnsServerMap map[string]string
}
// DNSForwarder returns a dns forwarder. client -> dns.udp -> glider -> forwarder -> remote dns addr
func DNSForwarder(addr, raddr string, upProxy Proxy) (Proxy, error) {
s := &dns{
proxy: newProxy(addr, upProxy),
raddr: raddr,
func DNSForwarder(addr, raddr string, upProxy Proxy) (*DNS, error) {
s := &DNS{
proxy: newProxy(addr, upProxy),
dnsServer: raddr,
dnsServerMap: make(map[string]string),
}
return s, nil
}
// ListenAndServe .
func (s *dns) ListenAndServe() {
func (s *DNS) ListenAndServe() {
l, err := net.ListenPacket("udp", s.addr)
if err != nil {
logf("failed to listen on %s: %v", s.addr, err)
@ -62,16 +66,18 @@ func (s *dns) ListenAndServe() {
go func() {
// TODO: check domain rules and get a proper upstream name server.
domain := getDomain(data)
domain := string(getDomain(data))
rc, err := s.GetProxy(s.raddr).Dial("tcp", s.raddr)
dnsServer := s.GetServer(domain)
// TODO: check here
rc, err := s.GetProxy(domain+":53").GetProxy(domain+":53").Dial("tcp", dnsServer)
if err != nil {
logf("failed to connect to server %v: %v", s.raddr, err)
logf("failed to connect to server %v: %v", dnsServer, err)
return
}
defer rc.Close()
logf("proxy-dns %s, %s <-> %s", domain, clientAddr.String(), s.raddr)
logf("proxy-dns %s, %s <-> %s", domain, clientAddr.String(), dnsServer)
// 2 bytes length after tcp header, before dns message
length := make([]byte, 2)
@ -99,6 +105,27 @@ func (s *dns) ListenAndServe() {
}
}
// SetServer .
func (s *DNS) SetServer(domain, server string) {
s.dnsServerMap[domain] = server
}
// GetServer .
func (s *DNS) GetServer(domain string) string {
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
if server, ok := s.dnsServerMap[domain]; ok {
return server
}
}
return s.dnsServer
}
// getDomain from dns request playload, return []byte like:
// []byte{'w', 'w', 'w', '.', 'm', 's', 'n', '.', 'c', 'o', 'm', '.'}
// []byte("www.msn.com.")

30
main.go
View File

@ -22,6 +22,11 @@ var conf struct {
Listen []string
Forward []string
RuleFile []string
DNS string
DNSServer []string
IPSet string
}
var flag = conflag.New()
@ -118,6 +123,11 @@ func main() {
flag.StringSliceUniqVar(&conf.Forward, "forward", nil, "forward url, format: SCHEMA://[USER|METHOD:PASSWORD@][HOST]:PORT[,SCHEMA://[USER|METHOD:PASSWORD@][HOST]:PORT]")
flag.StringSliceUniqVar(&conf.RuleFile, "rulefile", nil, "rule file path")
flag.StringVar(&conf.DNS, "dns", "", "dns listen address")
flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server")
flag.StringVar(&conf.IPSet, "ipset", "glider", "ipset name")
flag.Usage = usage
err := flag.Parse()
if err != nil {
@ -125,7 +135,7 @@ func main() {
return
}
if len(conf.Listen) == 0 {
if len(conf.Listen) == 0 && conf.DNS == "" {
flag.Usage()
fmt.Fprintf(os.Stderr, "ERROR: listen url must be specified.\n")
return
@ -177,6 +187,24 @@ func main() {
}
}
if conf.DNS != "" {
dns, err := DNSForwarder(conf.DNS, conf.DNSServer[0], forwarder)
if err != nil {
log.Fatal(err)
}
// rule
for _, frwder := range ruleForwarders {
for _, domain := range frwder.Domain {
if len(frwder.DNSServer) > 0 {
dns.SetServer(domain, frwder.DNSServer[0])
}
}
}
go dns.ListenAndServe()
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh

View File

@ -16,6 +16,9 @@ type ruleForwarder struct {
CheckWebSite string
CheckDuration int
DNSServer []string
IPSet string
Domain []string
IP []string
CIDR []string
@ -34,6 +37,9 @@ func newRuleProxyFromFile(ruleFile string) (*ruleForwarder, error) {
f.StringVar(&p.CheckWebSite, "checkwebsite", "www.apple.com", "proxy check HTTP(NOT HTTPS) website address, format: HOST[:PORT], default port: 80")
f.IntVar(&p.CheckDuration, "checkduration", 30, "proxy check duration(seconds)")
f.StringSliceUniqVar(&p.DNSServer, "dnsserver", nil, "remote dns server")
f.StringVar(&p.IPSet, "ipset", "", "ipset name")
f.StringSliceUniqVar(&p.Domain, "domain", nil, "domain")
f.StringSliceUniqVar(&p.IP, "ip", nil, "ip")
f.StringSliceUniqVar(&p.CIDR, "cidr", nil, "cidr")

View File

@ -42,7 +42,7 @@ func newRulesForwarder(ruleForwarders []*ruleForwarder, globalForwarder Proxy) P
return p
}
func (p *rulesForwarder) Addr() string { return "rule forwarder" }
func (p *rulesForwarder) Addr() string { return "rules forwarder" }
func (p *rulesForwarder) ListenAndServe() {}
func (p *rulesForwarder) Serve(c net.Conn) {}
func (p *rulesForwarder) CurrentProxy() Proxy { return p.globalForwarder.CurrentProxy() }