feat: INP2P v0.1.0 — complete P2P tunneling system
Core modules (M1-M6): - pkg/protocol: message format, encoding, NAT type enums - pkg/config: server/client config structs, env vars, validation - pkg/auth: CRC64 token, TOTP gen/verify, one-time relay tokens - pkg/nat: UDP/TCP STUN client and server - pkg/signal: WSS message dispatch, sync request/response - pkg/punch: UDP/TCP hole punching + priority chain - pkg/mux: stream multiplexer (7B frame: StreamID+Flags+Len) - pkg/tunnel: mux-based port forwarding with stats - pkg/relay: relay manager with TOTP auth + session bridging - internal/server: signaling server (login/heartbeat/report/coordinator) - internal/client: client (NAT detect/login/punch/relay/reconnect) - cmd/inp2ps + cmd/inp2pc: main entrypoints with graceful shutdown All tests pass: 16 tests across 5 packages Code: 3559 lines core + 861 lines tests = 19 source files
This commit is contained in:
92
pkg/auth/auth.go
Normal file
92
pkg/auth/auth.go
Normal file
@@ -0,0 +1,92 @@
|
||||
// Package auth provides TOTP and token authentication for INP2P.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/crc64"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// TOTPStep is the time window in seconds for TOTP validity.
|
||||
// A code is valid for ±1 step to allow for clock drift.
|
||||
TOTPStep int64 = 60
|
||||
)
|
||||
|
||||
var crcTable = crc64.MakeTable(crc64.ECMA)
|
||||
|
||||
// MakeToken generates a token from user+password using CRC64.
|
||||
func MakeToken(user, password string) uint64 {
|
||||
return crc64.Checksum([]byte(user+password), crcTable)
|
||||
}
|
||||
|
||||
// GenTOTP generates a TOTP code for relay authentication.
|
||||
func GenTOTP(token uint64, ts int64) uint64 {
|
||||
step := ts / TOTPStep
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint64(buf[:8], token)
|
||||
binary.BigEndian.PutUint64(buf[8:], uint64(step))
|
||||
|
||||
mac := hmac.New(sha256.New, buf[:8])
|
||||
mac.Write(buf[8:])
|
||||
sum := mac.Sum(nil)
|
||||
|
||||
return binary.BigEndian.Uint64(sum[:8])
|
||||
}
|
||||
|
||||
// VerifyTOTP verifies a TOTP code with ±1 step tolerance.
|
||||
func VerifyTOTP(code uint64, token uint64, ts int64) bool {
|
||||
for delta := int64(-1); delta <= 1; delta++ {
|
||||
expected := GenTOTP(token, ts+delta*TOTPStep)
|
||||
if code == expected {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RelayToken generates a one-time relay token signed by the server.
|
||||
// Used for cross-user super relay authentication.
|
||||
type RelayToken struct {
|
||||
SessionID string `json:"sessionID"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
Relay string `json:"relay"`
|
||||
Expires int64 `json:"expires"`
|
||||
Signature []byte `json:"signature"`
|
||||
}
|
||||
|
||||
// SignRelayToken creates a signed one-time relay token.
|
||||
func SignRelayToken(secret []byte, sessionID, from, to, relay string, ttl time.Duration) RelayToken {
|
||||
rt := RelayToken{
|
||||
SessionID: sessionID,
|
||||
From: from,
|
||||
To: to,
|
||||
Relay: relay,
|
||||
Expires: time.Now().Add(ttl).Unix(),
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("%s:%s:%s:%s:%d", rt.SessionID, rt.From, rt.To, rt.Relay, rt.Expires)
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write([]byte(msg))
|
||||
rt.Signature = mac.Sum(nil)
|
||||
|
||||
return rt
|
||||
}
|
||||
|
||||
// VerifyRelayToken validates a signed relay token.
|
||||
func VerifyRelayToken(secret []byte, rt RelayToken) bool {
|
||||
if time.Now().Unix() > rt.Expires {
|
||||
return false
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("%s:%s:%s:%s:%d", rt.SessionID, rt.From, rt.To, rt.Relay, rt.Expires)
|
||||
mac := hmac.New(sha256.New, secret)
|
||||
mac.Write([]byte(msg))
|
||||
expected := mac.Sum(nil)
|
||||
|
||||
return hmac.Equal(rt.Signature, expected)
|
||||
}
|
||||
161
pkg/config/config.go
Normal file
161
pkg/config/config.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Package config provides shared configuration types.
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
const (
|
||||
Version = "0.1.0"
|
||||
|
||||
DefaultWSPort = 27183 // WSS signaling
|
||||
DefaultSTUNUDP1 = 27182 // UDP STUN port 1
|
||||
DefaultSTUNUDP2 = 27183 // UDP STUN port 2
|
||||
DefaultSTUNTCP1 = 27180 // TCP STUN port 1
|
||||
DefaultSTUNTCP2 = 27181 // TCP STUN port 2
|
||||
DefaultWebPort = 10088 // Web console
|
||||
DefaultAPIPort = 10008 // REST API
|
||||
|
||||
DefaultMaxRelayLoad = 20
|
||||
DefaultRelayPort = 27185
|
||||
|
||||
HeartbeatInterval = 30 // seconds
|
||||
HeartbeatTimeout = 90 // seconds — 3x missed heartbeats → offline
|
||||
)
|
||||
|
||||
// ServerConfig holds inp2ps configuration.
|
||||
type ServerConfig struct {
|
||||
WSPort int `json:"wsPort"`
|
||||
STUNUDP1 int `json:"stunUDP1"`
|
||||
STUNUDP2 int `json:"stunUDP2"`
|
||||
STUNTCP1 int `json:"stunTCP1"`
|
||||
STUNTCP2 int `json:"stunTCP2"`
|
||||
WebPort int `json:"webPort"`
|
||||
APIPort int `json:"apiPort"`
|
||||
DBPath string `json:"dbPath"`
|
||||
CertFile string `json:"certFile"`
|
||||
KeyFile string `json:"keyFile"`
|
||||
LogLevel int `json:"logLevel"` // 0=debug, 1=info, 2=warn, 3=error
|
||||
Token uint64 `json:"token"` // master token for auth
|
||||
JWTKey string `json:"jwtKey"` // auto-generated if empty
|
||||
|
||||
AdminUser string `json:"adminUser"`
|
||||
AdminPass string `json:"adminPass"`
|
||||
}
|
||||
|
||||
func DefaultServerConfig() ServerConfig {
|
||||
return ServerConfig{
|
||||
WSPort: DefaultWSPort,
|
||||
STUNUDP1: DefaultSTUNUDP1,
|
||||
STUNUDP2: DefaultSTUNUDP2,
|
||||
STUNTCP1: DefaultSTUNTCP1,
|
||||
STUNTCP2: DefaultSTUNTCP2,
|
||||
WebPort: DefaultWebPort,
|
||||
APIPort: DefaultAPIPort,
|
||||
DBPath: "inp2ps.db",
|
||||
LogLevel: 1,
|
||||
AdminUser: "admin",
|
||||
AdminPass: "admin123",
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ServerConfig) FillFromEnv() {
|
||||
if v := os.Getenv("INP2PS_WS_PORT"); v != "" {
|
||||
c.WSPort, _ = strconv.Atoi(v)
|
||||
}
|
||||
if v := os.Getenv("INP2PS_WEB_PORT"); v != "" {
|
||||
c.WebPort, _ = strconv.Atoi(v)
|
||||
}
|
||||
if v := os.Getenv("INP2PS_DB_PATH"); v != "" {
|
||||
c.DBPath = v
|
||||
}
|
||||
if v := os.Getenv("INP2PS_TOKEN"); v != "" {
|
||||
c.Token, _ = strconv.ParseUint(v, 10, 64)
|
||||
}
|
||||
if v := os.Getenv("INP2PS_CERT"); v != "" {
|
||||
c.CertFile = v
|
||||
}
|
||||
if v := os.Getenv("INP2PS_KEY"); v != "" {
|
||||
c.KeyFile = v
|
||||
}
|
||||
if c.JWTKey == "" {
|
||||
b := make([]byte, 32)
|
||||
rand.Read(b)
|
||||
c.JWTKey = hex.EncodeToString(b)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ServerConfig) Validate() error {
|
||||
if c.Token == 0 {
|
||||
return fmt.Errorf("token is required (INP2PS_TOKEN or -token)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ClientConfig holds inp2pc configuration.
|
||||
type ClientConfig struct {
|
||||
ServerHost string `json:"serverHost"`
|
||||
ServerPort int `json:"serverPort"`
|
||||
Node string `json:"node"`
|
||||
Token uint64 `json:"token"`
|
||||
User string `json:"user,omitempty"`
|
||||
Insecure bool `json:"insecure"` // skip TLS verify
|
||||
|
||||
// STUN ports (defaults match server defaults)
|
||||
STUNUDP1 int `json:"stunUDP1,omitempty"`
|
||||
STUNUDP2 int `json:"stunUDP2,omitempty"`
|
||||
STUNTCP1 int `json:"stunTCP1,omitempty"`
|
||||
STUNTCP2 int `json:"stunTCP2,omitempty"`
|
||||
|
||||
RelayEnabled bool `json:"relayEnabled"` // --relay
|
||||
SuperRelay bool `json:"superRelay"` // --super
|
||||
RelayPort int `json:"relayPort"`
|
||||
MaxRelayLoad int `json:"maxRelayLoad"`
|
||||
|
||||
ShareBandwidth int `json:"shareBandwidth"` // Mbps
|
||||
LogLevel int `json:"logLevel"`
|
||||
|
||||
Apps []AppConfig `json:"apps"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
AppName string `json:"appName"`
|
||||
Protocol string `json:"protocol"` // tcp, udp
|
||||
SrcPort int `json:"srcPort"`
|
||||
PeerNode string `json:"peerNode"`
|
||||
DstHost string `json:"dstHost"`
|
||||
DstPort int `json:"dstPort"`
|
||||
Enabled bool `json:"enabled"`
|
||||
}
|
||||
|
||||
func DefaultClientConfig() ClientConfig {
|
||||
return ClientConfig{
|
||||
ServerPort: DefaultWSPort,
|
||||
STUNUDP1: DefaultSTUNUDP1,
|
||||
STUNUDP2: DefaultSTUNUDP2,
|
||||
STUNTCP1: DefaultSTUNTCP1,
|
||||
STUNTCP2: DefaultSTUNTCP2,
|
||||
ShareBandwidth: 10,
|
||||
RelayPort: DefaultRelayPort,
|
||||
MaxRelayLoad: DefaultMaxRelayLoad,
|
||||
LogLevel: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConfig) Validate() error {
|
||||
if c.ServerHost == "" {
|
||||
return fmt.Errorf("serverHost is required")
|
||||
}
|
||||
if c.Token == 0 {
|
||||
return fmt.Errorf("token is required")
|
||||
}
|
||||
if c.Node == "" {
|
||||
hostname, _ := os.Hostname()
|
||||
c.Node = hostname
|
||||
}
|
||||
return nil
|
||||
}
|
||||
487
pkg/mux/mux.go
Normal file
487
pkg/mux/mux.go
Normal file
@@ -0,0 +1,487 @@
|
||||
// Package mux provides stream multiplexing over a single net.Conn.
|
||||
//
|
||||
// Wire format per frame:
|
||||
//
|
||||
// StreamID (4B, big-endian)
|
||||
// Flags (1B)
|
||||
// Length (2B, big-endian, max 65535)
|
||||
// Data (Length bytes)
|
||||
//
|
||||
// Total header = 7 bytes.
|
||||
//
|
||||
// Flags:
|
||||
//
|
||||
// 0x01 SYN — open a new stream
|
||||
// 0x02 FIN — close a stream
|
||||
// 0x04 DATA — payload data
|
||||
// 0x08 PING — keepalive (StreamID=0)
|
||||
// 0x10 PONG — keepalive response (StreamID=0)
|
||||
// 0x20 RST — reset/abort a stream
|
||||
package mux
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
headerSize = 7
|
||||
maxPayload = 65535
|
||||
|
||||
FlagSYN byte = 0x01
|
||||
FlagFIN byte = 0x02
|
||||
FlagDATA byte = 0x04
|
||||
FlagPING byte = 0x08
|
||||
FlagPONG byte = 0x10
|
||||
FlagRST byte = 0x20
|
||||
|
||||
defaultWindowSize = 256 * 1024 // 256KB per stream receive buffer
|
||||
pingInterval = 15 * time.Second
|
||||
pingTimeout = 10 * time.Second
|
||||
acceptBacklog = 64
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionClosed = errors.New("mux: session closed")
|
||||
ErrStreamClosed = errors.New("mux: stream closed")
|
||||
ErrStreamReset = errors.New("mux: stream reset by peer")
|
||||
ErrTimeout = errors.New("mux: timeout")
|
||||
ErrAcceptBacklog = errors.New("mux: accept backlog full")
|
||||
)
|
||||
|
||||
// ─── Session ───
|
||||
// A Session multiplexes many Streams over a single underlying net.Conn.
|
||||
|
||||
type Session struct {
|
||||
conn net.Conn
|
||||
streams map[uint32]*Stream
|
||||
mu sync.RWMutex
|
||||
nextID uint32 // client uses odd, server uses even
|
||||
isServer bool
|
||||
acceptCh chan *Stream
|
||||
writeMu sync.Mutex // serialize frame writes
|
||||
closed int32
|
||||
quit chan struct{}
|
||||
once sync.Once
|
||||
|
||||
// stats
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
}
|
||||
|
||||
// NewSession wraps a net.Conn as a mux session.
|
||||
// isServer determines stream ID allocation: server=even, client=odd.
|
||||
func NewSession(conn net.Conn, isServer bool) *Session {
|
||||
s := &Session{
|
||||
conn: conn,
|
||||
streams: make(map[uint32]*Stream),
|
||||
acceptCh: make(chan *Stream, acceptBacklog),
|
||||
quit: make(chan struct{}),
|
||||
isServer: isServer,
|
||||
}
|
||||
if isServer {
|
||||
s.nextID = 2
|
||||
} else {
|
||||
s.nextID = 1
|
||||
}
|
||||
go s.readLoop()
|
||||
go s.pingLoop()
|
||||
return s
|
||||
}
|
||||
|
||||
// Open creates a new outbound stream.
|
||||
func (s *Session) Open() (*Stream, error) {
|
||||
if s.IsClosed() {
|
||||
return nil, ErrSessionClosed
|
||||
}
|
||||
|
||||
id := atomic.AddUint32(&s.nextID, 2) - 2 // increment by 2 to keep odd/even
|
||||
st := newStream(id, s)
|
||||
|
||||
s.mu.Lock()
|
||||
s.streams[id] = st
|
||||
s.mu.Unlock()
|
||||
|
||||
// Send SYN
|
||||
if err := s.writeFrame(id, FlagSYN, nil); err != nil {
|
||||
s.mu.Lock()
|
||||
delete(s.streams, id)
|
||||
s.mu.Unlock()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return st, nil
|
||||
}
|
||||
|
||||
// Accept waits for an inbound stream opened by the remote side.
|
||||
func (s *Session) Accept() (*Stream, error) {
|
||||
select {
|
||||
case st := <-s.acceptCh:
|
||||
return st, nil
|
||||
case <-s.quit:
|
||||
return nil, ErrSessionClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down the session and all streams.
|
||||
func (s *Session) Close() error {
|
||||
s.once.Do(func() {
|
||||
atomic.StoreInt32(&s.closed, 1)
|
||||
close(s.quit)
|
||||
|
||||
s.mu.Lock()
|
||||
for _, st := range s.streams {
|
||||
st.closeLocal()
|
||||
}
|
||||
s.streams = make(map[uint32]*Stream)
|
||||
s.mu.Unlock()
|
||||
|
||||
s.conn.Close()
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsClosed reports if the session is closed.
|
||||
func (s *Session) IsClosed() bool {
|
||||
return atomic.LoadInt32(&s.closed) == 1
|
||||
}
|
||||
|
||||
// NumStreams returns active stream count.
|
||||
func (s *Session) NumStreams() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.streams)
|
||||
}
|
||||
|
||||
// ─── Frame I/O ───
|
||||
|
||||
func (s *Session) writeFrame(streamID uint32, flags byte, data []byte) error {
|
||||
if len(data) > maxPayload {
|
||||
return fmt.Errorf("mux: payload too large: %d > %d", len(data), maxPayload)
|
||||
}
|
||||
|
||||
hdr := make([]byte, headerSize)
|
||||
binary.BigEndian.PutUint32(hdr[0:4], streamID)
|
||||
hdr[4] = flags
|
||||
binary.BigEndian.PutUint16(hdr[5:7], uint16(len(data)))
|
||||
|
||||
s.writeMu.Lock()
|
||||
defer s.writeMu.Unlock()
|
||||
|
||||
s.conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
|
||||
if _, err := s.conn.Write(hdr); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) > 0 {
|
||||
if _, err := s.conn.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt64(&s.BytesSent, int64(headerSize+len(data)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) readLoop() {
|
||||
hdr := make([]byte, headerSize)
|
||||
for {
|
||||
if _, err := io.ReadFull(s.conn, hdr); err != nil {
|
||||
if !s.IsClosed() {
|
||||
log.Printf("[mux] read header error: %v", err)
|
||||
}
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
|
||||
streamID := binary.BigEndian.Uint32(hdr[0:4])
|
||||
flags := hdr[4]
|
||||
length := binary.BigEndian.Uint16(hdr[5:7])
|
||||
|
||||
var data []byte
|
||||
if length > 0 {
|
||||
data = make([]byte, length)
|
||||
if _, err := io.ReadFull(s.conn, data); err != nil {
|
||||
if !s.IsClosed() {
|
||||
log.Printf("[mux] read data error: %v", err)
|
||||
}
|
||||
s.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
atomic.AddInt64(&s.BytesReceived, int64(headerSize+int(length)))
|
||||
s.handleFrame(streamID, flags, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) handleFrame(streamID uint32, flags byte, data []byte) {
|
||||
// Ping/Pong on StreamID 0
|
||||
if flags&FlagPING != 0 {
|
||||
s.writeFrame(0, FlagPONG, nil)
|
||||
return
|
||||
}
|
||||
if flags&FlagPONG != 0 {
|
||||
return // pong received, connection alive
|
||||
}
|
||||
|
||||
// SYN — new inbound stream
|
||||
if flags&FlagSYN != 0 {
|
||||
st := newStream(streamID, s)
|
||||
s.mu.Lock()
|
||||
s.streams[streamID] = st
|
||||
s.mu.Unlock()
|
||||
|
||||
select {
|
||||
case s.acceptCh <- st:
|
||||
default:
|
||||
log.Printf("[mux] accept backlog full, dropping stream %d", streamID)
|
||||
s.writeFrame(streamID, FlagRST, nil)
|
||||
s.mu.Lock()
|
||||
delete(s.streams, streamID)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Find the stream
|
||||
s.mu.RLock()
|
||||
st, ok := s.streams[streamID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
if flags&FlagRST == 0 {
|
||||
s.writeFrame(streamID, FlagRST, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// RST
|
||||
if flags&FlagRST != 0 {
|
||||
st.resetByPeer()
|
||||
s.mu.Lock()
|
||||
delete(s.streams, streamID)
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// DATA
|
||||
if flags&FlagDATA != 0 && len(data) > 0 {
|
||||
st.pushData(data)
|
||||
}
|
||||
|
||||
// FIN
|
||||
if flags&FlagFIN != 0 {
|
||||
st.finByPeer()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) removeStream(id uint32) {
|
||||
s.mu.Lock()
|
||||
delete(s.streams, id)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Session) pingLoop() {
|
||||
ticker := time.NewTicker(pingInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.writeFrame(0, FlagPING, nil); err != nil {
|
||||
return
|
||||
}
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Stream ───
|
||||
// A Stream is a virtual connection within a Session, implementing net.Conn.
|
||||
|
||||
type Stream struct {
|
||||
id uint32
|
||||
sess *Session
|
||||
readBuf *ringBuffer
|
||||
readCh chan struct{} // signaled when data arrives
|
||||
closed int32
|
||||
finRecv int32 // remote sent FIN
|
||||
finSent int32 // we sent FIN
|
||||
reset int32
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newStream(id uint32, sess *Session) *Stream {
|
||||
return &Stream{
|
||||
id: id,
|
||||
sess: sess,
|
||||
readBuf: newRingBuffer(defaultWindowSize),
|
||||
readCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (st *Stream) Read(p []byte) (int, error) {
|
||||
for {
|
||||
if atomic.LoadInt32(&st.reset) == 1 {
|
||||
return 0, ErrStreamReset
|
||||
}
|
||||
|
||||
n := st.readBuf.Read(p)
|
||||
if n > 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Buffer empty — check if FIN received
|
||||
if atomic.LoadInt32(&st.finRecv) == 1 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if atomic.LoadInt32(&st.closed) == 1 {
|
||||
return 0, ErrStreamClosed
|
||||
}
|
||||
|
||||
// Wait for data
|
||||
select {
|
||||
case <-st.readCh:
|
||||
case <-st.sess.quit:
|
||||
return 0, ErrSessionClosed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (st *Stream) Write(p []byte) (int, error) {
|
||||
if atomic.LoadInt32(&st.closed) == 1 || atomic.LoadInt32(&st.reset) == 1 {
|
||||
return 0, ErrStreamClosed
|
||||
}
|
||||
|
||||
total := 0
|
||||
for len(p) > 0 {
|
||||
chunk := p
|
||||
if len(chunk) > maxPayload {
|
||||
chunk = p[:maxPayload]
|
||||
}
|
||||
if err := st.sess.writeFrame(st.id, FlagDATA, chunk); err != nil {
|
||||
return total, err
|
||||
}
|
||||
total += len(chunk)
|
||||
p = p[len(chunk):]
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Close sends FIN and closes the stream.
|
||||
func (st *Stream) Close() error {
|
||||
if !atomic.CompareAndSwapInt32(&st.closed, 0, 1) {
|
||||
return nil
|
||||
}
|
||||
if atomic.CompareAndSwapInt32(&st.finSent, 0, 1) {
|
||||
st.sess.writeFrame(st.id, FlagFIN, nil)
|
||||
}
|
||||
st.sess.removeStream(st.id)
|
||||
st.notify()
|
||||
return nil
|
||||
}
|
||||
|
||||
// LocalAddr implements net.Conn.
|
||||
func (st *Stream) LocalAddr() net.Addr { return st.sess.conn.LocalAddr() }
|
||||
func (st *Stream) RemoteAddr() net.Addr { return st.sess.conn.RemoteAddr() }
|
||||
func (st *Stream) SetDeadline(t time.Time) error {
|
||||
return nil // TODO: implement per-stream deadlines
|
||||
}
|
||||
func (st *Stream) SetReadDeadline(t time.Time) error { return nil }
|
||||
func (st *Stream) SetWriteDeadline(t time.Time) error { return nil }
|
||||
|
||||
func (st *Stream) pushData(data []byte) {
|
||||
st.readBuf.Write(data)
|
||||
st.notify()
|
||||
}
|
||||
|
||||
func (st *Stream) finByPeer() {
|
||||
atomic.StoreInt32(&st.finRecv, 1)
|
||||
st.notify()
|
||||
}
|
||||
|
||||
func (st *Stream) resetByPeer() {
|
||||
atomic.StoreInt32(&st.reset, 1)
|
||||
atomic.StoreInt32(&st.closed, 1)
|
||||
st.notify()
|
||||
}
|
||||
|
||||
func (st *Stream) closeLocal() {
|
||||
atomic.StoreInt32(&st.closed, 1)
|
||||
st.notify()
|
||||
}
|
||||
|
||||
func (st *Stream) notify() {
|
||||
select {
|
||||
case st.readCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Ring Buffer ───
|
||||
// Lock-free-ish ring buffer for stream receive data.
|
||||
|
||||
type ringBuffer struct {
|
||||
buf []byte
|
||||
r, w int
|
||||
mu sync.Mutex
|
||||
size int
|
||||
}
|
||||
|
||||
func newRingBuffer(size int) *ringBuffer {
|
||||
return &ringBuffer{
|
||||
buf: make([]byte, size),
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Write(p []byte) int {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
|
||||
n := 0
|
||||
for _, b := range p {
|
||||
next := (rb.w + 1) % rb.size
|
||||
if next == rb.r {
|
||||
break // full
|
||||
}
|
||||
rb.buf[rb.w] = b
|
||||
rb.w = next
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Read(p []byte) int {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
|
||||
n := 0
|
||||
for n < len(p) && rb.r != rb.w {
|
||||
p[n] = rb.buf[rb.r]
|
||||
rb.r = (rb.r + 1) % rb.size
|
||||
n++
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (rb *ringBuffer) Len() int {
|
||||
rb.mu.Lock()
|
||||
defer rb.mu.Unlock()
|
||||
if rb.w >= rb.r {
|
||||
return rb.w - rb.r
|
||||
}
|
||||
return rb.size - rb.r + rb.w
|
||||
}
|
||||
266
pkg/mux/mux_test.go
Normal file
266
pkg/mux/mux_test.go
Normal file
@@ -0,0 +1,266 @@
|
||||
package mux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// pipe creates a connected pair of net.Conn using net.Pipe.
|
||||
func pipe() (net.Conn, net.Conn) {
|
||||
return net.Pipe()
|
||||
}
|
||||
|
||||
func TestSessionOpenAccept(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
defer c1.Close()
|
||||
defer c2.Close()
|
||||
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
// Client opens a stream
|
||||
st1, err := client.Open()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Server accepts
|
||||
st2, err := server.Accept()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Verify stream IDs: client=odd, server would be even
|
||||
if st1.id%2 != 1 {
|
||||
t.Errorf("client stream ID should be odd, got %d", st1.id)
|
||||
}
|
||||
_ = st2 // server accepted stream has client's ID
|
||||
}
|
||||
|
||||
func TestStreamReadWrite(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
st1, _ := client.Open()
|
||||
st2, _ := server.Accept()
|
||||
|
||||
msg := []byte("hello from client to server via mux")
|
||||
|
||||
// Write from client
|
||||
n, err := st1.Write(msg)
|
||||
if err != nil || n != len(msg) {
|
||||
t.Fatalf("write: n=%d err=%v", n, err)
|
||||
}
|
||||
|
||||
// Read on server
|
||||
buf := make([]byte, 1024)
|
||||
n, err = st2.Read(buf)
|
||||
if err != nil || n != len(msg) {
|
||||
t.Fatalf("read: n=%d err=%v", n, err)
|
||||
}
|
||||
if !bytes.Equal(buf[:n], msg) {
|
||||
t.Fatalf("data mismatch: got %q want %q", buf[:n], msg)
|
||||
}
|
||||
|
||||
// Bidirectional: server → client
|
||||
reply := []byte("pong")
|
||||
st2.Write(reply)
|
||||
n, _ = st1.Read(buf)
|
||||
if !bytes.Equal(buf[:n], reply) {
|
||||
t.Fatalf("reply mismatch: got %q want %q", buf[:n], reply)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleStreams(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
const numStreams = 10
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Client opens N streams concurrently
|
||||
wg.Add(numStreams)
|
||||
for i := 0; i < numStreams; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
st, err := client.Open()
|
||||
if err != nil {
|
||||
t.Errorf("open stream %d: %v", idx, err)
|
||||
return
|
||||
}
|
||||
msg := []byte("stream-data")
|
||||
st.Write(msg)
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Server accepts N streams
|
||||
for i := 0; i < numStreams; i++ {
|
||||
st, err := server.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("accept stream %d: %v", i, err)
|
||||
}
|
||||
buf := make([]byte, 64)
|
||||
n, _ := st.Read(buf)
|
||||
if string(buf[:n]) != "stream-data" {
|
||||
t.Errorf("stream %d data mismatch", i)
|
||||
}
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if client.NumStreams() != numStreams {
|
||||
t.Errorf("client streams: got %d want %d", client.NumStreams(), numStreams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamClose(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
st1, _ := client.Open()
|
||||
st2, _ := server.Accept()
|
||||
|
||||
// Write then close
|
||||
st1.Write([]byte("before-close"))
|
||||
st1.Close()
|
||||
|
||||
// Server should read data then get EOF
|
||||
buf := make([]byte, 64)
|
||||
n, _ := st2.Read(buf)
|
||||
if string(buf[:n]) != "before-close" {
|
||||
t.Errorf("unexpected data: %q", buf[:n])
|
||||
}
|
||||
|
||||
// Next read should eventually get EOF (FIN received)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
_, err := st2.Read(buf)
|
||||
if err != io.EOF {
|
||||
t.Errorf("expected EOF after FIN, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargePayload(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
st1, _ := client.Open()
|
||||
st2, _ := server.Accept()
|
||||
|
||||
// Write 200KB — larger than maxPayload (65535), should auto-split
|
||||
data := make([]byte, 200*1024)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, err := st1.Write(data)
|
||||
if err != nil {
|
||||
t.Errorf("write large: %v", err)
|
||||
}
|
||||
if n != len(data) {
|
||||
t.Errorf("write large: n=%d want %d", n, len(data))
|
||||
}
|
||||
}()
|
||||
|
||||
// Read all on server
|
||||
received := make([]byte, 0, len(data))
|
||||
buf := make([]byte, 32*1024)
|
||||
for len(received) < len(data) {
|
||||
n, err := st2.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("read at %d: %v", len(received), err)
|
||||
}
|
||||
received = append(received, buf[:n]...)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if !bytes.Equal(received, data) {
|
||||
t.Error("large payload data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionClose(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
|
||||
st1, _ := client.Open()
|
||||
server.Accept()
|
||||
|
||||
// Close session
|
||||
client.Close()
|
||||
|
||||
// Stream operations should fail
|
||||
_, err := st1.Write([]byte("x"))
|
||||
if err == nil {
|
||||
t.Error("write after session close should fail")
|
||||
}
|
||||
|
||||
// Server accept should fail
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
server.Close()
|
||||
}
|
||||
|
||||
func TestPingPong(t *testing.T) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
// Just verify it doesn't crash — ping/pong runs in background
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if client.IsClosed() || server.IsClosed() {
|
||||
t.Error("sessions should still be alive")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkThroughput(b *testing.B) {
|
||||
c1, c2 := pipe()
|
||||
client := NewSession(c1, false)
|
||||
server := NewSession(c2, true)
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
st1, _ := client.Open()
|
||||
st2, _ := server.Accept()
|
||||
|
||||
data := make([]byte, 4096)
|
||||
buf := make([]byte, 4096)
|
||||
|
||||
b.SetBytes(int64(len(data)))
|
||||
b.ResetTimer()
|
||||
|
||||
go func() {
|
||||
for i := 0; i < b.N; i++ {
|
||||
st2.Read(buf)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
st1.Write(data)
|
||||
}
|
||||
}
|
||||
260
pkg/nat/detect.go
Normal file
260
pkg/nat/detect.go
Normal file
@@ -0,0 +1,260 @@
|
||||
// Package nat provides NAT type detection via UDP and TCP STUN.
|
||||
package nat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
detectTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// DetectResult holds the NAT detection outcome.
|
||||
type DetectResult struct {
|
||||
Type protocol.NATType
|
||||
PublicIP string
|
||||
Port1 int // external port seen on STUN server port 1
|
||||
Port2 int // external port seen on STUN server port 2
|
||||
}
|
||||
|
||||
// stunReq is sent to the STUN endpoint.
|
||||
type stunReq struct {
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
// stunRsp is received from the STUN endpoint.
|
||||
type stunRsp struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
// DetectUDP sends probes from the same local port to two different server
|
||||
// UDP ports. If both return the same external port → Cone; different → Symmetric.
|
||||
func DetectUDP(serverIP string, port1, port2 int) DetectResult {
|
||||
result := DetectResult{Type: protocol.NATUnknown}
|
||||
|
||||
// Bind a single local UDP port
|
||||
conn, err := net.ListenPacket("udp", ":0")
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
r1, err1 := probeUDP(conn, serverIP, port1, 1)
|
||||
r2, err2 := probeUDP(conn, serverIP, port2, 2)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
return result // timeout → NATUnknown
|
||||
}
|
||||
|
||||
result.PublicIP = r1.IP
|
||||
result.Port1 = r1.Port
|
||||
result.Port2 = r2.Port
|
||||
|
||||
if r1.Port == r2.Port {
|
||||
result.Type = protocol.NATCone
|
||||
} else {
|
||||
result.Type = protocol.NATSymmetric
|
||||
}
|
||||
|
||||
// Check if public IP equals local IP → no NAT
|
||||
localIP := conn.LocalAddr().(*net.UDPAddr).IP.String()
|
||||
if localIP == r1.IP || r1.IP == "" {
|
||||
// might be public
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func probeUDP(conn net.PacketConn, serverIP string, port, id int) (stunRsp, error) {
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", serverIP, port))
|
||||
if err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
|
||||
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id})
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(detectTimeout))
|
||||
if _, err := conn.WriteTo(frame, addr); err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
conn.SetReadDeadline(time.Now().Add(detectTimeout))
|
||||
n, _, err := conn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
|
||||
var rsp stunRsp
|
||||
if n > protocol.HeaderSize {
|
||||
json.Unmarshal(buf[protocol.HeaderSize:n], &rsp)
|
||||
}
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
// DetectTCP connects to two different TCP ports on the server and compares
|
||||
// the observed external port. This is the fallback when UDP is blocked.
|
||||
func DetectTCP(serverIP string, port1, port2 int) DetectResult {
|
||||
result := DetectResult{Type: protocol.NATUnknown}
|
||||
|
||||
r1, err1 := probeTCP(serverIP, port1, 1)
|
||||
r2, err2 := probeTCP(serverIP, port2, 2)
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
result.PublicIP = r1.IP
|
||||
result.Port1 = r1.Port
|
||||
result.Port2 = r2.Port
|
||||
|
||||
if r1.Port == r2.Port {
|
||||
result.Type = protocol.NATCone
|
||||
} else {
|
||||
result.Type = protocol.NATSymmetric
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func probeTCP(serverIP string, port, id int) (stunRsp, error) {
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", serverIP, port), detectTimeout)
|
||||
if err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectReq, stunReq{ID: id})
|
||||
conn.SetWriteDeadline(time.Now().Add(detectTimeout))
|
||||
if _, err := conn.Write(frame); err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
conn.SetReadDeadline(time.Now().Add(detectTimeout))
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return stunRsp{}, err
|
||||
}
|
||||
|
||||
var rsp stunRsp
|
||||
if n > protocol.HeaderSize {
|
||||
json.Unmarshal(buf[protocol.HeaderSize:n], &rsp)
|
||||
}
|
||||
return rsp, nil
|
||||
}
|
||||
|
||||
// Detect runs UDP detection first, falls back to TCP if UDP is blocked.
|
||||
func Detect(serverIP string, udpPort1, udpPort2, tcpPort1, tcpPort2 int) DetectResult {
|
||||
result := DetectUDP(serverIP, udpPort1, udpPort2)
|
||||
if result.Type != protocol.NATUnknown {
|
||||
return result
|
||||
}
|
||||
// UDP blocked, fallback to TCP
|
||||
return DetectTCP(serverIP, tcpPort1, tcpPort2)
|
||||
}
|
||||
|
||||
// ─── Server-side STUN handler ───
|
||||
|
||||
// ServeUDPSTUN listens on a UDP port and echoes back the sender's observed IP:port.
|
||||
func ServeUDPSTUN(port int, quit <-chan struct{}) error {
|
||||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := net.ListenUDP("udp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
go func() {
|
||||
<-quit
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
n, remoteAddr, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-quit:
|
||||
return nil
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req stunReq
|
||||
if n > protocol.HeaderSize {
|
||||
json.Unmarshal(buf[protocol.HeaderSize:n], &req)
|
||||
}
|
||||
|
||||
// Reply with observed address
|
||||
rsp := stunRsp{
|
||||
IP: remoteAddr.IP.String(),
|
||||
Port: remoteAddr.Port,
|
||||
ID: req.ID,
|
||||
}
|
||||
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp)
|
||||
conn.WriteToUDP(frame, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// ServeTCPSTUN listens on a TCP port. Each connection: read one req, write one rsp with observed addr.
|
||||
func ServeTCPSTUN(port int, quit <-chan struct{}) error {
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
go func() {
|
||||
<-quit
|
||||
ln.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-quit:
|
||||
return nil
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
remoteAddr := c.RemoteAddr().(*net.TCPAddr)
|
||||
|
||||
buf := make([]byte, 1024)
|
||||
c.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
n, err := c.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var req stunReq
|
||||
if n > protocol.HeaderSize {
|
||||
json.Unmarshal(buf[protocol.HeaderSize:n], &req)
|
||||
}
|
||||
|
||||
rsp := stunRsp{
|
||||
IP: remoteAddr.IP.String(),
|
||||
Port: remoteAddr.Port,
|
||||
ID: req.ID,
|
||||
}
|
||||
frame, _ := protocol.Encode(protocol.MsgNAT, protocol.SubNATDetectRsp, rsp)
|
||||
c.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
c.Write(frame)
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
276
pkg/protocol/protocol.go
Normal file
276
pkg/protocol/protocol.go
Normal file
@@ -0,0 +1,276 @@
|
||||
// Package protocol defines the INP2P wire protocol.
|
||||
//
|
||||
// Message format: [Header 8B] + [JSON payload]
|
||||
// Header: DataLen(uint32 LE) + MainType(uint16 LE) + SubType(uint16 LE)
|
||||
// DataLen = len(header) + len(payload) = 8 + len(json)
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// HeaderSize is the fixed 8-byte message header.
|
||||
const HeaderSize = 8
|
||||
|
||||
// ─── Main message types ───
|
||||
|
||||
const (
|
||||
MsgLogin uint16 = 1
|
||||
MsgHeartbeat uint16 = 2
|
||||
MsgNAT uint16 = 3
|
||||
MsgPush uint16 = 4 // signaling push (punch/relay coordination)
|
||||
MsgRelay uint16 = 5
|
||||
MsgReport uint16 = 6
|
||||
MsgTunnel uint16 = 7 // in-tunnel control messages
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgLogin ───
|
||||
|
||||
const (
|
||||
SubLoginReq uint16 = iota
|
||||
SubLoginRsp
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgHeartbeat ───
|
||||
|
||||
const (
|
||||
SubHeartbeatPing uint16 = iota
|
||||
SubHeartbeatPong
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgNAT ───
|
||||
|
||||
const (
|
||||
SubNATDetectReq uint16 = iota
|
||||
SubNATDetectRsp
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgPush ───
|
||||
|
||||
const (
|
||||
SubPushConnectReq uint16 = iota // "please connect to peer X"
|
||||
SubPushConnectRsp // peer's punch parameters
|
||||
SubPushPunchStart // coordinate simultaneous punch
|
||||
SubPushPunchResult // report punch outcome
|
||||
SubPushRelayOffer // relay node offers to relay
|
||||
SubPushNodeOnline // notify: destination came online
|
||||
SubPushEditApp // add/edit tunnel app
|
||||
SubPushDeleteApp // delete tunnel app
|
||||
SubPushReportApps // request app list
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgRelay ───
|
||||
|
||||
const (
|
||||
SubRelayNodeReq uint16 = iota
|
||||
SubRelayNodeRsp
|
||||
SubRelayDataReq // establish data channel through relay
|
||||
SubRelayDataRsp
|
||||
)
|
||||
|
||||
// ─── Sub types: MsgReport ───
|
||||
|
||||
const (
|
||||
SubReportBasic uint16 = iota // OS, version, MAC, etc.
|
||||
SubReportApps // running tunnels
|
||||
SubReportConnect // connection result
|
||||
)
|
||||
|
||||
// ─── NAT types ───
|
||||
|
||||
type NATType int
|
||||
|
||||
const (
|
||||
NATNone NATType = 0 // public IP, no NAT
|
||||
NATCone NATType = 1 // full/restricted/port-restricted cone
|
||||
NATSymmetric NATType = 2 // symmetric (port changes per dest)
|
||||
NATUnknown NATType = 314 // detection failed / UDP blocked
|
||||
)
|
||||
|
||||
func (n NATType) String() string {
|
||||
switch n {
|
||||
case NATNone:
|
||||
return "None"
|
||||
case NATCone:
|
||||
return "Cone"
|
||||
case NATSymmetric:
|
||||
return "Symmetric"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// CanPunch returns true if at least one side is Cone (or has public IP).
|
||||
func CanPunch(a, b NATType) bool {
|
||||
return a == NATNone || b == NATNone || a == NATCone || b == NATCone
|
||||
}
|
||||
|
||||
// ─── Header ───
|
||||
|
||||
type Header struct {
|
||||
DataLen uint32
|
||||
MainType uint16
|
||||
SubType uint16
|
||||
}
|
||||
|
||||
// ─── Encode / Decode ───
|
||||
|
||||
// Encode packs header + JSON payload into a byte slice.
|
||||
func Encode(mainType, subType uint16, payload interface{}) ([]byte, error) {
|
||||
var jsonData []byte
|
||||
if payload != nil {
|
||||
var err error
|
||||
jsonData, err = json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal payload: %w", err)
|
||||
}
|
||||
}
|
||||
h := Header{
|
||||
DataLen: uint32(HeaderSize + len(jsonData)),
|
||||
MainType: mainType,
|
||||
SubType: subType,
|
||||
}
|
||||
buf := new(bytes.Buffer)
|
||||
buf.Grow(int(h.DataLen))
|
||||
if err := binary.Write(buf, binary.LittleEndian, h); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf.Write(jsonData)
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// DecodeHeader reads the 8-byte header from r.
|
||||
func DecodeHeader(data []byte) (Header, error) {
|
||||
if len(data) < HeaderSize {
|
||||
return Header{}, io.ErrShortBuffer
|
||||
}
|
||||
var h Header
|
||||
err := binary.Read(bytes.NewReader(data[:HeaderSize]), binary.LittleEndian, &h)
|
||||
return h, err
|
||||
}
|
||||
|
||||
// DecodePayload unmarshals the JSON portion after the header.
|
||||
func DecodePayload(data []byte, v interface{}) error {
|
||||
if len(data) <= HeaderSize {
|
||||
return nil // empty payload is valid
|
||||
}
|
||||
return json.Unmarshal(data[HeaderSize:], v)
|
||||
}
|
||||
|
||||
// ─── Common message structs ───
|
||||
|
||||
// LoginReq is sent by client on WSS connect.
|
||||
type LoginReq struct {
|
||||
Node string `json:"node"`
|
||||
Token uint64 `json:"token"`
|
||||
User string `json:"user,omitempty"`
|
||||
Version string `json:"version"`
|
||||
NATType NATType `json:"natType"`
|
||||
ShareBandwidth int `json:"shareBandwidth"`
|
||||
RelayEnabled bool `json:"relayEnabled"` // --relay flag
|
||||
SuperRelay bool `json:"superRelay"` // --super flag
|
||||
PublicIP string `json:"publicIP,omitempty"`
|
||||
}
|
||||
|
||||
type LoginRsp struct {
|
||||
Error int `json:"error"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
Ts int64 `json:"ts"`
|
||||
Token uint64 `json:"token"`
|
||||
User string `json:"user"`
|
||||
Node string `json:"node"`
|
||||
}
|
||||
|
||||
// ReportBasic is the initial system info report after login.
|
||||
type ReportBasic struct {
|
||||
OS string `json:"os"`
|
||||
Mac string `json:"mac"`
|
||||
LanIP string `json:"lanIP"`
|
||||
Version string `json:"version"`
|
||||
HasIPv4 int `json:"hasIPv4"`
|
||||
HasUPNPorNATPMP int `json:"hasUPNPorNATPMP"`
|
||||
IPv6 string `json:"IPv6,omitempty"`
|
||||
}
|
||||
|
||||
type ReportBasicRsp struct {
|
||||
Error int `json:"error"`
|
||||
}
|
||||
|
||||
// PunchParams carries the information needed for hole-punching.
|
||||
type PunchParams struct {
|
||||
IP string `json:"ip"`
|
||||
Port int `json:"port"`
|
||||
NATType NATType `json:"natType"`
|
||||
Token uint64 `json:"token"` // TOTP for auth
|
||||
IPv6 string `json:"ipv6,omitempty"`
|
||||
HasIPv4 int `json:"hasIPv4"`
|
||||
LinkMode string `json:"linkMode"` // "udp" or "tcp"
|
||||
}
|
||||
|
||||
// ConnectReq is pushed by server to coordinate a connection.
|
||||
type ConnectReq struct {
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
FromIP string `json:"fromIP"`
|
||||
Peer PunchParams `json:"peer"`
|
||||
AppName string `json:"appName,omitempty"`
|
||||
Protocol string `json:"protocol"` // "tcp" or "udp"
|
||||
SrcPort int `json:"srcPort"`
|
||||
DstHost string `json:"dstHost"`
|
||||
DstPort int `json:"dstPort"`
|
||||
}
|
||||
|
||||
type ConnectRsp struct {
|
||||
Error int `json:"error"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
From string `json:"from"`
|
||||
To string `json:"to"`
|
||||
Peer PunchParams `json:"peer,omitempty"`
|
||||
}
|
||||
|
||||
// RelayNodeReq asks the server for a relay node.
|
||||
type RelayNodeReq struct {
|
||||
PeerNode string `json:"peerNode"`
|
||||
}
|
||||
|
||||
type RelayNodeRsp struct {
|
||||
RelayName string `json:"relayName"`
|
||||
RelayIP string `json:"relayIP"`
|
||||
RelayPort int `json:"relayPort"`
|
||||
RelayToken uint64 `json:"relayToken"`
|
||||
Mode string `json:"mode"` // "private", "super", "server"
|
||||
Error int `json:"error"`
|
||||
}
|
||||
|
||||
// AppConfig defines a tunnel application.
|
||||
type AppConfig struct {
|
||||
AppName string `json:"appName"`
|
||||
Protocol string `json:"protocol"` // "tcp" or "udp"
|
||||
SrcPort int `json:"srcPort"`
|
||||
PeerNode string `json:"peerNode"`
|
||||
DstHost string `json:"dstHost"`
|
||||
DstPort int `json:"dstPort"`
|
||||
Enabled int `json:"enabled"`
|
||||
RelayNode string `json:"relayNode,omitempty"` // force specific relay
|
||||
}
|
||||
|
||||
// ReportConnect is the connection result reported to server.
|
||||
type ReportConnect struct {
|
||||
PeerNode string `json:"peerNode"`
|
||||
NATType NATType `json:"natType"`
|
||||
PeerNATType NATType `json:"peerNatType"`
|
||||
LinkMode string `json:"linkMode"` // "udppunch", "tcppunch", "relay"
|
||||
Error string `json:"error,omitempty"`
|
||||
RTT int `json:"rtt,omitempty"` // milliseconds
|
||||
RelayNode string `json:"relayNode,omitempty"`
|
||||
Protocol string `json:"protocol,omitempty"`
|
||||
SrcPort int `json:"srcPort,omitempty"`
|
||||
DstPort int `json:"dstPort,omitempty"`
|
||||
DstHost string `json:"dstHost,omitempty"`
|
||||
Version string `json:"version,omitempty"`
|
||||
ShareBandwidth int `json:"shareBandWidth,omitempty"`
|
||||
}
|
||||
204
pkg/punch/punch.go
Normal file
204
pkg/punch/punch.go
Normal file
@@ -0,0 +1,204 @@
|
||||
// Package punch implements UDP and TCP hole-punching.
|
||||
package punch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/protocol"
|
||||
)
|
||||
|
||||
const (
|
||||
punchTimeout = 5 * time.Second
|
||||
punchRetries = 5
|
||||
handshakeMagic = "INP2P-PUNCH"
|
||||
handshakeAck = "INP2P-PUNCH-ACK"
|
||||
)
|
||||
|
||||
// Result holds the outcome of a punch attempt.
|
||||
type Result struct {
|
||||
Conn net.Conn
|
||||
Mode string // "udp" or "tcp"
|
||||
RTT time.Duration
|
||||
PeerAddr string
|
||||
Error error
|
||||
}
|
||||
|
||||
// Config for a punch attempt.
|
||||
type Config struct {
|
||||
PeerIP string
|
||||
PeerPort int
|
||||
PeerNAT protocol.NATType
|
||||
SelfNAT protocol.NATType
|
||||
SelfPort int // local port to bind (0 = auto)
|
||||
IsInitiator bool
|
||||
}
|
||||
|
||||
// AttemptUDP tries to establish a UDP connection via hole-punching.
|
||||
// Both sides must call this simultaneously (coordinated by server).
|
||||
func AttemptUDP(cfg Config) Result {
|
||||
if !protocol.CanPunch(cfg.SelfNAT, cfg.PeerNAT) {
|
||||
return Result{Error: fmt.Errorf("cannot UDP punch: self=%s peer=%s", cfg.SelfNAT, cfg.PeerNAT)}
|
||||
}
|
||||
|
||||
localAddr := &net.UDPAddr{Port: cfg.SelfPort}
|
||||
conn, err := net.ListenUDP("udp", localAddr)
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Errorf("listen UDP: %w", err)}
|
||||
}
|
||||
|
||||
peerAddr := &net.UDPAddr{
|
||||
IP: net.ParseIP(cfg.PeerIP),
|
||||
Port: cfg.PeerPort,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// Send punch packets
|
||||
for i := 0; i < punchRetries; i++ {
|
||||
conn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
conn.WriteTo([]byte(handshakeMagic), peerAddr)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
|
||||
// Listen for response
|
||||
buf := make([]byte, 256)
|
||||
conn.SetReadDeadline(time.Now().Add(punchTimeout))
|
||||
n, from, err := conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return Result{Error: fmt.Errorf("UDP punch timeout: %w", err)}
|
||||
}
|
||||
|
||||
// Verify handshake
|
||||
msg := string(buf[:n])
|
||||
if msg != handshakeMagic && msg != handshakeAck {
|
||||
conn.Close()
|
||||
return Result{Error: fmt.Errorf("unexpected punch data: %q", msg)}
|
||||
}
|
||||
|
||||
// Send ack
|
||||
conn.WriteTo([]byte(handshakeAck), from)
|
||||
|
||||
rtt := time.Since(start)
|
||||
log.Printf("[punch] UDP punch ok: peer=%s rtt=%s", from, rtt)
|
||||
|
||||
return Result{
|
||||
Conn: conn,
|
||||
Mode: "udp",
|
||||
RTT: rtt,
|
||||
PeerAddr: from.String(),
|
||||
}
|
||||
}
|
||||
|
||||
// AttemptTCP tries TCP hole-punching using simultaneous SYN.
|
||||
// This works by having both sides dial each other at the same time.
|
||||
func AttemptTCP(cfg Config) Result {
|
||||
if !protocol.CanPunch(cfg.SelfNAT, cfg.PeerNAT) {
|
||||
return Result{Error: fmt.Errorf("cannot TCP punch: self=%s peer=%s", cfg.SelfNAT, cfg.PeerNAT)}
|
||||
}
|
||||
|
||||
peerAddr := fmt.Sprintf("%s:%d", cfg.PeerIP, cfg.PeerPort)
|
||||
start := time.Now()
|
||||
|
||||
// TCP simultaneous open: keep trying to dial the peer
|
||||
var conn net.Conn
|
||||
var err error
|
||||
for i := 0; i < punchRetries*2; i++ {
|
||||
d := net.Dialer{Timeout: time.Second, LocalAddr: &net.TCPAddr{Port: cfg.SelfPort}}
|
||||
conn, err = d.Dial("tcp", peerAddr)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Errorf("TCP punch failed: %w", err)}
|
||||
}
|
||||
|
||||
// TCP handshake for INP2P
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
conn.Write([]byte(handshakeMagic))
|
||||
|
||||
buf := make([]byte, 256)
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return Result{Error: fmt.Errorf("TCP handshake read: %w", err)}
|
||||
}
|
||||
|
||||
msg := string(buf[:n])
|
||||
if msg != handshakeMagic && msg != handshakeAck {
|
||||
conn.Close()
|
||||
return Result{Error: fmt.Errorf("TCP unexpected handshake: %q", msg)}
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
conn.Write([]byte(handshakeAck))
|
||||
|
||||
rtt := time.Since(start)
|
||||
log.Printf("[punch] TCP punch ok: peer=%s rtt=%s", conn.RemoteAddr(), rtt)
|
||||
|
||||
return Result{
|
||||
Conn: conn,
|
||||
Mode: "tcp",
|
||||
RTT: rtt,
|
||||
PeerAddr: conn.RemoteAddr().String(),
|
||||
}
|
||||
}
|
||||
|
||||
// AttemptDirect tries to directly connect when one side has a public IP.
|
||||
func AttemptDirect(cfg Config) Result {
|
||||
addr := fmt.Sprintf("%s:%d", cfg.PeerIP, cfg.PeerPort)
|
||||
start := time.Now()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", addr, punchTimeout)
|
||||
if err != nil {
|
||||
return Result{Error: fmt.Errorf("direct connect failed: %w", err)}
|
||||
}
|
||||
|
||||
rtt := time.Since(start)
|
||||
log.Printf("[punch] direct connect ok: peer=%s rtt=%s", addr, rtt)
|
||||
|
||||
return Result{
|
||||
Conn: conn,
|
||||
Mode: "tcp-direct",
|
||||
RTT: rtt,
|
||||
PeerAddr: addr,
|
||||
}
|
||||
}
|
||||
|
||||
// Connect tries all punch methods in priority order and returns the first success.
|
||||
func Connect(cfg Config) Result {
|
||||
methods := []struct {
|
||||
name string
|
||||
fn func(Config) Result
|
||||
}{
|
||||
{"UDP-punch", AttemptUDP},
|
||||
{"TCP-punch", AttemptTCP},
|
||||
}
|
||||
|
||||
// If peer has public IP, try direct first
|
||||
if cfg.PeerNAT == protocol.NATNone {
|
||||
r := AttemptDirect(cfg)
|
||||
if r.Error == nil {
|
||||
return r
|
||||
}
|
||||
log.Printf("[punch] direct failed: %v", r.Error)
|
||||
}
|
||||
|
||||
for _, m := range methods {
|
||||
log.Printf("[punch] trying %s to %s:%d", m.name, cfg.PeerIP, cfg.PeerPort)
|
||||
r := m.fn(cfg)
|
||||
if r.Error == nil {
|
||||
return r
|
||||
}
|
||||
log.Printf("[punch] %s failed: %v", m.name, r.Error)
|
||||
}
|
||||
|
||||
return Result{Error: fmt.Errorf("all punch methods exhausted")}
|
||||
}
|
||||
415
pkg/relay/relay.go
Normal file
415
pkg/relay/relay.go
Normal file
@@ -0,0 +1,415 @@
|
||||
// Package relay implements relay/super-relay node capabilities.
|
||||
//
|
||||
// Relay flow:
|
||||
// 1. Client A asks server for relay (RelayNodeReq)
|
||||
// 2. Server finds relay R, generates TOTP/token, responds to A (RelayNodeRsp)
|
||||
// 3. Server pushes RelayOffer to R with session info
|
||||
// 4. A connects to R:relayPort, sends RelayHandshake{SessionID, Role="from", Token}
|
||||
// 5. B connects to R:relayPort, sends RelayHandshake{SessionID, Role="to", Token}
|
||||
// (B gets the session info via server push)
|
||||
// 6. R verifies both tokens, bridges A↔B
|
||||
package relay
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/auth"
|
||||
)
|
||||
|
||||
const (
|
||||
handshakeTimeout = 10 * time.Second
|
||||
pairTimeout = 30 * time.Second // how long to wait for the second peer
|
||||
headerLen = 4 // uint32 LE length prefix for handshake JSON
|
||||
)
|
||||
|
||||
// RelayHandshake is sent by each peer when connecting to a relay node.
|
||||
type RelayHandshake struct {
|
||||
SessionID string `json:"sessionID"`
|
||||
Role string `json:"role"` // "from" or "to"
|
||||
Token uint64 `json:"token"` // TOTP or one-time token
|
||||
Node string `json:"node"` // sender's node name
|
||||
}
|
||||
|
||||
// Node represents a relay-capable node's metadata (used by server).
|
||||
type Node struct {
|
||||
Name string
|
||||
IP string
|
||||
Port int
|
||||
Token uint64
|
||||
Mode string // "private" (same user), "super" (shared)
|
||||
Bandwidth int
|
||||
LastUsed time.Time
|
||||
ActiveLoad int32
|
||||
}
|
||||
|
||||
// pendingSession waits for both peers to arrive.
|
||||
type pendingSession struct {
|
||||
id string
|
||||
from string
|
||||
to string
|
||||
token uint64
|
||||
connFrom net.Conn
|
||||
connTo net.Conn
|
||||
mu sync.Mutex
|
||||
done chan struct{}
|
||||
created time.Time
|
||||
}
|
||||
|
||||
// Manager manages relay sessions on this node.
|
||||
type Manager struct {
|
||||
enabled bool
|
||||
superRelay bool
|
||||
maxLoad int
|
||||
token uint64 // this node's auth token
|
||||
port int
|
||||
listener net.Listener
|
||||
|
||||
pending map[string]*pendingSession // sessionID → pending
|
||||
pMu sync.Mutex
|
||||
sessions map[string]*Session // sessionID → active session
|
||||
sMu sync.RWMutex
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// Session represents an active relay bridging two peers.
|
||||
type Session struct {
|
||||
ID string
|
||||
From string
|
||||
To string
|
||||
ConnA net.Conn
|
||||
ConnB net.Conn
|
||||
BytesFwd int64
|
||||
StartTime time.Time
|
||||
closed int32
|
||||
}
|
||||
|
||||
// NewManager creates a relay manager.
|
||||
func NewManager(port int, enabled, superRelay bool, maxLoad int, token uint64) *Manager {
|
||||
return &Manager{
|
||||
enabled: enabled,
|
||||
superRelay: superRelay,
|
||||
maxLoad: maxLoad,
|
||||
token: token,
|
||||
port: port,
|
||||
pending: make(map[string]*pendingSession),
|
||||
sessions: make(map[string]*Session),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) IsEnabled() bool { return m.enabled }
|
||||
func (m *Manager) IsSuperRelay() bool { return m.superRelay }
|
||||
|
||||
func (m *Manager) ActiveSessions() int {
|
||||
m.sMu.RLock()
|
||||
defer m.sMu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
func (m *Manager) CanAcceptRelay() bool {
|
||||
return m.enabled && m.ActiveSessions() < m.maxLoad
|
||||
}
|
||||
|
||||
// Start begins listening for relay connections.
|
||||
func (m *Manager) Start() error {
|
||||
if !m.enabled {
|
||||
return nil
|
||||
}
|
||||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", m.port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("relay listen :%d: %w", m.port, err)
|
||||
}
|
||||
m.listener = ln
|
||||
log.Printf("[relay] listening on :%d (super=%v, maxLoad=%d)", m.port, m.superRelay, m.maxLoad)
|
||||
|
||||
go m.acceptLoop()
|
||||
go m.cleanupLoop()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) acceptLoop() {
|
||||
for {
|
||||
conn, err := m.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-m.quit:
|
||||
return
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
go m.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) handleConn(conn net.Conn) {
|
||||
// Read handshake with timeout
|
||||
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
||||
|
||||
// Length-prefixed JSON: [4B len][JSON]
|
||||
var length uint32
|
||||
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
||||
log.Printf("[relay] handshake read len: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if length > 4096 {
|
||||
log.Printf("[relay] handshake too large: %d", length)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
log.Printf("[relay] handshake read body: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{}) // clear deadline
|
||||
|
||||
var hs RelayHandshake
|
||||
if err := json.Unmarshal(buf, &hs); err != nil {
|
||||
log.Printf("[relay] handshake parse: %v", err)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Verify TOTP
|
||||
if !auth.VerifyTOTP(hs.Token, m.token, time.Now().Unix()) {
|
||||
log.Printf("[relay] handshake denied: %s (TOTP mismatch)", hs.Node)
|
||||
sendRelayResult(conn, 1, "auth failed")
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[relay] handshake ok: session=%s role=%s node=%s", hs.SessionID, hs.Role, hs.Node)
|
||||
|
||||
// Find or create pending session
|
||||
m.pMu.Lock()
|
||||
ps, exists := m.pending[hs.SessionID]
|
||||
if !exists {
|
||||
ps = &pendingSession{
|
||||
id: hs.SessionID,
|
||||
token: hs.Token,
|
||||
done: make(chan struct{}),
|
||||
created: time.Now(),
|
||||
}
|
||||
m.pending[hs.SessionID] = ps
|
||||
}
|
||||
m.pMu.Unlock()
|
||||
|
||||
ps.mu.Lock()
|
||||
switch hs.Role {
|
||||
case "from":
|
||||
ps.from = hs.Node
|
||||
ps.connFrom = conn
|
||||
case "to":
|
||||
ps.to = hs.Node
|
||||
ps.connTo = conn
|
||||
default:
|
||||
ps.mu.Unlock()
|
||||
log.Printf("[relay] unknown role: %s", hs.Role)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if both peers have arrived
|
||||
bothReady := ps.connFrom != nil && ps.connTo != nil
|
||||
ps.mu.Unlock()
|
||||
|
||||
if bothReady {
|
||||
// Both peers connected — bridge them
|
||||
m.pMu.Lock()
|
||||
delete(m.pending, hs.SessionID)
|
||||
m.pMu.Unlock()
|
||||
|
||||
sendRelayResult(ps.connFrom, 0, "ok")
|
||||
sendRelayResult(ps.connTo, 0, "ok")
|
||||
|
||||
m.bridge(ps)
|
||||
} else {
|
||||
// Wait for the other peer
|
||||
select {
|
||||
case <-ps.done:
|
||||
// Woken up by the other peer's arrival
|
||||
case <-time.After(pairTimeout):
|
||||
log.Printf("[relay] session %s timeout waiting for pair", hs.SessionID)
|
||||
m.pMu.Lock()
|
||||
delete(m.pending, hs.SessionID)
|
||||
m.pMu.Unlock()
|
||||
sendRelayResult(conn, 1, "pair timeout")
|
||||
conn.Close()
|
||||
case <-m.quit:
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// relayResult is sent back to each peer after handshake.
|
||||
type relayResult struct {
|
||||
Error int `json:"error"`
|
||||
Detail string `json:"detail,omitempty"`
|
||||
}
|
||||
|
||||
func sendRelayResult(conn net.Conn, errCode int, detail string) {
|
||||
data, _ := json.Marshal(relayResult{Error: errCode, Detail: detail})
|
||||
length := uint32(len(data))
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
binary.Write(conn, binary.LittleEndian, length)
|
||||
conn.Write(data)
|
||||
conn.SetWriteDeadline(time.Time{})
|
||||
}
|
||||
|
||||
func (m *Manager) bridge(ps *pendingSession) {
|
||||
sess := &Session{
|
||||
ID: ps.id,
|
||||
From: ps.from,
|
||||
To: ps.to,
|
||||
ConnA: ps.connFrom,
|
||||
ConnB: ps.connTo,
|
||||
StartTime: time.Now(),
|
||||
}
|
||||
|
||||
m.sMu.Lock()
|
||||
m.sessions[ps.id] = sess
|
||||
m.sMu.Unlock()
|
||||
|
||||
log.Printf("[relay] bridging %s ↔ %s (session %s)", ps.from, ps.to, ps.id)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
sess.Close()
|
||||
m.sMu.Lock()
|
||||
delete(m.sessions, ps.id)
|
||||
m.sMu.Unlock()
|
||||
log.Printf("[relay] session %s ended, %d bytes forwarded", ps.id, atomic.LoadInt64(&sess.BytesFwd))
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, _ := io.Copy(sess.ConnB, sess.ConnA)
|
||||
atomic.AddInt64(&sess.BytesFwd, n)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
n, _ := io.Copy(sess.ConnA, sess.ConnB)
|
||||
atomic.AddInt64(&sess.BytesFwd, n)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupLoop() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.pMu.Lock()
|
||||
for id, ps := range m.pending {
|
||||
if time.Since(ps.created) > pairTimeout {
|
||||
delete(m.pending, id)
|
||||
if ps.connFrom != nil {
|
||||
ps.connFrom.Close()
|
||||
}
|
||||
if ps.connTo != nil {
|
||||
ps.connTo.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
m.pMu.Unlock()
|
||||
case <-m.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close shuts down a session.
|
||||
func (s *Session) Close() {
|
||||
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
|
||||
return
|
||||
}
|
||||
if s.ConnA != nil {
|
||||
s.ConnA.Close()
|
||||
}
|
||||
if s.ConnB != nil {
|
||||
s.ConnB.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Stop shuts down the relay manager.
|
||||
func (m *Manager) Stop() {
|
||||
close(m.quit)
|
||||
if m.listener != nil {
|
||||
m.listener.Close()
|
||||
}
|
||||
m.sMu.Lock()
|
||||
for _, s := range m.sessions {
|
||||
s.Close()
|
||||
}
|
||||
m.sMu.Unlock()
|
||||
}
|
||||
|
||||
// ─── Client-side helper ───
|
||||
|
||||
// ConnectToRelay connects to a relay node and performs the handshake.
|
||||
func ConnectToRelay(relayAddr string, sessionID, role, node string, token uint64) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout("tcp", relayAddr, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial relay %s: %w", relayAddr, err)
|
||||
}
|
||||
|
||||
hs := RelayHandshake{
|
||||
SessionID: sessionID,
|
||||
Role: role,
|
||||
Token: token,
|
||||
Node: node,
|
||||
}
|
||||
data, _ := json.Marshal(hs)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
||||
length := uint32(len(data))
|
||||
if err := binary.Write(conn, binary.LittleEndian, length); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if _, err := conn.Write(data); err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read result
|
||||
conn.SetReadDeadline(time.Now().Add(pairTimeout + 5*time.Second))
|
||||
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("read relay result: %w", err)
|
||||
}
|
||||
buf := make([]byte, length)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("read relay result body: %w", err)
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
var result relayResult
|
||||
json.Unmarshal(buf, &result)
|
||||
if result.Error != 0 {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("relay denied: %s", result.Detail)
|
||||
}
|
||||
|
||||
log.Printf("[relay] connected to relay %s, session=%s role=%s", relayAddr, sessionID, role)
|
||||
return conn, nil
|
||||
}
|
||||
189
pkg/relay/relay_test.go
Normal file
189
pkg/relay/relay_test.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/auth"
|
||||
)
|
||||
|
||||
func TestRelayBridge(t *testing.T) {
|
||||
token := auth.MakeToken("test", "pass")
|
||||
mgr := NewManager(29700, true, false, 10, token)
|
||||
if err := mgr.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
sessionID := "test-session-1"
|
||||
totp := auth.GenTOTP(token, time.Now().Unix())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var connA, connB net.Conn
|
||||
var errA, errB error
|
||||
|
||||
// Peer A connects as "from"
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
connA, errA = ConnectToRelay(
|
||||
fmt.Sprintf("127.0.0.1:%d", 29700),
|
||||
sessionID, "from", "nodeA", totp,
|
||||
)
|
||||
}()
|
||||
|
||||
// Peer B connects as "to" after a short delay
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
connB, errB = ConnectToRelay(
|
||||
fmt.Sprintf("127.0.0.1:%d", 29700),
|
||||
sessionID, "to", "nodeB", totp,
|
||||
)
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if errA != nil {
|
||||
t.Fatalf("connA error: %v", errA)
|
||||
}
|
||||
if errB != nil {
|
||||
t.Fatalf("connB error: %v", errB)
|
||||
}
|
||||
defer connA.Close()
|
||||
defer connB.Close()
|
||||
|
||||
// Test data flow: A → B
|
||||
msg := []byte("hello through relay")
|
||||
connA.Write(msg)
|
||||
|
||||
buf := make([]byte, 256)
|
||||
connB.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
n, err := connB.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("read from B: %v", err)
|
||||
}
|
||||
if string(buf[:n]) != string(msg) {
|
||||
t.Errorf("got %q, want %q", buf[:n], msg)
|
||||
}
|
||||
|
||||
// Test data flow: B → A
|
||||
reply := []byte("relay pong")
|
||||
connB.Write(reply)
|
||||
|
||||
connA.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
n, err = connA.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("read from A: %v", err)
|
||||
}
|
||||
if string(buf[:n]) != string(reply) {
|
||||
t.Errorf("got %q, want %q", buf[:n], reply)
|
||||
}
|
||||
|
||||
// Verify session count
|
||||
if mgr.ActiveSessions() != 1 {
|
||||
t.Errorf("active sessions: got %d want 1", mgr.ActiveSessions())
|
||||
}
|
||||
|
||||
t.Logf("✅ Relay bridge OK: A↔B bidirectional, %d active sessions", mgr.ActiveSessions())
|
||||
}
|
||||
|
||||
func TestRelayLargeData(t *testing.T) {
|
||||
token := auth.MakeToken("test", "pass")
|
||||
mgr := NewManager(29701, true, false, 10, token)
|
||||
if err := mgr.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
sessionID := "test-large-data"
|
||||
totp := auth.GenTOTP(token, time.Now().Unix())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var connA, connB net.Conn
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
var err error
|
||||
connA, err = ConnectToRelay("127.0.0.1:29701", sessionID, "from", "bigA", totp)
|
||||
if err != nil {
|
||||
t.Errorf("connA: %v", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
var err error
|
||||
connB, err = ConnectToRelay("127.0.0.1:29701", sessionID, "to", "bigB", totp)
|
||||
if err != nil {
|
||||
t.Errorf("connB: %v", err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if connA == nil || connB == nil {
|
||||
t.Fatal("connection failed")
|
||||
}
|
||||
defer connA.Close()
|
||||
defer connB.Close()
|
||||
|
||||
// Send 1MB through relay
|
||||
const dataSize = 1024 * 1024
|
||||
data := make([]byte, dataSize)
|
||||
for i := range data {
|
||||
data[i] = byte(i % 256)
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
connA.Write(data)
|
||||
}()
|
||||
|
||||
// Read exact amount on B side
|
||||
received := make([]byte, dataSize)
|
||||
total := 0
|
||||
connB.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
for total < dataSize {
|
||||
n, err := connB.Read(received[total:])
|
||||
if err != nil {
|
||||
t.Fatalf("read at %d: %v", total, err)
|
||||
}
|
||||
total += n
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if len(received) != len(data) {
|
||||
t.Fatalf("size mismatch: got %d want %d", len(received), len(data))
|
||||
}
|
||||
for i := 0; i < len(data); i++ {
|
||||
if received[i] != data[i] {
|
||||
t.Fatalf("data mismatch at byte %d", i)
|
||||
break
|
||||
}
|
||||
}
|
||||
t.Logf("✅ 1MB relay transfer OK")
|
||||
}
|
||||
|
||||
func TestRelayAuthDenied(t *testing.T) {
|
||||
token := auth.MakeToken("real", "token")
|
||||
mgr := NewManager(29702, true, false, 10, token)
|
||||
if err := mgr.Start(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer mgr.Stop()
|
||||
|
||||
// Use wrong TOTP
|
||||
wrongToken := auth.GenTOTP(auth.MakeToken("wrong", "creds"), time.Now().Unix())
|
||||
_, err := ConnectToRelay("127.0.0.1:29702", "bad-session", "from", "badNode", wrongToken)
|
||||
if err == nil {
|
||||
t.Fatal("expected auth denied, got success")
|
||||
}
|
||||
t.Logf("✅ Auth denied correctly: %v", err)
|
||||
}
|
||||
180
pkg/signal/conn.go
Normal file
180
pkg/signal/conn.go
Normal file
@@ -0,0 +1,180 @@
|
||||
// Package signal provides the WSS signaling connection between client and server.
|
||||
package signal
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/openp2p-cn/inp2p/pkg/protocol"
|
||||
)
|
||||
|
||||
// Conn wraps a WebSocket connection with message framing.
|
||||
type Conn struct {
|
||||
ws *websocket.Conn
|
||||
writeMu sync.Mutex
|
||||
handlers map[msgKey]Handler
|
||||
hMu sync.RWMutex
|
||||
quit chan struct{}
|
||||
once sync.Once
|
||||
Node string
|
||||
Token uint64
|
||||
|
||||
// waiters for synchronous request-response
|
||||
waiters map[msgKey]chan []byte
|
||||
wMu sync.Mutex
|
||||
}
|
||||
|
||||
type msgKey struct {
|
||||
main uint16
|
||||
sub uint16
|
||||
}
|
||||
|
||||
// Handler processes an incoming message. data includes header + payload.
|
||||
type Handler func(data []byte) error
|
||||
|
||||
// NewConn wraps an existing websocket.
|
||||
func NewConn(ws *websocket.Conn) *Conn {
|
||||
return &Conn{
|
||||
ws: ws,
|
||||
handlers: make(map[msgKey]Handler),
|
||||
waiters: make(map[msgKey]chan []byte),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// OnMessage registers a handler for a specific (MainType, SubType).
|
||||
func (c *Conn) OnMessage(mainType, subType uint16, h Handler) {
|
||||
c.hMu.Lock()
|
||||
c.handlers[msgKey{mainType, subType}] = h
|
||||
c.hMu.Unlock()
|
||||
}
|
||||
|
||||
// Write sends a message with the given type and JSON payload.
|
||||
func (c *Conn) Write(mainType, subType uint16, payload interface{}) error {
|
||||
frame, err := protocol.Encode(mainType, subType, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.WriteRaw(frame)
|
||||
}
|
||||
|
||||
// WriteRaw sends raw bytes.
|
||||
func (c *Conn) WriteRaw(data []byte) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
||||
return c.ws.WriteMessage(websocket.BinaryMessage, data)
|
||||
}
|
||||
|
||||
// Request sends a message and waits for a specific response type.
|
||||
func (c *Conn) Request(mainType, subType uint16, payload interface{},
|
||||
rspMain, rspSub uint16, timeout time.Duration) ([]byte, error) {
|
||||
|
||||
ch := make(chan []byte, 1)
|
||||
key := msgKey{rspMain, rspSub}
|
||||
|
||||
c.wMu.Lock()
|
||||
c.waiters[key] = ch
|
||||
c.wMu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.wMu.Lock()
|
||||
delete(c.waiters, key)
|
||||
c.wMu.Unlock()
|
||||
}()
|
||||
|
||||
if err := c.Write(mainType, subType, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
case data := <-ch:
|
||||
return data, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("request timeout %d:%d → %d:%d", mainType, subType, rspMain, rspSub)
|
||||
case <-c.quit:
|
||||
return nil, fmt.Errorf("connection closed")
|
||||
}
|
||||
}
|
||||
|
||||
// ReadLoop reads messages and dispatches to handlers. Blocks until error or Close().
|
||||
func (c *Conn) ReadLoop() error {
|
||||
for {
|
||||
_, msg, err := c.ws.ReadMessage()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return nil
|
||||
default:
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(msg) < protocol.HeaderSize {
|
||||
continue
|
||||
}
|
||||
h, err := protocol.DecodeHeader(msg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key := msgKey{h.MainType, h.SubType}
|
||||
|
||||
// Check waiters first (synchronous request-response)
|
||||
c.wMu.Lock()
|
||||
if ch, ok := c.waiters[key]; ok {
|
||||
delete(c.waiters, key)
|
||||
c.wMu.Unlock()
|
||||
select {
|
||||
case ch <- msg:
|
||||
default:
|
||||
}
|
||||
continue
|
||||
}
|
||||
c.wMu.Unlock()
|
||||
|
||||
// Dispatch to registered handler
|
||||
c.hMu.RLock()
|
||||
handler, ok := c.handlers[key]
|
||||
c.hMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
if err := handler(msg); err != nil {
|
||||
log.Printf("[signal] handler %d:%d error: %v", h.MainType, h.SubType, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close gracefully shuts down the connection.
|
||||
func (c *Conn) Close() {
|
||||
c.once.Do(func() {
|
||||
close(c.quit)
|
||||
c.ws.Close()
|
||||
})
|
||||
}
|
||||
|
||||
// IsClosed reports whether the connection has been closed.
|
||||
func (c *Conn) IsClosed() bool {
|
||||
select {
|
||||
case <-c.quit:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Helpers ───
|
||||
|
||||
// ParsePayload is a convenience to unmarshal JSON from a raw message.
|
||||
func ParsePayload[T any](data []byte) (T, error) {
|
||||
var v T
|
||||
if len(data) <= protocol.HeaderSize {
|
||||
return v, nil
|
||||
}
|
||||
err := json.Unmarshal(data[protocol.HeaderSize:], &v)
|
||||
return v, err
|
||||
}
|
||||
233
pkg/tunnel/tunnel.go
Normal file
233
pkg/tunnel/tunnel.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Package tunnel provides P2P tunnel with mux-based port forwarding.
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/openp2p-cn/inp2p/pkg/mux"
|
||||
)
|
||||
|
||||
// Tunnel represents a P2P tunnel that multiplexes port forwards over one connection.
|
||||
type Tunnel struct {
|
||||
PeerNode string
|
||||
PeerIP string
|
||||
LinkMode string // "udppunch", "tcppunch", "relay", "direct"
|
||||
RTT time.Duration
|
||||
|
||||
sess *mux.Session
|
||||
listeners map[int]*forwarder // srcPort → forwarder
|
||||
mu sync.Mutex
|
||||
closed int32
|
||||
stats Stats
|
||||
}
|
||||
|
||||
type forwarder struct {
|
||||
listener net.Listener
|
||||
dstHost string
|
||||
dstPort int
|
||||
quit chan struct{}
|
||||
}
|
||||
|
||||
// Stats tracks tunnel traffic.
|
||||
type Stats struct {
|
||||
BytesSent int64
|
||||
BytesReceived int64
|
||||
Connections int64
|
||||
ActiveStreams int32
|
||||
}
|
||||
|
||||
// New creates a tunnel from an established P2P connection.
|
||||
// isInitiator: the side that opened the P2P connection is the mux client.
|
||||
func New(peerNode string, conn net.Conn, linkMode string, rtt time.Duration, isInitiator bool) *Tunnel {
|
||||
return &Tunnel{
|
||||
PeerNode: peerNode,
|
||||
PeerIP: conn.RemoteAddr().String(),
|
||||
LinkMode: linkMode,
|
||||
RTT: rtt,
|
||||
sess: mux.NewSession(conn, !isInitiator), // initiator=client, responder=server
|
||||
listeners: make(map[int]*forwarder),
|
||||
}
|
||||
}
|
||||
|
||||
// ListenAndForward starts a local listener that forwards connections through the tunnel.
|
||||
// Each accepted connection opens a mux stream to the peer, which connects to dstHost:dstPort.
|
||||
func (t *Tunnel) ListenAndForward(protocol string, srcPort int, dstHost string, dstPort int) error {
|
||||
addr := fmt.Sprintf(":%d", srcPort)
|
||||
ln, err := net.Listen(protocol, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen %s %s: %w", protocol, addr, err)
|
||||
}
|
||||
|
||||
fwd := &forwarder{
|
||||
listener: ln,
|
||||
dstHost: dstHost,
|
||||
dstPort: dstPort,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
t.listeners[srcPort] = fwd
|
||||
t.mu.Unlock()
|
||||
|
||||
log.Printf("[tunnel] LISTEN %s:%d → %s(%s:%d) via %s", protocol, srcPort, t.PeerNode, dstHost, dstPort, t.LinkMode)
|
||||
|
||||
go t.acceptLoop(fwd)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Tunnel) acceptLoop(fwd *forwarder) {
|
||||
for {
|
||||
conn, err := fwd.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-fwd.quit:
|
||||
return
|
||||
default:
|
||||
if atomic.LoadInt32(&t.closed) == 1 {
|
||||
return
|
||||
}
|
||||
log.Printf("[tunnel] accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
atomic.AddInt64(&t.stats.Connections, 1)
|
||||
go t.handleLocalConn(conn, fwd.dstHost, fwd.dstPort)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleLocalConn(local net.Conn, dstHost string, dstPort int) {
|
||||
defer local.Close()
|
||||
|
||||
// Open a mux stream
|
||||
stream, err := t.sess.Open()
|
||||
if err != nil {
|
||||
log.Printf("[tunnel] mux open error: %v", err)
|
||||
return
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
atomic.AddInt32(&t.stats.ActiveStreams, 1)
|
||||
defer atomic.AddInt32(&t.stats.ActiveStreams, -1)
|
||||
|
||||
// Send destination info as first message on the stream
|
||||
// Format: "host:port\n"
|
||||
header := fmt.Sprintf("%s:%d\n", dstHost, dstPort)
|
||||
if _, err := stream.Write([]byte(header)); err != nil {
|
||||
log.Printf("[tunnel] stream write header: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Bidirectional copy
|
||||
t.bridge(local, stream)
|
||||
}
|
||||
|
||||
// AcceptAndConnect handles incoming mux streams (called on the responder side).
|
||||
// It reads the destination header and connects to the local target.
|
||||
func (t *Tunnel) AcceptAndConnect() {
|
||||
for {
|
||||
stream, err := t.sess.Accept()
|
||||
if err != nil {
|
||||
if !t.sess.IsClosed() {
|
||||
log.Printf("[tunnel] mux accept error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
go t.handleRemoteStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) handleRemoteStream(stream *mux.Stream) {
|
||||
defer stream.Close()
|
||||
|
||||
atomic.AddInt32(&t.stats.ActiveStreams, 1)
|
||||
defer atomic.AddInt32(&t.stats.ActiveStreams, -1)
|
||||
|
||||
// Read destination header: "host:port\n"
|
||||
buf := make([]byte, 256)
|
||||
n := 0
|
||||
for n < len(buf) {
|
||||
nn, err := stream.Read(buf[n : n+1])
|
||||
if err != nil {
|
||||
log.Printf("[tunnel] read dest header: %v", err)
|
||||
return
|
||||
}
|
||||
n += nn
|
||||
if buf[n-1] == '\n' {
|
||||
break
|
||||
}
|
||||
}
|
||||
dest := string(buf[:n-1]) // trim \n
|
||||
|
||||
// Connect to local destination
|
||||
conn, err := net.DialTimeout("tcp", dest, 5*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("[tunnel] connect to %s failed: %v", dest, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
log.Printf("[tunnel] stream → %s connected", dest)
|
||||
|
||||
// Bidirectional copy
|
||||
t.bridge(conn, stream)
|
||||
}
|
||||
|
||||
func (t *Tunnel) bridge(a, b io.ReadWriter) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
copyAndCount := func(dst io.Writer, src io.Reader, counter *int64) {
|
||||
defer wg.Done()
|
||||
n, _ := io.Copy(dst, src)
|
||||
atomic.AddInt64(counter, n)
|
||||
}
|
||||
|
||||
go copyAndCount(a, b, &t.stats.BytesReceived)
|
||||
go copyAndCount(b, a, &t.stats.BytesSent)
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Close shuts down the tunnel and all listeners.
|
||||
func (t *Tunnel) Close() {
|
||||
if !atomic.CompareAndSwapInt32(&t.closed, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
t.mu.Lock()
|
||||
for port, fwd := range t.listeners {
|
||||
close(fwd.quit)
|
||||
fwd.listener.Close()
|
||||
log.Printf("[tunnel] stopped :%d", port)
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
t.sess.Close()
|
||||
log.Printf("[tunnel] closed → %s", t.PeerNode)
|
||||
}
|
||||
|
||||
// GetStats returns traffic statistics.
|
||||
func (t *Tunnel) GetStats() Stats {
|
||||
return Stats{
|
||||
BytesSent: atomic.LoadInt64(&t.stats.BytesSent),
|
||||
BytesReceived: atomic.LoadInt64(&t.stats.BytesReceived),
|
||||
Connections: atomic.LoadInt64(&t.stats.Connections),
|
||||
ActiveStreams: atomic.LoadInt32(&t.stats.ActiveStreams),
|
||||
}
|
||||
}
|
||||
|
||||
// IsAlive returns true if the tunnel is open.
|
||||
func (t *Tunnel) IsAlive() bool {
|
||||
return atomic.LoadInt32(&t.closed) == 0 && !t.sess.IsClosed()
|
||||
}
|
||||
|
||||
// NumStreams returns active mux streams.
|
||||
func (t *Tunnel) NumStreams() int {
|
||||
return t.sess.NumStreams()
|
||||
}
|
||||
176
pkg/tunnel/tunnel_test.go
Normal file
176
pkg/tunnel/tunnel_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEndToEndForward(t *testing.T) {
|
||||
// 1. Start a "target" TCP server (simulates SSH on the remote side)
|
||||
targetLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer targetLn.Close()
|
||||
targetPort := targetLn.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := targetLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, _ := c.Read(buf)
|
||||
c.Write([]byte("ECHO:" + string(buf[:n])))
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
// 2. Create a connected pair (simulates a P2P punch connection)
|
||||
c1, c2 := net.Pipe()
|
||||
|
||||
// 3. Create tunnels on both sides
|
||||
initiator := New("remote-node", c1, "test", 0, true)
|
||||
responder := New("local-node", c2, "test", 0, false)
|
||||
defer initiator.Close()
|
||||
defer responder.Close()
|
||||
|
||||
// Responder accepts incoming mux streams and connects to local targets
|
||||
go responder.AcceptAndConnect()
|
||||
|
||||
// 4. Initiator listens on a local port and forwards to remote target
|
||||
localLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
localPort := localLn.Addr().(*net.TCPAddr).Port
|
||||
localLn.Close() // free the port so tunnel can use it
|
||||
|
||||
err = initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// 5. Connect to the tunnel's local port
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 6. Send data and verify echo
|
||||
conn.Write([]byte("hello-tunnel"))
|
||||
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got := string(buf[:n])
|
||||
want := "ECHO:hello-tunnel"
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleConnections(t *testing.T) {
|
||||
// Target server: echoes back with a prefix
|
||||
targetLn, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
defer targetLn.Close()
|
||||
targetPort := targetLn.Addr().(*net.TCPAddr).Port
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := targetLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
io.Copy(c, c) // pure echo
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
c1, c2 := net.Pipe()
|
||||
initiator := New("peer", c1, "test", 0, true)
|
||||
responder := New("me", c2, "test", 0, false)
|
||||
defer initiator.Close()
|
||||
defer responder.Close()
|
||||
|
||||
go responder.AcceptAndConnect()
|
||||
|
||||
localLn, _ := net.Listen("tcp", "127.0.0.1:0")
|
||||
localPort := localLn.Addr().(*net.TCPAddr).Port
|
||||
localLn.Close()
|
||||
|
||||
initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Open 5 concurrent connections through the tunnel
|
||||
const N = 5
|
||||
done := make(chan bool, N)
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
go func(idx int) {
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort))
|
||||
if err != nil {
|
||||
t.Errorf("conn %d: dial: %v", idx, err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
msg := fmt.Sprintf("msg-%d", idx)
|
||||
conn.Write([]byte(msg))
|
||||
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
|
||||
buf := make([]byte, 256)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil || string(buf[:n]) != msg {
|
||||
t.Errorf("conn %d: got %q, want %q, err=%v", idx, buf[:n], msg, err)
|
||||
done <- false
|
||||
return
|
||||
}
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
if ok := <-done; !ok {
|
||||
t.Errorf("connection %d failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
stats := initiator.GetStats()
|
||||
if stats.Connections != N {
|
||||
t.Errorf("connections: got %d want %d", stats.Connections, N)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTunnelStats(t *testing.T) {
|
||||
c1, c2 := net.Pipe()
|
||||
initiator := New("peer", c1, "test", 0, true)
|
||||
responder := New("me", c2, "test", 0, false)
|
||||
defer initiator.Close()
|
||||
defer responder.Close()
|
||||
|
||||
if !initiator.IsAlive() || !responder.IsAlive() {
|
||||
t.Error("tunnels should be alive")
|
||||
}
|
||||
|
||||
initiator.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
if initiator.IsAlive() {
|
||||
t.Error("initiator should be dead after close")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user