backtrace/bk/trace_common.go
2025-04-08 13:54:11 +00:00

631 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package backtrace
import (
"context"
"errors"
"fmt"
"net"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
. "github.com/oneclickvirt/defaultset"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// DefaultConfig is the default configuration for Tracer.
var DefaultConfig = Config{
Delay: 50 * time.Millisecond,
Timeout: 500 * time.Millisecond,
MaxHops: 15,
Count: 1,
Networks: []string{"ip4:icmp", "ip4:ip", "ip6:ipv6-icmp", "ip6:ip"},
}
// DefaultTracer is a tracer with DefaultConfig.
var DefaultTracer = &Tracer{
Config: DefaultConfig,
}
// Config is a configuration for Tracer.
type Config struct {
Delay time.Duration
Timeout time.Duration
MaxHops int
Count int
Networks []string
Addr *net.IPAddr
}
// Tracer is a traceroute tool based on raw IP packets.
// It can handle multiple sessions simultaneously.
type Tracer struct {
Config
once sync.Once
conn *net.IPConn // Ipv4连接
ipv6conn *ipv6.PacketConn // IPv6连接
err error
mu sync.RWMutex
sess map[string][]*Session
seq uint32
}
// Trace starts sending IP packets increasing TTL until MaxHops and calls h for each reply.
func (t *Tracer) Trace(ctx context.Context, ip net.IP, h func(reply *Reply)) error {
sess, err := t.NewSession(ip)
if err != nil {
return err
}
defer sess.Close()
delay := time.NewTicker(t.Delay)
defer delay.Stop()
max := t.MaxHops
for n := 0; n < t.Count; n++ {
for ttl := 1; ttl <= t.MaxHops && ttl <= max; ttl++ {
err = sess.Ping(ttl)
if err != nil {
return err
}
select {
case <-delay.C:
case r := <-sess.Receive():
if max > r.Hops && ip.Equal(r.IP) {
max = r.Hops
}
h(r)
case <-ctx.Done():
return ctx.Err()
}
}
}
if sess.isDone(max) {
return nil
}
deadline := time.After(t.Timeout)
for {
select {
case r := <-sess.Receive():
if max > r.Hops && ip.Equal(r.IP) {
max = r.Hops
}
h(r)
if sess.isDone(max) {
return nil
}
case <-deadline:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}
// NewSession returns new tracer session.
func (t *Tracer) NewSession(ip net.IP) (*Session, error) {
t.once.Do(t.init)
if t.err != nil {
return nil, t.err
}
return newSession(t, shortIP(ip)), nil
}
func (t *Tracer) init() {
// 初始化IPv4连接
for _, network := range t.Networks {
if strings.HasPrefix(network, "ip4") {
t.conn, t.err = t.listen(network, t.Addr)
if t.err == nil {
go t.serve(t.conn)
break
}
}
}
// 初始化IPv6连接
for _, network := range t.Networks {
if strings.HasPrefix(network, "ip6") {
conn, err := net.ListenIP(network, t.Addr)
if err == nil {
t.ipv6conn = ipv6.NewPacketConn(conn)
err = t.ipv6conn.SetControlMessage(ipv6.FlagHopLimit|ipv6.FlagSrc|ipv6.FlagDst|ipv6.FlagInterface, true)
if err != nil {
if EnableLoger {
InitLogger()
defer Logger.Sync()
Logger.Info("设置IPv6控制消息失败: " + err.Error())
}
t.ipv6conn.Close()
continue
}
go t.serveIPv6(t.ipv6conn)
break
}
}
}
}
// Close closes listening socket.
// Tracer can not be used after Close is called.
func (t *Tracer) Close() {
t.mu.Lock()
defer t.mu.Unlock()
if t.conn != nil {
t.conn.Close()
}
if t.ipv6conn != nil {
t.ipv6conn.Close()
}
}
func (t *Tracer) serve(conn *net.IPConn) error {
defer conn.Close()
buf := make([]byte, 1500)
for {
n, from, err := conn.ReadFromIP(buf)
if err != nil {
return err
}
err = t.serveData(from.IP, buf[:n])
if err != nil {
continue
}
}
}
func (t *Tracer) serveData(from net.IP, b []byte) error {
if from.To4() == nil {
// IPv6处理
msg, err := icmp.ParseMessage(ProtocolIPv6ICMP, b)
if err != nil {
if EnableLoger {
Logger.Warn("解析IPv6 ICMP消息失败: " + err.Error())
}
return err
}
// 记录所有收到的消息类型,帮助调试
if EnableLoger {
Logger.Info(fmt.Sprintf("收到IPv6 ICMP消息: 类型=%v, 代码=%v", msg.Type, msg.Code))
}
// 处理不同类型的ICMP消息
if msg.Type == ipv6.ICMPTypeEchoReply {
if echo, ok := msg.Body.(*icmp.Echo); ok {
if EnableLoger {
Logger.Info(fmt.Sprintf("处理IPv6回显应答: ID=%d, Seq=%d", echo.ID, echo.Seq))
}
return t.serveReply(from, &packet{from, uint16(echo.ID), 1, time.Now()})
}
} else if msg.Type == ipv6.ICMPTypeTimeExceeded {
// 时间超过这是traceroute的关键响应类型
b = getReplyData(msg)
if len(b) < ipv6.HeaderLen {
if EnableLoger {
Logger.Warn("IPv6时间超过消息太短")
}
return errMessageTooShort
}
// 解析原始IPv6包头
if b[0]>>4 == ipv6.Version {
ip, err := ipv6.ParseHeader(b)
if err != nil {
if EnableLoger {
Logger.Warn("解析IPv6头部失败: " + err.Error())
}
return err
}
if EnableLoger {
Logger.Info(fmt.Sprintf("处理IPv6时间超过: 目标=%v, FlowLabel=%d, HopLimit=%d",
ip.Dst, ip.FlowLabel, ip.HopLimit))
}
return t.serveReply(ip.Dst, &packet{from, uint16(ip.FlowLabel), ip.HopLimit, time.Now()})
}
}
} else {
// 原有的IPv4处理逻辑
msg, err := icmp.ParseMessage(ProtocolICMP, b)
if err != nil {
return err
}
if msg.Type == ipv4.ICMPTypeEchoReply {
echo := msg.Body.(*icmp.Echo)
return t.serveReply(from, &packet{from, uint16(echo.ID), 1, time.Now()})
}
b = getReplyData(msg)
if len(b) < ipv4.HeaderLen {
return errMessageTooShort
}
switch b[0] >> 4 {
case ipv4.Version:
ip, err := ipv4.ParseHeader(b)
if err != nil {
return err
}
return t.serveReply(ip.Dst, &packet{from, uint16(ip.ID), ip.TTL, time.Now()})
default:
return errUnsupportedProtocol
}
}
return nil
}
func (t *Tracer) sendRequest(dst net.IP, ttl int) (*packet, error) {
id := uint16(atomic.AddUint32(&t.seq, 1))
var b []byte
req := &packet{dst, id, ttl, time.Now()}
if dst.To4() == nil {
// IPv6
b := newPacketV6(id, dst, ttl)
if t.ipv6conn != nil {
cm := &ipv6.ControlMessage{
HopLimit: ttl,
}
_, err := t.ipv6conn.WriteTo(b, cm, &net.IPAddr{IP: dst})
if err != nil {
if EnableLoger {
InitLogger()
defer Logger.Sync()
Logger.Info("发送IPv6请求失败: " + err.Error())
}
return nil, err
}
return req, nil
}
return nil, errors.New("IPv6连接不可用")
} else {
// IPv4
b = newPacketV4(id, dst, ttl)
_, err := t.conn.WriteToIP(b, &net.IPAddr{IP: dst})
if err != nil {
return nil, err
}
return req, nil
}
}
func (t *Tracer) addSession(s *Session) {
t.mu.Lock()
defer t.mu.Unlock()
if t.sess == nil {
t.sess = make(map[string][]*Session)
}
t.sess[string(s.ip)] = append(t.sess[string(s.ip)], s)
}
func (t *Tracer) removeSession(s *Session) {
t.mu.Lock()
defer t.mu.Unlock()
a := t.sess[string(s.ip)]
for i, it := range a {
if it == s {
t.sess[string(s.ip)] = append(a[:i], a[i+1:]...)
return
}
}
}
// func (t *Tracer) serveReply(dst net.IP, res *packet) error {
// t.mu.RLock()
// defer t.mu.RUnlock()
// a := t.sess[string(shortIP(dst))]
// for _, s := range a {
// s.handle(res)
// }
// return nil
// }
func (t *Tracer) serveReply(dst net.IP, res *packet) error {
if EnableLoger {
Logger.Info(fmt.Sprintf("处理回复: 目标=%v, 来源=%v, ID=%d, TTL=%d",
dst, res.IP, res.ID, res.TTL))
}
// 确保使用正确的IP格式进行查找
shortDst := shortIP(dst)
t.mu.RLock()
defer t.mu.RUnlock()
// 调试输出会话信息
if EnableLoger && len(t.sess) > 0 {
for ip, sessions := range t.sess {
Logger.Info(fmt.Sprintf("会话信息: IP=%v, 会话数=%d",
net.IP([]byte(ip)), len(sessions)))
}
}
// 查找对应的会话
a := t.sess[string(shortDst)]
if len(a) == 0 && EnableLoger {
Logger.Warn(fmt.Sprintf("找不到目标IP=%v的会话", dst))
}
for _, s := range a {
if EnableLoger {
Logger.Info(fmt.Sprintf("处理会话响应: 会话目标=%v", s.ip))
}
s.handle(res)
}
return nil
}
// Session is a tracer session.
type Session struct {
t *Tracer
ip net.IP
ch chan *Reply
mu sync.RWMutex
probes []*packet
}
// NewSession returns new session.
func NewSession(ip net.IP) (*Session, error) {
return DefaultTracer.NewSession(ip)
}
func newSession(t *Tracer, ip net.IP) *Session {
s := &Session{
t: t,
ip: ip,
ch: make(chan *Reply, 64),
}
t.addSession(s)
return s
}
// Ping sends single ICMP packet with specified TTL.
func (s *Session) Ping(ttl int) error {
req, err := s.t.sendRequest(s.ip, ttl+1)
if err != nil {
return err
}
s.mu.Lock()
s.probes = append(s.probes, req)
s.mu.Unlock()
return nil
}
// Receive returns channel to receive ICMP replies.
func (s *Session) Receive() <-chan *Reply {
return s.ch
}
// isDone returns true if session does not have unresponsed requests with TTL <= ttl.
func (s *Session) isDone(ttl int) bool {
s.mu.RLock()
defer s.mu.RUnlock()
for _, r := range s.probes {
if r.TTL <= ttl {
return false
}
}
return true
}
// func (s *Session) handle(res *packet) {
// now := res.Time
// n := 0
// var req *packet
// s.mu.Lock()
// for _, r := range s.probes {
// if now.Sub(r.Time) > s.t.Timeout {
// continue
// }
// if r.ID == res.ID {
// req = r
// continue
// }
// s.probes[n] = r
// n++
// }
// s.probes = s.probes[:n]
// s.mu.Unlock()
// if req == nil {
// return
// }
// hops := req.TTL - res.TTL + 1
// if hops < 1 {
// hops = 1
// }
// select {
// case s.ch <- &Reply{
// IP: res.IP,
// RTT: res.Time.Sub(req.Time),
// Hops: hops,
// }:
// default:
// }
// }
func (s *Session) handle(res *packet) {
now := res.Time
n := 0
var req *packet
if EnableLoger {
Logger.Info(fmt.Sprintf("处理会话响应: 会话目标=%v, 响应源=%v, ID=%d, TTL=%d",
s.ip, res.IP, res.ID, res.TTL))
}
s.mu.Lock()
// 打印出所有待处理的探测包
if EnableLoger && len(s.probes) > 0 {
Logger.Info(fmt.Sprintf("当前会话有 %d 个待处理的探测包", len(s.probes)))
for i, probe := range s.probes {
Logger.Info(fmt.Sprintf("探测包 #%d: ID=%d, TTL=%d, 时间=%v",
i, probe.ID, probe.TTL, probe.Time))
}
}
// 查找匹配的请求包
for _, r := range s.probes {
if now.Sub(r.Time) > s.t.Timeout {
if EnableLoger {
Logger.Info(fmt.Sprintf("探测包超时: ID=%d, TTL=%d", r.ID, r.TTL))
}
continue
}
// 对于IPv6可能需要松散匹配
if r.ID == res.ID || (res.IP.To4() == nil && res.ID == 0) {
if EnableLoger {
Logger.Info(fmt.Sprintf("找到匹配的探测包: ID=%d, TTL=%d", r.ID, r.TTL))
}
req = r
continue
}
s.probes[n] = r
n++
}
s.probes = s.probes[:n]
s.mu.Unlock()
if req == nil {
if EnableLoger {
Logger.Warn(fmt.Sprintf("未找到匹配的探测包: 响应ID=%d", res.ID))
}
return
}
hops := req.TTL - res.TTL + 1
if hops < 1 {
hops = 1
}
if EnableLoger {
Logger.Info(fmt.Sprintf("创建响应: IP=%v, RTT=%v, Hops=%d",
res.IP, res.Time.Sub(req.Time), hops))
}
select {
case s.ch <- &Reply{
IP: res.IP,
RTT: res.Time.Sub(req.Time),
Hops: hops,
}:
if EnableLoger {
Logger.Info("响应已发送到通道")
}
default:
if EnableLoger {
Logger.Warn("发送响应到通道失败,通道已满")
}
}
}
// Close closes tracer session.
func (s *Session) Close() {
s.t.removeSession(s)
}
type packet struct {
IP net.IP
ID uint16
TTL int
Time time.Time
}
func shortIP(ip net.IP) net.IP {
if v := ip.To4(); v != nil {
return v
}
return ip
}
func getReplyData(msg *icmp.Message) []byte {
switch b := msg.Body.(type) {
case *icmp.TimeExceeded:
return b.Data
case *icmp.DstUnreach:
return b.Data
case *icmp.ParamProb:
return b.Data
}
return nil
}
var (
errMessageTooShort = errors.New("message too short")
errUnsupportedProtocol = errors.New("unsupported protocol")
errNoReplyData = errors.New("no reply data")
)
// IANA Assigned Internet Protocol Numbers
const (
ProtocolICMP = 1
ProtocolTCP = 6
ProtocolUDP = 17
ProtocolIPv6ICMP = 58
)
// Reply is a reply packet.
type Reply struct {
IP net.IP
RTT time.Duration
Hops int
}
// Node is a detected network node.
type Node struct {
IP net.IP
RTT []time.Duration
}
// Hop is a set of detected nodes.
type Hop struct {
Nodes []*Node
Distance int
}
// Add adds node from r.
func (h *Hop) Add(r *Reply) *Node {
var node *Node
for _, it := range h.Nodes {
if it.IP.Equal(r.IP) {
node = it
break
}
}
if node == nil {
node = &Node{IP: r.IP}
h.Nodes = append(h.Nodes, node)
}
node.RTT = append(node.RTT, r.RTT)
return node
}
// Trace is a simple traceroute tool using DefaultTracer.
func Trace(ip net.IP) ([]*Hop, error) {
hops := make([]*Hop, 0, DefaultTracer.MaxHops)
touch := func(dist int) *Hop {
for _, h := range hops {
if h.Distance == dist {
return h
}
}
h := &Hop{Distance: dist}
hops = append(hops, h)
return h
}
err := DefaultTracer.Trace(context.Background(), ip, func(r *Reply) {
touch(r.Hops).Add(r)
})
if err != nil && err != context.DeadlineExceeded {
return nil, err
}
sort.Slice(hops, func(i, j int) bool {
return hops[i].Distance < hops[j].Distance
})
last := len(hops) - 1
for i := last; i >= 0; i-- {
h := hops[i]
if len(h.Nodes) == 1 && ip.Equal(h.Nodes[0].IP) {
continue
}
if i == last {
break
}
i++
node := hops[i].Nodes[0]
i++
for _, it := range hops[i:] {
node.RTT = append(node.RTT, it.Nodes[0].RTT...)
}
hops = hops[:i]
break
}
return hops, nil
}