// MIT License @LiamHaworth // https://github.com/LiamHaworth/go-tproxy/blob/master/tproxy_udp.go package tproxy import ( "bytes" "encoding/binary" "fmt" "net" "os" "strconv" "syscall" "unsafe" ) var nativeEndian binary.ByteOrder func init() { buf := [2]byte{} *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD) switch buf { case [2]byte{0xCD, 0xAB}: nativeEndian = binary.LittleEndian case [2]byte{0xAB, 0xCD}: nativeEndian = binary.BigEndian default: panic("Could not determine native endianness.") } } // ListenUDP will construct a new UDP listener // socket with the Linux IP_TRANSPARENT option // set on the underlying socket func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) { listener, err := net.ListenUDP(network, laddr) if err != nil { return nil, err } fileDescriptorSource, err := listener.File() if err != nil { return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("get file descriptor: %s", err)} } defer fileDescriptorSource.Close() fileDescriptor := int(fileDescriptorSource.Fd()) if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil { return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)} } if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_RECVORIGDSTADDR, 1); err != nil { return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: laddr, Err: fmt.Errorf("set socket option: IP_RECVORIGDSTADDR: %s", err)} } return listener, nil } // ReadFromUDP reads a UDP packet from c, copying the payload into b. // It returns the number of bytes copied into b and the return address // that was on the packet. // // Out-of-band data is also read in so that the original destination // address can be identified and parsed. func ReadFromUDP(conn *net.UDPConn, b []byte) (int, *net.UDPAddr, *net.UDPAddr, error) { oob := make([]byte, 1024) n, oobn, _, addr, err := conn.ReadMsgUDP(b, oob) if err != nil { return 0, nil, nil, err } msgs, err := syscall.ParseSocketControlMessage(oob[:oobn]) if err != nil { return 0, nil, nil, fmt.Errorf("parsing socket control message: %s", err) } var originalDst *net.UDPAddr for _, msg := range msgs { if msg.Header.Level == syscall.SOL_IP && msg.Header.Type == syscall.IP_RECVORIGDSTADDR { originalDstRaw := &syscall.RawSockaddrInet4{} if err = binary.Read(bytes.NewReader(msg.Data), nativeEndian, originalDstRaw); err != nil { return 0, nil, nil, fmt.Errorf("reading original destination address: %s", err) } switch originalDstRaw.Family { case syscall.AF_INET: pp := (*syscall.RawSockaddrInet4)(unsafe.Pointer(originalDstRaw)) p := (*[2]byte)(unsafe.Pointer(&pp.Port)) originalDst = &net.UDPAddr{ IP: net.IPv4(pp.Addr[0], pp.Addr[1], pp.Addr[2], pp.Addr[3]), Port: int(p[0])<<8 + int(p[1]), } case syscall.AF_INET6: pp := (*syscall.RawSockaddrInet6)(unsafe.Pointer(originalDstRaw)) p := (*[2]byte)(unsafe.Pointer(&pp.Port)) originalDst = &net.UDPAddr{ IP: net.IP(pp.Addr[:]), Port: int(p[0])<<8 + int(p[1]), Zone: strconv.Itoa(int(pp.Scope_id)), } default: return 0, nil, nil, fmt.Errorf("original destination is an unsupported network family") } } } if originalDst == nil { return 0, nil, nil, fmt.Errorf("unable to obtain original destination: %s", err) } return n, addr, originalDst, nil } // DialUDP connects to the remote address raddr on the network net, // which must be "udp", "udp4", or "udp6". If laddr is not nil, it is // used as the local address for the connection. func DialUDP(network string, laddr *net.UDPAddr, raddr *net.UDPAddr) (*net.UDPConn, error) { remoteSocketAddress, err := udpAddrToSocketAddr(raddr) if err != nil { return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build destination socket address: %s", err)} } localSocketAddress, err := udpAddrToSocketAddr(laddr) if err != nil { return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("build local socket address: %s", err)} } fileDescriptor, err := syscall.Socket(udpAddrFamily(network, laddr, raddr), syscall.SOCK_DGRAM, 0) if err != nil { return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket open: %s", err)} } if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { syscall.Close(fileDescriptor) return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: SO_REUSEADDR: %s", err)} } if err = syscall.SetsockoptInt(fileDescriptor, syscall.SOL_IP, syscall.IP_TRANSPARENT, 1); err != nil { syscall.Close(fileDescriptor) return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("set socket option: IP_TRANSPARENT: %s", err)} } if err = syscall.Bind(fileDescriptor, localSocketAddress); err != nil { syscall.Close(fileDescriptor) return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket bind: %s", err)} } if err = syscall.Connect(fileDescriptor, remoteSocketAddress); err != nil { syscall.Close(fileDescriptor) return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("socket connect: %s", err)} } fdFile := os.NewFile(uintptr(fileDescriptor), fmt.Sprintf("net-udp-dial-%s", raddr.String())) defer fdFile.Close() remoteConn, err := net.FileConn(fdFile) if err != nil { syscall.Close(fileDescriptor) return nil, &net.OpError{Op: "dial", Err: fmt.Errorf("convert file descriptor to connection: %s", err)} } return remoteConn.(*net.UDPConn), nil } // udpAddToSockerAddr will convert a UDPAddr // into a Sockaddr that may be used when // connecting and binding sockets func udpAddrToSocketAddr(addr *net.UDPAddr) (syscall.Sockaddr, error) { switch { case addr.IP.To4() != nil: ip := [4]byte{} copy(ip[:], addr.IP.To4()) return &syscall.SockaddrInet4{Addr: ip, Port: addr.Port}, nil default: ip := [16]byte{} copy(ip[:], addr.IP.To16()) zoneID, err := strconv.ParseUint(addr.Zone, 10, 32) if err != nil { return nil, err } return &syscall.SockaddrInet6{Addr: ip, Port: addr.Port, ZoneId: uint32(zoneID)}, nil } } // udpAddrFamily will attempt to work // out the address family based on the // network and UDP addresses func udpAddrFamily(net string, laddr, raddr *net.UDPAddr) int { switch net[len(net)-1] { case '4': return syscall.AF_INET case '6': return syscall.AF_INET6 } if (laddr == nil || laddr.IP.To4() != nil) && (raddr == nil || laddr.IP.To4() != nil) { return syscall.AF_INET } return syscall.AF_INET6 }