diff --git a/dns/client.go b/dns/client.go index 2c9afbf..8fc5d5f 100644 --- a/dns/client.go +++ b/dns/client.go @@ -67,7 +67,7 @@ func (c *Client) Exchange(reqBytes []byte, clientAddr string, preferTCP bool) ([ } if c.config.NoAAAA && req.Question.QTYPE == QTypeAAAA { - reqBytes[2] |= uint8(Response) << 7 + reqBytes[2] |= uint8(ResponseMsg) << 7 return reqBytes, nil } @@ -321,7 +321,7 @@ func MakeResponse(domain, ip string, ttl uint32) (*Message, error) { rdata = ipb } - m := NewMessage(0, Response) + m := NewMessage(0, ResponseMsg) m.SetQuestion(NewQuestion(qtype, domain)) rr := &RR{NAME: domain, TYPE: qtype, CLASS: ClassINET, TTL: ttl, RDLENGTH: rdlen, RDATA: rdata} diff --git a/dns/message.go b/dns/message.go index 4b24e19..b1b0907 100644 --- a/dns/message.go +++ b/dns/message.go @@ -20,10 +20,13 @@ const UDPMaxLen = 512 // HeaderLen is the length of dns msg header. const HeaderLen = 12 +// MsgType is the dns Message type. +type MsgType byte + // Message types. const ( - Query = 0 - Response = 1 + QueryMsg MsgType = 0 + ResponseMsg MsgType = 1 ) // Query types. @@ -64,7 +67,7 @@ type Message struct { } // NewMessage returns a new message. -func NewMessage(id uint16, msgType int) *Message { +func NewMessage(id uint16, msgType MsgType) *Message { if id == 0 { id = uint16(rand.Uint32()) } @@ -194,7 +197,7 @@ type Header struct { } // SetMsgType sets the message type. -func (h *Header) SetMsgType(qr int) { +func (h *Header) SetMsgType(qr MsgType) { h.Bits |= uint16(qr) << 15 } diff --git a/ipset/ipset_linux.go b/ipset/ipset_linux.go index b5e0111..8cebe8f 100644 --- a/ipset/ipset_linux.go +++ b/ipset/ipset_linux.go @@ -30,8 +30,7 @@ func NewManager(rules []*rule.Config) (*Manager, error) { } for set := range sets { - ipset.Create(set) - ipset.Flush(set) + createSet(set) } // init ipset @@ -42,10 +41,10 @@ func NewManager(rules []*rule.Config) (*Manager, error) { m.domainSet.Store(domain, r.IPSet) } for _, ip := range r.IP { - ipset.Add(r.IPSet, ip) + addToSet(r.IPSet, ip) } for _, cidr := range r.CIDR { - ipset.Add(r.IPSet, cidr) + addToSet(r.IPSet, cidr) } } } @@ -63,9 +62,23 @@ func (m *Manager) AddDomainIP(domain, ip string) error { for i := len(domain); i != -1; { i = strings.LastIndexByte(domain[:i], '.') if setName, ok := m.domainSet.Load(domain[i+1:]); ok { - ipset.Add(setName.(string), ip) + addToSet(setName.(string), ip) } } return nil } + +func createSet(s string) { + ipset.Create(s) + ipset.Flush(s) + ipset.Create(s+"6", ipset.OptIPv6()) + ipset.Flush(s + "6") +} + +func addToSet(s, item string) error { + if strings.IndexByte(item, '.') == -1 { + return ipset.Add(s+"6", item) + } + return ipset.Add(s, item) +} diff --git a/main.go b/main.go index 063719e..a7086b1 100644 --- a/main.go +++ b/main.go @@ -72,7 +72,6 @@ func main() { if err != nil { log.Fatal(err) } - go local.ListenAndServe() }