From dd06bc2dcef0a1bba3919a33f7cc8014c562c693 Mon Sep 17 00:00:00 2001 From: Alessio Bianchini Date: Tue, 31 Dec 2024 12:35:07 +0100 Subject: [PATCH] Improved keep alive logic --- Tunnelize/Services/TunnelManager.cs | 96 +++++++++++++++++++++++++++-- TunnelizeClient/src/tunnelize.js | 23 +++---- 2 files changed, 98 insertions(+), 21 deletions(-) diff --git a/Tunnelize/Services/TunnelManager.cs b/Tunnelize/Services/TunnelManager.cs index b9377f0..2acab42 100644 --- a/Tunnelize/Services/TunnelManager.cs +++ b/Tunnelize/Services/TunnelManager.cs @@ -1,9 +1,11 @@ -using System.Net.WebSockets; +using System; +using System.Net.WebSockets; using System.Text; public class TunnelManager { private readonly Dictionary _tunnels = new(); + private readonly Dictionary _lastActivityTracker = new(); // Track last activity per tunnel public async Task HandleTunnelConnection(string tunnelId, WebSocket webSocket) { @@ -12,8 +14,76 @@ public async Task HandleTunnelConnection(string tunnelId, WebSocket webSocket) if (!_tunnels.ContainsKey(tunnelId)) { _tunnels[tunnelId] = webSocket; + _lastActivityTracker[tunnelId] = DateTime.UtcNow; } + var buffer = new byte[1024 * 1024]; + var cts = new CancellationTokenSource(); + var keepAliveInterval = TimeSpan.FromMinutes(20); + + _ = Task.Run(async () => + { + while (webSocket.State == WebSocketState.Open) + { + try + { + if (_lastActivityTracker.ContainsKey(tunnelId) && DateTime.UtcNow - _lastActivityTracker[tunnelId] >= keepAliveInterval) + { + Console.WriteLine($"[DEBUG] No activity for 5 minutes. Sending ping to client for tunnel {tunnelId}"); + await webSocket.SendAsync(new ArraySegment(Encoding.UTF8.GetBytes("ping")), WebSocketMessageType.Text, true, CancellationToken.None); + + try + { + while (true) + { + var result = await webSocket.ReceiveAsync(new ArraySegment(buffer), cts.Token); + + if (result.MessageType == WebSocketMessageType.Close) + { + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Client closed connection", CancellationToken.None); + throw new Exception($"WebSocket connection closed by client for tunnel {tunnelId}"); + } + + if (result.MessageType == WebSocketMessageType.Text) + { + var message = Encoding.UTF8.GetString(buffer, 0, result.Count); + + if (message == "pong") + { + Console.WriteLine($"[DEBUG] Pong received from client for tunnel {tunnelId}"); + } + else + { + Console.WriteLine($"[DEBUG] Received message for tunnel {tunnelId}: {message}"); + } + + if (_lastActivityTracker.ContainsKey(tunnelId)) + { + _lastActivityTracker[tunnelId] = DateTime.UtcNow; + } + } + + if (result.EndOfMessage) + { + break; + } + } + } + catch (OperationCanceledException) + { + throw new Exception("Timeout while waiting for pong from WebSocket client."); + } + } + await Task.Delay(TimeSpan.FromSeconds(30), cts.Token); + } + catch (Exception ex) + { + Console.WriteLine($"[ERROR] Ping failed for tunnel {tunnelId}: {ex.Message}"); + break; + } + } + }); + try { while (webSocket.State == WebSocketState.Open) @@ -26,11 +96,18 @@ public async Task HandleTunnelConnection(string tunnelId, WebSocket webSocket) } finally { + cts.Cancel(); + if (_tunnels.ContainsKey(tunnelId)) { _tunnels.Remove(tunnelId); Console.WriteLine($"[INFO] Tunnel {tunnelId} removed."); } + + if (_lastActivityTracker.ContainsKey(tunnelId)) + { + _lastActivityTracker.Remove(tunnelId); + } } } @@ -45,6 +122,11 @@ public async Task ForwardRequestToClient(string tunnelId, string message) var buffer = Encoding.UTF8.GetBytes(message); await webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Text, true, CancellationToken.None); + + if (_lastActivityTracker.ContainsKey(tunnelId)) + { + _lastActivityTracker[tunnelId] = DateTime.UtcNow; + } } public async Task ForwardRequestToWSClient(string tunnelId, string message) @@ -59,9 +141,14 @@ public async Task ForwardRequestToWSClient(string tunnelId, string messa await webSocket.SendAsync(new ArraySegment(buffer), WebSocketMessageType.Text, true, CancellationToken.None); + if (_lastActivityTracker.ContainsKey(tunnelId)) + { + _lastActivityTracker[tunnelId] = DateTime.UtcNow; + } + var responseBuffer = new byte[1024 * 1024 * 5]; - var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60)); - var completeResponse = new List(); + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(60)); + var completeResponse = new List(); try { @@ -91,5 +178,4 @@ public async Task ForwardRequestToWSClient(string tunnelId, string messa var jsonResponse = Encoding.UTF8.GetString(completeResponse.ToArray()); return jsonResponse; } - -} \ No newline at end of file +} diff --git a/TunnelizeClient/src/tunnelize.js b/TunnelizeClient/src/tunnelize.js index 46722f4..473c491 100644 --- a/TunnelizeClient/src/tunnelize.js +++ b/TunnelizeClient/src/tunnelize.js @@ -15,24 +15,9 @@ function connectToWebSocket(protocol, port, tunnelId = null) { const MAX_BUFFER_SIZE = 1024 * 1024 * 5; let wssUrl = !!tunnelId ? `wss://${url}/ws/${tunnelId}` : `wss://${url}/ws`; const ws = new WebSocket(wssUrl, { maxPayload: MAX_BUFFER_SIZE }); - let isAlive = true; - + ws.on('open', () => { console.info('[INFO] Connection established with the proxy'); - setInterval(() => { - console.info(`[DEBUG] Sending ping. Connection alive: ${isAlive}`); - if (isAlive) { - ws.ping(); - isAlive = false; - } else { - console.error('[ERROR] WebSocket connection appears to be dead. Reconnecting...'); - ws.terminate(); - } - }, 60000); - }); - - ws.on('pong', () => { - isAlive = true; }); ws.on('close', () => { @@ -47,6 +32,12 @@ function connectToWebSocket(protocol, port, tunnelId = null) { ws.on('message', (data) => { const message = data.toString('utf8'); + if (message === 'ping') { + console.debug('[DEBUG] Ping received from proxy. Sending pong...'); + ws.send('pong', { fin: true }); + return; + } + if (isTunnelId(message)) { console.log(`\n✅ Tunnel ID received: ${message}\nYou can use now https://${url}/${message}/*?param=abc`); } else {