From 7fda4f8710ebc253e23b3461a2203be3065dea12 Mon Sep 17 00:00:00 2001 From: nadoo <287492+nadoo@users.noreply.github.com> Date: Thu, 28 Jun 2018 11:20:48 +0800 Subject: [PATCH] dns: optimize code --- dns/dns.go | 19 +++++++++---------- ipset_linux.go | 14 +++++++++++++- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/dns/dns.go b/dns/dns.go index 3ccea19..bb4f71e 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -291,10 +291,7 @@ func (s *DNS) Exchange(reqLen uint16, reqMsg []byte, addr string) (respLen uint1 return } - dnsServer := s.DNSServer - if !s.Tunnel { - dnsServer = s.GetServer(query.QNAME) - } + dnsServer := s.GetServer(query.QNAME) rc, err := s.dialer.NextDialer(query.QNAME+":53").Dial("tcp", dnsServer) if err != nil { @@ -366,13 +363,15 @@ func (s *DNS) SetServer(domain, server string) { // GetServer . func (s *DNS) GetServer(domain string) string { - domainParts := strings.Split(domain, ".") - length := len(domainParts) - for i := length - 2; i >= 0; i-- { - domain := strings.Join(domainParts[i:length], ".") + if !s.Tunnel { + domainParts := strings.Split(domain, ".") + length := len(domainParts) + for i := length - 2; i >= 0; i-- { + domain := strings.Join(domainParts[i:length], ".") - if server, ok := s.DNSServerMap[domain]; ok { - return server + if server, ok := s.DNSServerMap[domain]; ok { + return server + } } } diff --git a/ipset_linux.go b/ipset_linux.go index f3cd02d..ba93d6f 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -21,8 +21,8 @@ import ( // https://github.com/torvalds/linux/blob/9e66317d3c92ddaab330c125dfe9d06eee268aff/include/uapi/linux/netfilter/nfnetlink.h#L56 const NFNL_SUBSYS_IPSET = 6 -// http://git.netfilter.org/ipset/tree/include/libipset/linux_ip_set.h // IPSET_PROTOCOL The protocol version +// http://git.netfilter.org/ipset/tree/include/libipset/linux_ip_set.h const IPSET_PROTOCOL = 6 // IPSET_MAXNAMELEN The max length of strings including NUL: set and type identifiers @@ -146,6 +146,7 @@ func (m *IPSetManager) AddDomainIP(domain, ip string) error { return nil } +// CreateSet create a ipset func CreateSet(fd int, lsa syscall.SockaddrNetlink, setName string) { if setName == "" { return @@ -189,6 +190,7 @@ func CreateSet(fd int, lsa syscall.SockaddrNetlink, setName string) { FlushSet(fd, lsa, setName) } +// FlushSet flush a ipset func FlushSet(fd int, lsa syscall.SockaddrNetlink, setName string) { log.F("ipset flush %s", setName) @@ -206,6 +208,7 @@ func FlushSet(fd int, lsa syscall.SockaddrNetlink, setName string) { } +// AddToSet adds an entry to ipset func AddToSet(fd int, lsa syscall.SockaddrNetlink, setName, entry string) { if setName == "" { return @@ -280,6 +283,7 @@ func rtaAlignOf(attrlen int) int { return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1) } +// NetlinkRequestData . type NetlinkRequestData interface { Len() int Serialize() []byte @@ -291,6 +295,7 @@ type NfGenMsg struct { resID uint16 } +// NewNfGenMsg . func NewNfGenMsg(nfgenFamily, version, resID int) *NfGenMsg { return &NfGenMsg{ nfgenFamily: uint8(nfgenFamily), @@ -429,6 +434,7 @@ func (req *NetlinkRequest) Serialize() []byte { return b } +// AddData add data to request func (req *NetlinkRequest) AddData(data NetlinkRequestData) { if data != nil { req.Data = append(req.Data, data) @@ -442,10 +448,12 @@ func (req *NetlinkRequest) AddRawData(data []byte) { } } +// Uint8Attr . func Uint8Attr(v uint8) []byte { return []byte{byte(v)} } +// Uint16Attr . func Uint16Attr(v uint16) []byte { native := NativeEndian() bytes := make([]byte, 2) @@ -453,6 +461,7 @@ func Uint16Attr(v uint16) []byte { return bytes } +// Uint32Attr . func Uint32Attr(v uint32) []byte { native := NativeEndian() bytes := make([]byte, 4) @@ -460,6 +469,7 @@ func Uint32Attr(v uint32) []byte { return bytes } +// ZeroTerminated . func ZeroTerminated(s string) []byte { bytes := make([]byte, len(s)+1) for i := 0; i < len(s); i++ { @@ -469,6 +479,7 @@ func ZeroTerminated(s string) []byte { return bytes } +// NonZeroTerminated . func NonZeroTerminated(s string) []byte { bytes := make([]byte, len(s)) for i := 0; i < len(s); i++ { @@ -477,6 +488,7 @@ func NonZeroTerminated(s string) []byte { return bytes } +// BytesToString . func BytesToString(b []byte) string { n := bytes.Index(b, []byte{0}) return string(b[:n])