// Package nat provides NAT type detection via UDP and TCP STUN. package nat import ( "encoding/json" "fmt" "net" "time" "github.com/openp2p-cn/inp2p/pkg/protocol" ) const ( detectTimeout = 5 * time.Second ) // DetectResult holds the NAT detection outcome. type DetectResult struct { Type protocol.NATType PublicIP string Port1 int // external port seen on STUN server port 1 Port2 int // external port seen on STUN server port 2 LocalPort int // local UDP port used for detection (for punch bind) } // stunReq is sent to the STUN endpoint. type stunReq struct { ID int `json:"id"` } // stunRsp is received from the STUN endpoint. type stunRsp struct { IP string `json:"ip"` Port int `json:"port"` ID int `json:"id"` } // DetectUDP sends probes from the same local port to two different server // UDP ports. If both return the same external port → Cone; different → Symmetric. func DetectUDP(serverIP string, port1, port2 int) DetectResult { result := DetectResult{Type: protocol.NATUnknown} // Bind a single local UDP port conn, err := net.ListenPacket("udp", ":0") if err != nil { return result } defer conn.Close() if ua, ok := conn.LocalAddr().(*net.UDPAddr); ok { result.LocalPort = ua.Port } r1, err1 := probeUDP(conn, serverIP, port1, 1) r2, err2 := probeUDP(conn, serverIP, port2, 2) if err1 != nil || err2 != nil { return result // timeout → NATUnknown } result.PublicIP = r1.IP result.Port1 = r1.Port result.Port2 = r2.Port if r1.Port == r2.Port { result.Type = protocol.NATCone } else { result.Type = protocol.NATSymmetric } // Check if public IP equals local IP → no NAT localIP := conn.LocalAddr().(*net.UDPAddr).IP.String() if localIP == r1.IP || r1.IP == "" { // might be public } return result } func probeUDP(conn net.PacketConn, serverIP string, port, id int) (stunRsp, error) { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", serverIP, port)) if err != nil { return stunRsp{}, err } frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id}) conn.SetWriteDeadline(time.Now().Add(detectTimeout)) if _, err := conn.WriteTo(frame, addr); err != nil { return stunRsp{}, err } buf := make([]byte, 1024) conn.SetReadDeadline(time.Now().Add(detectTimeout)) n, _, err := conn.ReadFrom(buf) if err != nil { return stunRsp{}, err } var rsp stunRsp if n > protocol.HeaderSize { json.Unmarshal(buf[protocol.HeaderSize:n], &rsp) } return rsp, nil } // DetectTCP connects to two different TCP ports on the server and compares // the observed external port. This is the fallback when UDP is blocked. func DetectTCP(serverIP string, port1, port2 int) DetectResult { result := DetectResult{Type: protocol.NATUnknown} r1, err1 := probeTCP(serverIP, port1, 1) r2, err2 := probeTCP(serverIP, port2, 2) if err1 != nil || err2 != nil { return result } result.PublicIP = r1.IP result.Port1 = r1.Port result.Port2 = r2.Port if r1.Port == r2.Port { result.Type = protocol.NATCone } else { result.Type = protocol.NATSymmetric } return result } func probeTCP(serverIP string, port, id int) (stunRsp, error) { conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", serverIP, port), detectTimeout) if err != nil { return stunRsp{}, err } defer conn.Close() frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id}) conn.SetWriteDeadline(time.Now().Add(detectTimeout)) if _, err := conn.Write(frame); err != nil { return stunRsp{}, err } buf := make([]byte, 1024) conn.SetReadDeadline(time.Now().Add(detectTimeout)) n, err := conn.Read(buf) if err != nil { return stunRsp{}, err } var rsp stunRsp if n > protocol.HeaderSize { json.Unmarshal(buf[protocol.HeaderSize:n], &rsp) } return rsp, nil } // Detect runs UDP detection first, falls back to TCP if UDP is blocked. func Detect(serverIP string, udpPort1, udpPort2, tcpPort1, tcpPort2 int) DetectResult { result := DetectUDP(serverIP, udpPort1, udpPort2) if result.Type != protocol.NATUnknown { return result } // UDP blocked, fallback to TCP return DetectTCP(serverIP, tcpPort1, tcpPort2) } // ─── Server-side STUN handler ─── // ServeUDPSTUN listens on a UDP port and echoes back the sender's observed IP:port. func ServeUDPSTUN(port int, quit <-chan struct{}) error { addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port)) if err != nil { return err } conn, err := net.ListenUDP("udp", addr) if err != nil { return err } defer conn.Close() go func() { <-quit conn.Close() }() buf := make([]byte, 1024) for { n, remoteAddr, err := conn.ReadFromUDP(buf) if err != nil { select { case <-quit: return nil default: continue } } // Parse request var req stunReq if n > protocol.HeaderSize { json.Unmarshal(buf[protocol.HeaderSize:n], &req) } // Reply with observed address rsp := stunRsp{ IP: remoteAddr.IP.String(), Port: remoteAddr.Port, ID: req.ID, } frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp) conn.WriteToUDP(frame, remoteAddr) } } // ServeTCPSTUN listens on a TCP port. Each connection: read one req, write one rsp with observed addr. func ServeTCPSTUN(port int, quit <-chan struct{}) error { ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return err } defer ln.Close() go func() { <-quit ln.Close() }() for { conn, err := ln.Accept() if err != nil { select { case <-quit: return nil default: continue } } go func(c net.Conn) { defer c.Close() remoteAddr := c.RemoteAddr().(*net.TCPAddr) buf := make([]byte, 1024) c.SetReadDeadline(time.Now().Add(10 * time.Second)) n, err := c.Read(buf) if err != nil { return } var req stunReq if n > protocol.HeaderSize { json.Unmarshal(buf[protocol.HeaderSize:n], &req) } rsp := stunRsp{ IP: remoteAddr.IP.String(), Port: remoteAddr.Port, ID: req.ID, } frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp) c.SetWriteDeadline(time.Now().Add(5 * time.Second)) c.Write(frame) }(conn) } }