From 5fef07134926a8ef6d621b2738b49feeb26529d1 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Wed, 16 Aug 2017 13:20:12 +0800 Subject: [PATCH] dns: add experimental codes to specify different remote dns server in rule file --- dns.go | 49 ++++++++++++++++++++++++++++++++++++++----------- main.go | 30 +++++++++++++++++++++++++++++- rule.go | 6 ++++++ rules.go | 2 +- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/dns.go b/dns.go index a079de0..f2ab29a 100644 --- a/dns.go +++ b/dns.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "io/ioutil" "net" + "strings" ) // UDPDNSHeaderLen is the length of UDP dns msg header @@ -23,23 +24,26 @@ const TCPDNSHEADERLen = 2 + UDPDNSHeaderLen // so we should also serve tcp requests. const MaxUDPDNSLen = 512 -type dns struct { +type DNS struct { *proxy - raddr string + dnsServer string + + dnsServerMap map[string]string } // DNSForwarder returns a dns forwarder. client -> dns.udp -> glider -> forwarder -> remote dns addr -func DNSForwarder(addr, raddr string, upProxy Proxy) (Proxy, error) { - s := &dns{ - proxy: newProxy(addr, upProxy), - raddr: raddr, +func DNSForwarder(addr, raddr string, upProxy Proxy) (*DNS, error) { + s := &DNS{ + proxy: newProxy(addr, upProxy), + dnsServer: raddr, + dnsServerMap: make(map[string]string), } return s, nil } // ListenAndServe . -func (s *dns) ListenAndServe() { +func (s *DNS) ListenAndServe() { l, err := net.ListenPacket("udp", s.addr) if err != nil { logf("failed to listen on %s: %v", s.addr, err) @@ -62,16 +66,18 @@ func (s *dns) ListenAndServe() { go func() { // TODO: check domain rules and get a proper upstream name server. - domain := getDomain(data) + domain := string(getDomain(data)) - rc, err := s.GetProxy(s.raddr).Dial("tcp", s.raddr) + dnsServer := s.GetServer(domain) + // TODO: check here + rc, err := s.GetProxy(domain+":53").GetProxy(domain+":53").Dial("tcp", dnsServer) if err != nil { - logf("failed to connect to server %v: %v", s.raddr, err) + logf("failed to connect to server %v: %v", dnsServer, err) return } defer rc.Close() - logf("proxy-dns %s, %s <-> %s", domain, clientAddr.String(), s.raddr) + logf("proxy-dns %s, %s <-> %s", domain, clientAddr.String(), dnsServer) // 2 bytes length after tcp header, before dns message length := make([]byte, 2) @@ -99,6 +105,27 @@ func (s *dns) ListenAndServe() { } } +// SetServer . +func (s *DNS) SetServer(domain, server string) { + s.dnsServerMap[domain] = server +} + +// GetServer . +func (s *DNS) GetServer(domain string) string { + + domainParts := strings.Split(domain, ".") + length := len(domainParts) + for i := length - 2; i >= 0; i-- { + domain := strings.Join(domainParts[i:length], ".") + + if server, ok := s.dnsServerMap[domain]; ok { + return server + } + } + + return s.dnsServer +} + // getDomain from dns request playload, return []byte like: // []byte{'w', 'w', 'w', '.', 'm', 's', 'n', '.', 'c', 'o', 'm', '.'} // []byte("www.msn.com.") diff --git a/main.go b/main.go index f2b2972..368c1f8 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,11 @@ var conf struct { Listen []string Forward []string RuleFile []string + + DNS string + DNSServer []string + + IPSet string } var flag = conflag.New() @@ -118,6 +123,11 @@ func main() { flag.StringSliceUniqVar(&conf.Forward, "forward", nil, "forward url, format: SCHEMA://[USER|METHOD:PASSWORD@][HOST]:PORT[,SCHEMA://[USER|METHOD:PASSWORD@][HOST]:PORT]") flag.StringSliceUniqVar(&conf.RuleFile, "rulefile", nil, "rule file path") + flag.StringVar(&conf.DNS, "dns", "", "dns listen address") + flag.StringSliceUniqVar(&conf.DNSServer, "dnsserver", []string{"8.8.8.8:53"}, "remote dns server") + + flag.StringVar(&conf.IPSet, "ipset", "glider", "ipset name") + flag.Usage = usage err := flag.Parse() if err != nil { @@ -125,7 +135,7 @@ func main() { return } - if len(conf.Listen) == 0 { + if len(conf.Listen) == 0 && conf.DNS == "" { flag.Usage() fmt.Fprintf(os.Stderr, "ERROR: listen url must be specified.\n") return @@ -177,6 +187,24 @@ func main() { } } + if conf.DNS != "" { + dns, err := DNSForwarder(conf.DNS, conf.DNSServer[0], forwarder) + if err != nil { + log.Fatal(err) + } + + // rule + for _, frwder := range ruleForwarders { + for _, domain := range frwder.Domain { + if len(frwder.DNSServer) > 0 { + dns.SetServer(domain, frwder.DNSServer[0]) + } + } + } + + go dns.ListenAndServe() + } + sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh diff --git a/rule.go b/rule.go index 5b0e27d..afd2081 100644 --- a/rule.go +++ b/rule.go @@ -16,6 +16,9 @@ type ruleForwarder struct { CheckWebSite string CheckDuration int + DNSServer []string + IPSet string + Domain []string IP []string CIDR []string @@ -34,6 +37,9 @@ func newRuleProxyFromFile(ruleFile string) (*ruleForwarder, error) { f.StringVar(&p.CheckWebSite, "checkwebsite", "www.apple.com", "proxy check HTTP(NOT HTTPS) website address, format: HOST[:PORT], default port: 80") f.IntVar(&p.CheckDuration, "checkduration", 30, "proxy check duration(seconds)") + f.StringSliceUniqVar(&p.DNSServer, "dnsserver", nil, "remote dns server") + f.StringVar(&p.IPSet, "ipset", "", "ipset name") + f.StringSliceUniqVar(&p.Domain, "domain", nil, "domain") f.StringSliceUniqVar(&p.IP, "ip", nil, "ip") f.StringSliceUniqVar(&p.CIDR, "cidr", nil, "cidr") diff --git a/rules.go b/rules.go index 64aee7f..c71dc43 100644 --- a/rules.go +++ b/rules.go @@ -42,7 +42,7 @@ func newRulesForwarder(ruleForwarders []*ruleForwarder, globalForwarder Proxy) P return p } -func (p *rulesForwarder) Addr() string { return "rule forwarder" } +func (p *rulesForwarder) Addr() string { return "rules forwarder" } func (p *rulesForwarder) ListenAndServe() {} func (p *rulesForwarder) Serve(c net.Conn) {} func (p *rulesForwarder) CurrentProxy() Proxy { return p.globalForwarder.CurrentProxy() }