diff --git a/dns/client.go b/dns/client.go index 3ed10c8..3d306c0 100644 --- a/dns/client.go +++ b/dns/client.go @@ -3,7 +3,6 @@ package dns import ( "encoding/binary" "errors" - "fmt" "io" "net" "strings" @@ -308,5 +307,11 @@ func (c *Client) MakeResponse(domain string, ip string) (*Message, error) { } func qKey(q *Question) string { - return fmt.Sprintf("%s/%d", q.QNAME, q.QTYPE) + switch q.QTYPE { + case QTypeA: + return q.QNAME + "/4" + case QTypeAAAA: + return q.QNAME + "/6" + } + return q.QNAME } diff --git a/dns/message.go b/dns/message.go index c76d883..36806c3 100644 --- a/dns/message.go +++ b/dns/message.go @@ -307,12 +307,14 @@ func (m *Message) UnmarshalQuestion(b []byte, q *Question) (n int, err error) { return 0, errors.New("UnmarshalQuestion: not enough data") } - domain, idx, err := m.UnmarshalDomain(b) + sb := new(strings.Builder) + sb.Grow(32) + idx, err := m.UnmarshalDomainTo(sb, b) if err != nil { return 0, err } - q.QNAME = domain + q.QNAME = sb.String() q.QTYPE = binary.BigEndian.Uint16(b[idx : idx+2]) q.QCLASS = binary.BigEndian.Uint16(b[idx+2 : idx+4]) @@ -411,11 +413,14 @@ func (m *Message) UnmarshalRR(start int, rr *RR) (n int, err error) { p := m.unMarshaled[start:] - domain, n, err := m.UnmarshalDomain(p) + sb := new(strings.Builder) + sb.Grow(32) + + n, err = m.UnmarshalDomainTo(sb, p) if err != nil { return 0, err } - rr.NAME = domain + rr.NAME = sb.String() if len(p) <= n+10 { return 0, errors.New("UnmarshalRR: not enough data") @@ -469,67 +474,63 @@ func MarshalDomainTo(w io.Writer, domain string) (n int, err error) { return } -// UnmarshalDomain gets domain from bytes. -func (m *Message) UnmarshalDomain(b []byte) (string, int, error) { +// UnmarshalDomainTo gets domain from bytes to string builder. +func (m *Message) UnmarshalDomainTo(sb *strings.Builder, b []byte) (int, error) { var idx, size int - var labels []string - for { + for len(b[idx:]) != 0 { // https://tools.ietf.org/html/rfc1035#section-4.1.4 // "Message compression", // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ // | 1 1| OFFSET | // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ - if len(b[idx:]) == 0 { - break - - } else if b[idx]&0xC0 == 0xC0 { + if b[idx]&0xC0 == 0xC0 { if len(b[idx:]) < 2 { - return "", 0, errors.New("UnmarshalDomain: not enough size for compressed domain") + return 0, errors.New("UnmarshalDomainTo: not enough size for compressed domain") } offset := binary.BigEndian.Uint16(b[idx : idx+2]) - label, err := m.UnmarshalDomainPoint(int(offset & 0x3FFF)) + err := m.UnmarshalDomainPointTo(sb, int(offset&0x3FFF)) if err != nil { - return "", 0, err + return 0, err } - labels = append(labels, label) idx += 2 break - - } else { - size = int(b[idx]) - idx++ - - // root domain name - if size == 0 { - break - } - - if size > 63 { - return "", 0, errors.New("UnmarshalDomain: label size larger than 63") - } - - if idx+size > len(b) { - return "", 0, errors.New("UnmarshalDomain: label size larger than msg length") - } - - labels = append(labels, string(b[idx:idx+size])) - idx += size } + size = int(b[idx]) + idx++ + + // root domain name + if size == 0 { + break + } + + if size > 63 { + return 0, errors.New("UnmarshalDomainTo: label size larger than 63") + } + + if idx+size > len(b) { + return 0, errors.New("UnmarshalDomainTo: label size larger than msg length") + } + + if sb.Len() > 0 { + sb.WriteByte('.') + } + sb.Write(b[idx : idx+size]) + + idx += size } - domain := strings.Join(labels, ".") - return domain, idx, nil + return idx, nil } -// UnmarshalDomainPoint gets domain from offset point. -func (m *Message) UnmarshalDomainPoint(offset int) (string, error) { +// UnmarshalDomainPointTo gets domain from offset point to string builder. +func (m *Message) UnmarshalDomainPointTo(sb *strings.Builder, offset int) error { if offset > len(m.unMarshaled) { - return "", errors.New("UnmarshalDomainPoint: offset larger than msg length") + return errors.New("UnmarshalDomainPointTo: offset larger than msg length") } - domain, _, err := m.UnmarshalDomain(m.unMarshaled[offset:]) - return domain, err + _, err := m.UnmarshalDomainTo(sb, m.unMarshaled[offset:]) + return err } diff --git a/go.mod b/go.mod index 58250fb..d344f2f 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a golang.org/x/net v0.0.0-20200822124328-c89045814202 // indirect golang.org/x/sys v0.0.0-20200824131525-c12d262b63d8 // indirect - golang.org/x/tools v0.0.0-20200823205832-c024452afbcd // indirect + golang.org/x/tools v0.0.0-20200826040757-bc8aaaa29e06 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect ) diff --git a/go.sum b/go.sum index 693dbf1..0feaeb0 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,8 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20200425043458-8463f397d07c h1:iHhCR0b26amDCiiO+kBguKZom9aMF+NrFxh9zeKR/XU= golang.org/x/tools v0.0.0-20200425043458-8463f397d07c/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200808161706-5bf02b21f123/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= -golang.org/x/tools v0.0.0-20200823205832-c024452afbcd h1:KNSumuk5eGuQV7zbOrDDZ3MIkwsQr0n5oKiH4oE0/hU= -golang.org/x/tools v0.0.0-20200823205832-c024452afbcd/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= +golang.org/x/tools v0.0.0-20200826040757-bc8aaaa29e06 h1:ChBCbOHeLqK+j+znGPlWCcvx/t2PdxmyPBheVZxXbcc= +golang.org/x/tools v0.0.0-20200826040757-bc8aaaa29e06/go.mod h1:njjCfa9FT2d7l9Bc6FUM5FLjQPp3cFF28FI3qnDFljA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= diff --git a/ipset/ipset_linux.go b/ipset/ipset_linux.go index 4be93dd..1cf1855 100644 --- a/ipset/ipset_linux.go +++ b/ipset/ipset_linux.go @@ -98,13 +98,18 @@ func NewManager(rules []*rule.Config) (*Manager, error) { m := &Manager{fd: fd, lsa: lsa} - // create ipset + // create ipset, avoid redundant. + sets := make(map[string]struct{}) for _, r := range rules { if r.IPSet != "" { - CreateSet(fd, lsa, r.IPSet) + sets[r.IPSet] = struct{}{} } } + for set := range sets { + CreateSet(fd, lsa, set) + } + // init ipset for _, r := range rules { if r.IPSet != "" { diff --git a/main.go b/main.go index db7d692..9f3327e 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,7 @@ func main() { } // global rule proxy - p := rule.NewProxy(conf.rules, strategy.NewProxy(conf.Forward, &conf.StrategyConfig)) + p := rule.NewProxy(conf.rules, strategy.NewProxy("default", conf.Forward, &conf.StrategyConfig)) // ipset manager ipsetM, _ := ipset.NewManager(conf.rules) diff --git a/proxy/http/server.go b/proxy/http/server.go index 0b2f703..ab0e70d 100644 --- a/proxy/http/server.go +++ b/proxy/http/server.go @@ -130,10 +130,7 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) { // copy the left request bytes to remote server. eg. length specificed or chunked body. go func() { if _, err := c.Reader().Peek(1); err == nil { - b := pool.GetBuffer(conn.TCPBufSize) - io.CopyBuffer(rc, c, b) - pool.PutBuffer(b) - + conn.Copy(rc, c) rc.SetDeadline(time.Now()) c.SetDeadline(time.Now()) } @@ -167,7 +164,5 @@ func (s *HTTP) servHTTP(req *request, c *conn.Conn) { log.F("[http] %s <-> %s", c.RemoteAddr(), req.target) c.Write(buf.Bytes()) - b := pool.GetBuffer(conn.TCPBufSize) - io.CopyBuffer(c, r, b) - pool.PutBuffer(b) + conn.Copy(c, r) } diff --git a/rule/rule.go b/rule/rule.go index 07b3cdc..74d76b3 100644 --- a/rule/rule.go +++ b/rule/rule.go @@ -25,7 +25,7 @@ func NewProxy(rules []*Config, proxy *strategy.Proxy) *Proxy { rd := &Proxy{proxy: proxy} for _, r := range rules { - sd := strategy.NewProxy(r.Forward, &r.StrategyConfig) + sd := strategy.NewProxy(r.Name, r.Forward, &r.StrategyConfig) rd.proxies = append(rd.proxies, sd) for _, domain := range r.Domain { diff --git a/strategy/strategy.go b/strategy/strategy.go index 22d5879..68d9674 100644 --- a/strategy/strategy.go +++ b/strategy/strategy.go @@ -47,7 +47,7 @@ type Proxy struct { } // NewProxy returns a new strategy proxy. -func NewProxy(s []string, c *Config) *Proxy { +func NewProxy(name string, s []string, c *Config) *Proxy { var fwdrs []*Forwarder for _, chain := range s { fwdr, err := ForwarderFromURL(chain, c.IntFace, @@ -66,11 +66,11 @@ func NewProxy(s []string, c *Config) *Proxy { c.Strategy = "rr" } - return newProxy(fwdrs, c) + return newProxy(name, fwdrs, c) } // newProxy returns a new Proxy. -func newProxy(fwdrs []*Forwarder, c *Config) *Proxy { +func newProxy(name string, fwdrs []*Forwarder, c *Config) *Proxy { p := &Proxy{fwdrs: fwdrs, config: c} sort.Sort(p.fwdrs) @@ -83,19 +83,19 @@ func newProxy(fwdrs []*Forwarder, c *Config) *Proxy { switch c.Strategy { case "rr": p.next = p.scheduleRR - log.F("[strategy] forward to remote servers in round robin mode.") + log.F("[strategy] %s: forward in round robin mode.", name) case "ha": p.next = p.scheduleHA - log.F("[strategy] forward to remote servers in high availability mode.") + log.F("[strategy] %s: forward in high availability mode.", name) case "lha": p.next = p.scheduleLHA - log.F("[strategy] forward to remote servers in latency based high availability mode.") + log.F("[strategy] %s: forward in latency based high availability mode.", name) case "dh": p.next = p.scheduleDH - log.F("[strategy] forward to remote servers in destination hashing mode.") + log.F("[strategy] %s: forward in destination hashing mode.", name) default: p.next = p.scheduleRR - log.F("[strategy] not supported forward mode '%s', use round robin mode.", c.Strategy) + log.F("[strategy] %s: not supported forward mode '%s', use round robin mode.", name, c.Strategy) } for _, f := range fwdrs {