451 lines
10 KiB
Go
451 lines
10 KiB
Go
// Package relay implements relay/super-relay node capabilities.
|
|
//
|
|
// Relay flow:
|
|
// 1. Client A asks server for relay (RelayNodeReq)
|
|
// 2. Server finds relay R, generates TOTP/token, responds to A (RelayNodeRsp)
|
|
// 3. Server pushes RelayOffer to R with session info
|
|
// 4. A connects to R:relayPort, sends RelayHandshake{SessionID, Role="from", Token}
|
|
// 5. B connects to R:relayPort, sends RelayHandshake{SessionID, Role="to", Token}
|
|
// (B gets the session info via server push)
|
|
// 6. R verifies both tokens, bridges A↔B
|
|
package relay
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/openp2p-cn/inp2p/pkg/auth"
|
|
)
|
|
|
|
const (
|
|
handshakeTimeout = 10 * time.Second
|
|
pairTimeout = 30 * time.Second // how long to wait for the second peer
|
|
headerLen = 4 // uint32 LE length prefix for handshake JSON
|
|
)
|
|
|
|
// RelayHandshake is sent by each peer when connecting to a relay node.
|
|
type RelayHandshake struct {
|
|
SessionID string `json:"sessionID"`
|
|
Role string `json:"role"` // "from" or "to"
|
|
Token uint64 `json:"token"` // TOTP or one-time token
|
|
Node string `json:"node"` // sender's node name
|
|
}
|
|
|
|
// Node represents a relay-capable node's metadata (used by server).
|
|
type Node struct {
|
|
Name string
|
|
IP string
|
|
Port int
|
|
Token uint64
|
|
Mode string // "private" (same user), "super" (shared)
|
|
Bandwidth int
|
|
LastUsed time.Time
|
|
ActiveLoad int32
|
|
}
|
|
|
|
// pendingSession waits for both peers to arrive.
|
|
type pendingSession struct {
|
|
id string
|
|
from string
|
|
to string
|
|
token uint64
|
|
connFrom net.Conn
|
|
connTo net.Conn
|
|
mu sync.Mutex
|
|
done chan struct{}
|
|
created time.Time
|
|
}
|
|
|
|
// Manager manages relay sessions on this node.
|
|
type Manager struct {
|
|
enabled bool
|
|
superRelay bool
|
|
maxLoad int
|
|
maxMbps int
|
|
token uint64 // this node's auth token
|
|
port int
|
|
listener net.Listener
|
|
|
|
pending map[string]*pendingSession // sessionID → pending
|
|
pMu sync.Mutex
|
|
sessions map[string]*Session // sessionID → active session
|
|
sMu sync.RWMutex
|
|
quit chan struct{}
|
|
}
|
|
|
|
// Session represents an active relay bridging two peers.
|
|
type Session struct {
|
|
ID string
|
|
From string
|
|
To string
|
|
ConnA net.Conn
|
|
ConnB net.Conn
|
|
BytesFwd int64
|
|
StartTime time.Time
|
|
closed int32
|
|
}
|
|
|
|
// NewManager creates a relay manager.
|
|
func NewManager(port int, enabled, superRelay bool, maxLoad int, token uint64, maxMbps int) *Manager {
|
|
return &Manager{
|
|
enabled: enabled,
|
|
superRelay: superRelay,
|
|
maxLoad: maxLoad,
|
|
maxMbps: maxMbps,
|
|
token: token,
|
|
port: port,
|
|
pending: make(map[string]*pendingSession),
|
|
sessions: make(map[string]*Session),
|
|
quit: make(chan struct{}),
|
|
}
|
|
}
|
|
|
|
func (m *Manager) IsEnabled() bool { return m.enabled }
|
|
func (m *Manager) IsSuperRelay() bool { return m.superRelay }
|
|
|
|
func (m *Manager) ActiveSessions() int {
|
|
m.sMu.RLock()
|
|
defer m.sMu.RUnlock()
|
|
return len(m.sessions)
|
|
}
|
|
|
|
func (m *Manager) CanAcceptRelay() bool {
|
|
return m.enabled && m.ActiveSessions() < m.maxLoad
|
|
}
|
|
|
|
// Start begins listening for relay connections.
|
|
func (m *Manager) Start() error {
|
|
if !m.enabled {
|
|
return nil
|
|
}
|
|
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", m.port))
|
|
if err != nil {
|
|
return fmt.Errorf("relay listen :%d: %w", m.port, err)
|
|
}
|
|
m.listener = ln
|
|
log.Printf("[relay] listening on :%d (super=%v, maxLoad=%d)", m.port, m.superRelay, m.maxLoad)
|
|
|
|
go m.acceptLoop()
|
|
go m.cleanupLoop()
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) acceptLoop() {
|
|
for {
|
|
conn, err := m.listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-m.quit:
|
|
return
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
go m.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (m *Manager) handleConn(conn net.Conn) {
|
|
// Read handshake with timeout
|
|
conn.SetReadDeadline(time.Now().Add(handshakeTimeout))
|
|
|
|
// Length-prefixed JSON: [4B len][JSON]
|
|
var length uint32
|
|
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
|
log.Printf("[relay] handshake read len: %v", err)
|
|
conn.Close()
|
|
return
|
|
}
|
|
if length > 4096 {
|
|
log.Printf("[relay] handshake too large: %d", length)
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
buf := make([]byte, length)
|
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
log.Printf("[relay] handshake read body: %v", err)
|
|
conn.Close()
|
|
return
|
|
}
|
|
conn.SetReadDeadline(time.Time{}) // clear deadline
|
|
|
|
var hs RelayHandshake
|
|
if err := json.Unmarshal(buf, &hs); err != nil {
|
|
log.Printf("[relay] handshake parse: %v", err)
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Verify TOTP
|
|
if !auth.VerifyTOTP(hs.Token, m.token, time.Now().Unix()) {
|
|
log.Printf("[relay] handshake denied: %s (TOTP mismatch)", hs.Node)
|
|
sendRelayResult(conn, 1, "auth failed")
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
log.Printf("[relay] handshake ok: session=%s role=%s node=%s", hs.SessionID, hs.Role, hs.Node)
|
|
|
|
// Find or create pending session
|
|
m.pMu.Lock()
|
|
ps, exists := m.pending[hs.SessionID]
|
|
if !exists {
|
|
ps = &pendingSession{
|
|
id: hs.SessionID,
|
|
token: hs.Token,
|
|
done: make(chan struct{}),
|
|
created: time.Now(),
|
|
}
|
|
m.pending[hs.SessionID] = ps
|
|
}
|
|
m.pMu.Unlock()
|
|
|
|
ps.mu.Lock()
|
|
switch hs.Role {
|
|
case "from":
|
|
ps.from = hs.Node
|
|
ps.connFrom = conn
|
|
case "to":
|
|
ps.to = hs.Node
|
|
ps.connTo = conn
|
|
default:
|
|
ps.mu.Unlock()
|
|
log.Printf("[relay] unknown role: %s", hs.Role)
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
// Check if both peers have arrived
|
|
bothReady := ps.connFrom != nil && ps.connTo != nil
|
|
ps.mu.Unlock()
|
|
|
|
if bothReady {
|
|
// Both peers connected — bridge them
|
|
m.pMu.Lock()
|
|
delete(m.pending, hs.SessionID)
|
|
m.pMu.Unlock()
|
|
|
|
sendRelayResult(ps.connFrom, 0, "ok")
|
|
sendRelayResult(ps.connTo, 0, "ok")
|
|
|
|
m.bridge(ps)
|
|
} else {
|
|
// Wait for the other peer
|
|
select {
|
|
case <-ps.done:
|
|
// Woken up by the other peer's arrival
|
|
case <-time.After(pairTimeout):
|
|
log.Printf("[relay] session %s timeout waiting for pair", hs.SessionID)
|
|
m.pMu.Lock()
|
|
delete(m.pending, hs.SessionID)
|
|
m.pMu.Unlock()
|
|
sendRelayResult(conn, 1, "pair timeout")
|
|
conn.Close()
|
|
case <-m.quit:
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
|
|
// relayResult is sent back to each peer after handshake.
|
|
type relayResult struct {
|
|
Error int `json:"error"`
|
|
Detail string `json:"detail,omitempty"`
|
|
}
|
|
|
|
func sendRelayResult(conn net.Conn, errCode int, detail string) {
|
|
data, _ := json.Marshal(relayResult{Error: errCode, Detail: detail})
|
|
length := uint32(len(data))
|
|
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
binary.Write(conn, binary.LittleEndian, length)
|
|
conn.Write(data)
|
|
conn.SetWriteDeadline(time.Time{})
|
|
}
|
|
|
|
func (m *Manager) bridge(ps *pendingSession) {
|
|
sess := &Session{
|
|
ID: ps.id,
|
|
From: ps.from,
|
|
To: ps.to,
|
|
ConnA: ps.connFrom,
|
|
ConnB: ps.connTo,
|
|
StartTime: time.Now(),
|
|
}
|
|
|
|
m.sMu.Lock()
|
|
m.sessions[ps.id] = sess
|
|
m.sMu.Unlock()
|
|
|
|
log.Printf("[relay] bridging %s ↔ %s (session %s)", ps.from, ps.to, ps.id)
|
|
|
|
go func() {
|
|
defer func() {
|
|
sess.Close()
|
|
m.sMu.Lock()
|
|
delete(m.sessions, ps.id)
|
|
m.sMu.Unlock()
|
|
log.Printf("[relay] session %s ended, %d bytes forwarded", ps.id, atomic.LoadInt64(&sess.BytesFwd))
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
copyWithLimit := func(dst, src net.Conn) int64 {
|
|
if m.maxMbps <= 0 {
|
|
n, _ := io.Copy(dst, src)
|
|
return n
|
|
}
|
|
bytesPerSec := int64(m.maxMbps) * 1024 * 1024 / 8
|
|
if bytesPerSec < 1 {
|
|
bytesPerSec = 1
|
|
}
|
|
var total int64
|
|
buf := make([]byte, 32*1024)
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
var allowance = bytesPerSec / 10
|
|
for {
|
|
n, err := src.Read(buf)
|
|
if n > 0 {
|
|
// simple token bucket
|
|
if allowance < int64(n) {
|
|
<-ticker.C
|
|
allowance = bytesPerSec / 10
|
|
}
|
|
allowance -= int64(n)
|
|
w, _ := dst.Write(buf[:n])
|
|
total += int64(w)
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
return total
|
|
}
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
n := copyWithLimit(sess.ConnB, sess.ConnA)
|
|
atomic.AddInt64(&sess.BytesFwd, n)
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
n := copyWithLimit(sess.ConnA, sess.ConnB)
|
|
atomic.AddInt64(&sess.BytesFwd, n)
|
|
}()
|
|
|
|
wg.Wait()
|
|
}()
|
|
}
|
|
|
|
func (m *Manager) cleanupLoop() {
|
|
ticker := time.NewTicker(30 * time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
m.pMu.Lock()
|
|
for id, ps := range m.pending {
|
|
if time.Since(ps.created) > pairTimeout {
|
|
delete(m.pending, id)
|
|
if ps.connFrom != nil {
|
|
ps.connFrom.Close()
|
|
}
|
|
if ps.connTo != nil {
|
|
ps.connTo.Close()
|
|
}
|
|
}
|
|
}
|
|
m.pMu.Unlock()
|
|
case <-m.quit:
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Close shuts down a session.
|
|
func (s *Session) Close() {
|
|
if !atomic.CompareAndSwapInt32(&s.closed, 0, 1) {
|
|
return
|
|
}
|
|
if s.ConnA != nil {
|
|
s.ConnA.Close()
|
|
}
|
|
if s.ConnB != nil {
|
|
s.ConnB.Close()
|
|
}
|
|
}
|
|
|
|
// Stop shuts down the relay manager.
|
|
func (m *Manager) Stop() {
|
|
close(m.quit)
|
|
if m.listener != nil {
|
|
m.listener.Close()
|
|
}
|
|
m.sMu.Lock()
|
|
for _, s := range m.sessions {
|
|
s.Close()
|
|
}
|
|
m.sMu.Unlock()
|
|
}
|
|
|
|
// ─── Client-side helper ───
|
|
|
|
// ConnectToRelay connects to a relay node and performs the handshake.
|
|
func ConnectToRelay(relayAddr string, sessionID, role, node string, token uint64) (net.Conn, error) {
|
|
conn, err := net.DialTimeout("tcp", relayAddr, 10*time.Second)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial relay %s: %w", relayAddr, err)
|
|
}
|
|
|
|
hs := RelayHandshake{
|
|
SessionID: sessionID,
|
|
Role: role,
|
|
Token: token,
|
|
Node: node,
|
|
}
|
|
data, _ := json.Marshal(hs)
|
|
|
|
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
|
|
length := uint32(len(data))
|
|
if err := binary.Write(conn, binary.LittleEndian, length); err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
if _, err := conn.Write(data); err != nil {
|
|
conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
// Read result
|
|
conn.SetReadDeadline(time.Now().Add(pairTimeout + 5*time.Second))
|
|
if err := binary.Read(conn, binary.LittleEndian, &length); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("read relay result: %w", err)
|
|
}
|
|
buf := make([]byte, length)
|
|
if _, err := io.ReadFull(conn, buf); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("read relay result body: %w", err)
|
|
}
|
|
conn.SetReadDeadline(time.Time{})
|
|
|
|
var result relayResult
|
|
json.Unmarshal(buf, &result)
|
|
if result.Error != 0 {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("relay denied: %s", result.Detail)
|
|
}
|
|
|
|
log.Printf("[relay] connected to relay %s, session=%s role=%s", relayAddr, sessionID, role)
|
|
return conn, nil
|
|
}
|