ipset: add ipset management. REQUIRE CAP_NET_ADMIN capability.

This commit is contained in:
nadoo 2017-08-28 19:23:32 +08:00
parent 419fa05171
commit ada5cdfcc3
3 changed files with 424 additions and 0 deletions

400
ipset_linux.go Normal file
View File

@ -0,0 +1,400 @@
// Apache License 2.0
// @mdlayher https://github.com/mdlayher/netlink
// Ref: https://github.com/vishvananda/netlink/blob/master/nl/nl_linux.go
package main
import (
"bytes"
"encoding/binary"
"log"
"net"
"strings"
"sync"
"sync/atomic"
"syscall"
"unsafe"
)
const NFNL_SUBSYS_IPSET = 6
// http://git.netfilter.org/ipset/tree/include/libipset/linux_ip_set.h
/* The protocol version */
const IPSET_PROTOCOL = 6
/* The max length of strings including NUL: set and type identifiers */
const IPSET_MAXNAMELEN = 32
/* Message types and commands */
const IPSET_CMD_CREATE = 2
const IPSET_CMD_ADD = 9
const IPSET_CMD_DEL = 10
/* Attributes at command level */
const IPSET_ATTR_PROTOCOL = 1 /* 1: Protocol version */
const IPSET_ATTR_SETNAME = 2 /* 2: Name of the set */
const IPSET_ATTR_TYPENAME = 3 /* 3: Typename */
const IPSET_ATTR_REVISION = 4 /* 4: Settype revision */
const IPSET_ATTR_FAMILY = 5 /* 5: Settype family */
const IPSET_ATTR_DATA = 7 /* 7: Nested attributes */
const IPSET_ATTR_IP = 1
/* IP specific attributes */
const IPSET_ATTR_IPADDR_IPV4 = 1
const IPSET_ATTR_IPADDR_IPV6 = 2
const NLA_F_NESTED = (1 << 15)
const NLA_F_NET_BYTEORDER = (1 << 14)
var nextSeqNr uint32
var nativeEndian binary.ByteOrder
type IPSetManager struct {
fd int
lsa syscall.SockaddrNetlink
domainSet sync.Map
}
func NewIPSetManager(rules []*RuleConf) (*IPSetManager, error) {
fd, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_NETFILTER)
if err != nil {
logf("%s", err)
return nil, err
}
// defer syscall.Close(fd)
lsa := syscall.SockaddrNetlink{
Family: syscall.AF_NETLINK,
}
if err = syscall.Bind(fd, &lsa); err != nil {
logf("%s", err)
return nil, err
}
var domainSet sync.Map
for _, r := range rules {
CreateSet(fd, lsa, r.IPSet)
for _, domain := range r.Domain {
domainSet.Store(domain, r.IPSet)
}
for _, ip := range r.IP {
AddToSet(fd, lsa, r.IPSet, ip)
}
// TODO: add all ip in cidr to ipset
// for _, s := range r.CIDR {
// if _, cidr, err := net.ParseCIDR(s); err == nil {
// rd.cidrMap.Store(cidr, sd)
// }
// }
}
return &IPSetManager{fd: fd, lsa: lsa, domainSet: domainSet}, nil
}
// AddDomainIP used to update ipset according to domainSet rule
func (m *IPSetManager) AddDomainIP(domain, ip string) error {
if ip != "" {
logf("domain: %s, ip: %s\n", domain, ip)
domainParts := strings.Split(domain, ".")
length := len(domainParts)
for i := length - 2; i >= 0; i-- {
domain := strings.Join(domainParts[i:length], ".")
// find in domainMap
if ipset, ok := m.domainSet.Load(domain); ok {
AddToSet(m.fd, m.lsa, ipset.(string), ip)
}
}
}
return nil
}
func CreateSet(fd int, lsa syscall.SockaddrNetlink, setName string) {
if len(setName) > IPSET_MAXNAMELEN {
log.Fatal("ipset name too long")
}
// req := NewNetlinkRequest(1538, syscall.NLM_F_REQUEST)
req := NewNetlinkRequest(IPSET_CMD_CREATE|(NFNL_SUBSYS_IPSET<<8), syscall.NLM_F_REQUEST)
// TODO: support AF_INET6
nfgenMsg := NewNfGenMsg(syscall.AF_INET, 0, 0)
req.AddData(nfgenMsg)
attrProto := NewRtAttr(IPSET_ATTR_PROTOCOL, Uint8Attr(IPSET_PROTOCOL))
req.AddData(attrProto)
attrSiteName := NewRtAttr(IPSET_ATTR_SETNAME, ZeroTerminated(setName))
req.AddData(attrSiteName)
attrSiteType := NewRtAttr(IPSET_ATTR_TYPENAME, ZeroTerminated("hash:ip"))
req.AddData(attrSiteType)
attrRev := NewRtAttr(IPSET_ATTR_REVISION, Uint8Attr(1))
req.AddData(attrRev)
attrFamily := NewRtAttr(IPSET_ATTR_FAMILY, Uint8Attr(2))
req.AddData(attrFamily)
attrData := NewRtAttr(IPSET_ATTR_DATA|NLA_F_NESTED, nil)
req.AddData(attrData)
err := syscall.Sendto(fd, req.Serialize(), 0, &lsa)
logf("%s", err)
}
func AddToSet(fd int, lsa syscall.SockaddrNetlink, setName, ipStr string) {
if len(setName) > IPSET_MAXNAMELEN {
logf("ipset name too long")
}
ip := net.ParseIP(ipStr).To4()
req := NewNetlinkRequest(IPSET_CMD_ADD|(NFNL_SUBSYS_IPSET<<8), syscall.NLM_F_REQUEST)
// TODO: support AF_INET6
nfgenMsg := NewNfGenMsg(syscall.AF_INET, 0, 0)
req.AddData(nfgenMsg)
attrProto := NewRtAttr(IPSET_ATTR_PROTOCOL, Uint8Attr(IPSET_PROTOCOL))
req.AddData(attrProto)
attrSiteName := NewRtAttr(IPSET_ATTR_SETNAME, ZeroTerminated(setName))
req.AddData(attrSiteName)
attrNested := NewRtAttr(IPSET_ATTR_DATA|NLA_F_NESTED, nil)
attrIP := NewRtAttrChild(attrNested, IPSET_ATTR_IP|NLA_F_NESTED, nil)
NewRtAttrChild(attrIP, IPSET_ATTR_IPADDR_IPV4|NLA_F_NET_BYTEORDER, ip)
// NewRtAttrChild(attrNested, 9|NLA_F_NET_BYTEORDER, Uint32Attr(0))
req.AddData(attrNested)
err := syscall.Sendto(fd, req.Serialize(), 0, &lsa)
logf("%s", err)
}
// Get native endianness for the system
func NativeEndian() binary.ByteOrder {
if nativeEndian == nil {
var x uint32 = 0x01020304
if *(*byte)(unsafe.Pointer(&x)) == 0x01 {
nativeEndian = binary.BigEndian
} else {
nativeEndian = binary.LittleEndian
}
}
return nativeEndian
}
func rtaAlignOf(attrlen int) int {
return (attrlen + syscall.RTA_ALIGNTO - 1) & ^(syscall.RTA_ALIGNTO - 1)
}
type NetlinkRequestData interface {
Len() int
Serialize() []byte
}
type NfGenMsg struct {
nfgenFamily uint8
version uint8
resID uint16
}
func NewNfGenMsg(nfgenFamily, version, resID int) *NfGenMsg {
return &NfGenMsg{
nfgenFamily: uint8(nfgenFamily),
version: uint8(version),
resID: uint16(resID),
}
}
func (m *NfGenMsg) Len() int {
return rtaAlignOf(4)
}
func (m *NfGenMsg) Serialize() []byte {
native := NativeEndian()
length := m.Len()
buf := make([]byte, rtaAlignOf(length))
buf[0] = m.nfgenFamily
buf[1] = m.version
native.PutUint16(buf[2:4], m.resID)
return buf
}
// Extend RtAttr to handle data and children
type RtAttr struct {
syscall.RtAttr
Data []byte
children []NetlinkRequestData
}
// Create a new Extended RtAttr object
func NewRtAttr(attrType int, data []byte) *RtAttr {
return &RtAttr{
RtAttr: syscall.RtAttr{
Type: uint16(attrType),
},
children: []NetlinkRequestData{},
Data: data,
}
}
// Create a new RtAttr obj anc add it as a child of an existing object
func NewRtAttrChild(parent *RtAttr, attrType int, data []byte) *RtAttr {
attr := NewRtAttr(attrType, data)
parent.children = append(parent.children, attr)
return attr
}
func (a *RtAttr) Len() int {
if len(a.children) == 0 {
return (syscall.SizeofRtAttr + len(a.Data))
}
l := 0
for _, child := range a.children {
l += rtaAlignOf(child.Len())
}
l += syscall.SizeofRtAttr
return rtaAlignOf(l + len(a.Data))
}
// Serialize the RtAttr into a byte array
// This can't just unsafe.cast because it must iterate through children.
func (a *RtAttr) Serialize() []byte {
native := NativeEndian()
length := a.Len()
buf := make([]byte, rtaAlignOf(length))
next := 4
if a.Data != nil {
copy(buf[next:], a.Data)
next += rtaAlignOf(len(a.Data))
}
if len(a.children) > 0 {
for _, child := range a.children {
childBuf := child.Serialize()
copy(buf[next:], childBuf)
next += rtaAlignOf(len(childBuf))
}
}
if l := uint16(length); l != 0 {
native.PutUint16(buf[0:2], l)
}
native.PutUint16(buf[2:4], a.Type)
return buf
}
type NetlinkRequest struct {
syscall.NlMsghdr
Data []NetlinkRequestData
RawData []byte
}
// Create a new netlink request from proto and flags
// Note the Len value will be inaccurate once data is added until
// the message is serialized
func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
return &NetlinkRequest{
NlMsghdr: syscall.NlMsghdr{
Len: uint32(syscall.SizeofNlMsghdr),
Type: uint16(proto),
Flags: syscall.NLM_F_REQUEST | uint16(flags),
Seq: atomic.AddUint32(&nextSeqNr, 1),
// Pid: uint32(os.Getpid()),
},
}
}
// Serialize the Netlink Request into a byte array
func (req *NetlinkRequest) Serialize() []byte {
length := syscall.SizeofNlMsghdr
dataBytes := make([][]byte, len(req.Data))
for i, data := range req.Data {
dataBytes[i] = data.Serialize()
length = length + len(dataBytes[i])
}
length += len(req.RawData)
req.Len = uint32(length)
b := make([]byte, length)
hdr := (*(*[syscall.SizeofNlMsghdr]byte)(unsafe.Pointer(req)))[:]
next := syscall.SizeofNlMsghdr
copy(b[0:next], hdr)
for _, data := range dataBytes {
for _, dataByte := range data {
b[next] = dataByte
next = next + 1
}
}
// Add the raw data if any
if len(req.RawData) > 0 {
copy(b[next:length], req.RawData)
}
return b
}
func (req *NetlinkRequest) AddData(data NetlinkRequestData) {
if data != nil {
req.Data = append(req.Data, data)
}
}
// AddRawData adds raw bytes to the end of the NetlinkRequest object during serialization
func (req *NetlinkRequest) AddRawData(data []byte) {
if data != nil {
req.RawData = append(req.RawData, data...)
}
}
func Uint8Attr(v uint8) []byte {
return []byte{byte(v)}
}
func Uint16Attr(v uint16) []byte {
native := NativeEndian()
bytes := make([]byte, 2)
native.PutUint16(bytes, v)
return bytes
}
func Uint32Attr(v uint32) []byte {
native := NativeEndian()
bytes := make([]byte, 4)
native.PutUint32(bytes, v)
return bytes
}
func ZeroTerminated(s string) []byte {
bytes := make([]byte, len(s)+1)
for i := 0; i < len(s); i++ {
bytes[i] = s[i]
}
bytes[len(s)] = 0
return bytes
}
func NonZeroTerminated(s string) []byte {
bytes := make([]byte, len(s))
for i := 0; i < len(s); i++ {
bytes[i] = s[i]
}
return bytes
}
func BytesToString(b []byte) string {
n := bytes.Index(b, []byte{0})
return string(b[:n])
}

16
ipset_other.go Normal file
View File

@ -0,0 +1,16 @@
// +build !linux
package main
import "errors"
type IPSetManager struct {
}
func NewIPSetManager(rules []*RuleConf) (*IPSetManager, error) {
return nil, errors.New("ipset not supported on this os")
}
func (m *IPSetManager) AddDomainIP(domain, ip string) error {
return errors.New("ipset not supported on this os")
}

View File

@ -43,6 +43,11 @@ func main() {
go local.ListenAndServe()
}
ipsetM, err := NewIPSetManager(conf.rules)
if err != nil {
logf("ipset error: %s", err)
}
if conf.DNS != "" {
dns, err := NewDNS(conf.DNS, conf.DNSServer[0], sDialer)
if err != nil {
@ -60,6 +65,9 @@ func main() {
// add a handler to update proxy rules when a domain resolved
dns.AddAnswerHandler(sDialer.AddDomainIP)
if ipsetM != nil {
dns.AddAnswerHandler(ipsetM.AddDomainIP)
}
go dns.ListenAndServe()
}