package tunnel import ( "fmt" "io" "net" "testing" "time" ) func TestEndToEndForward(t *testing.T) { // 1. Start a "target" TCP server (simulates SSH on the remote side) targetLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } defer targetLn.Close() targetPort := targetLn.Addr().(*net.TCPAddr).Port go func() { for { conn, err := targetLn.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() buf := make([]byte, 1024) n, _ := c.Read(buf) c.Write([]byte("ECHO:" + string(buf[:n]))) }(conn) } }() // 2. Create a connected pair (simulates a P2P punch connection) c1, c2 := net.Pipe() // 3. Create tunnels on both sides initiator := New("remote-node", c1, "test", 0, true) responder := New("local-node", c2, "test", 0, false) defer initiator.Close() defer responder.Close() // Responder accepts incoming mux streams and connects to local targets go responder.AcceptAndConnect() // 4. Initiator listens on a local port and forwards to remote target localLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatal(err) } localPort := localLn.Addr().(*net.TCPAddr).Port localLn.Close() // free the port so tunnel can use it err = initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort) if err != nil { t.Fatal(err) } time.Sleep(50 * time.Millisecond) // 5. Connect to the tunnel's local port conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) if err != nil { t.Fatal(err) } defer conn.Close() // 6. Send data and verify echo conn.Write([]byte("hello-tunnel")) conn.SetReadDeadline(time.Now().Add(3 * time.Second)) buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil { t.Fatal(err) } got := string(buf[:n]) want := "ECHO:hello-tunnel" if got != want { t.Errorf("got %q, want %q", got, want) } } func TestMultipleConnections(t *testing.T) { // Target server: echoes back with a prefix targetLn, _ := net.Listen("tcp", "127.0.0.1:0") defer targetLn.Close() targetPort := targetLn.Addr().(*net.TCPAddr).Port go func() { for { conn, err := targetLn.Accept() if err != nil { return } go func(c net.Conn) { defer c.Close() io.Copy(c, c) // pure echo }(conn) } }() c1, c2 := net.Pipe() initiator := New("peer", c1, "test", 0, true) responder := New("me", c2, "test", 0, false) defer initiator.Close() defer responder.Close() go responder.AcceptAndConnect() localLn, _ := net.Listen("tcp", "127.0.0.1:0") localPort := localLn.Addr().(*net.TCPAddr).Port localLn.Close() initiator.ListenAndForward("tcp", localPort, "127.0.0.1", targetPort) time.Sleep(50 * time.Millisecond) // Open 5 concurrent connections through the tunnel const N = 5 done := make(chan bool, N) for i := 0; i < N; i++ { go func(idx int) { conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", localPort)) if err != nil { t.Errorf("conn %d: dial: %v", idx, err) done <- false return } defer conn.Close() msg := fmt.Sprintf("msg-%d", idx) conn.Write([]byte(msg)) conn.SetReadDeadline(time.Now().Add(3 * time.Second)) buf := make([]byte, 256) n, err := conn.Read(buf) if err != nil || string(buf[:n]) != msg { t.Errorf("conn %d: got %q, want %q, err=%v", idx, buf[:n], msg, err) done <- false return } done <- true }(i) } for i := 0; i < N; i++ { if ok := <-done; !ok { t.Errorf("connection %d failed", i) } } stats := initiator.GetStats() if stats.Connections != N { t.Errorf("connections: got %d want %d", stats.Connections, N) } } func TestTunnelStats(t *testing.T) { c1, c2 := net.Pipe() initiator := New("peer", c1, "test", 0, true) responder := New("me", c2, "test", 0, false) defer initiator.Close() defer responder.Close() if !initiator.IsAlive() || !responder.IsAlive() { t.Error("tunnels should be alive") } initiator.Close() time.Sleep(50 * time.Millisecond) if initiator.IsAlive() { t.Error("initiator should be dead after close") } }