diff --git a/src/OpenClaw.Shared/OpenClawGatewayClient.cs b/src/OpenClaw.Shared/OpenClawGatewayClient.cs index 411ae5e..0e21836 100644 --- a/src/OpenClaw.Shared/OpenClawGatewayClient.cs +++ b/src/OpenClaw.Shared/OpenClawGatewayClient.cs @@ -1,27 +1,13 @@ using System; using System.Collections.Generic; -using System.IO; -using System.Net.WebSockets; -using System.Text; using System.Text.Json; using System.Threading; using System.Threading.Tasks; namespace OpenClaw.Shared; -public class OpenClawGatewayClient : IDisposable +public class OpenClawGatewayClient : WebSocketClientBase { - private ClientWebSocket? _webSocket; - private readonly string _gatewayUrl; - private readonly string _gatewayUrlForDisplay; - private readonly string _token; - private readonly string? _credentials; - private readonly IOpenClawLogger _logger; - private CancellationTokenSource _cts; - private bool _disposed; - private int _reconnectAttempts; - private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 }; - // Tracked state private readonly Dictionary _sessions = new(); private readonly Dictionary _nodes = new(); @@ -45,8 +31,32 @@ private void ResetUnsupportedMethodFlags() _nodeListUnsupported = false; } + protected override int ReceiveBufferSize => 16384; + protected override string ClientRole => "gateway"; + + protected override Task ProcessMessageAsync(string json) + { + ProcessMessage(json); + return Task.CompletedTask; + } + + protected override Task OnConnectedAsync() + { + ResetUnsupportedMethodFlags(); + return Task.CompletedTask; + } + + protected override void OnDisconnected() + { + ClearPendingRequests(); + } + + protected override void OnDisposing() + { + ClearPendingRequests(); + } + // Events - public event EventHandler? StatusChanged; public event EventHandler? NotificationReceived; public event EventHandler? ActivityChanged; public event EventHandler? ChannelHealthUpdated; @@ -59,63 +69,17 @@ private void ResetUnsupportedMethodFlags() public event EventHandler? SessionCommandCompleted; public OpenClawGatewayClient(string gatewayUrl, string token, IOpenClawLogger? logger = null) + : base(gatewayUrl, token, logger) { - _gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl); - _gatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl); - _token = token; - _credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl); - _logger = logger ?? NullLogger.Instance; - _cts = new CancellationTokenSource(); - } - - public async Task ConnectAsync() - { - try - { - StatusChanged?.Invoke(this, ConnectionStatus.Connecting); - _logger.Info($"Connecting to gateway: {_gatewayUrlForDisplay}"); - - _webSocket = new ClientWebSocket(); - _webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30); - - // Set Origin header based on gateway URL (convert ws/wss to http/https) - var uri = new Uri(_gatewayUrl); - var originScheme = uri.Scheme == "wss" ? "https" : "http"; - var origin = $"{originScheme}://{uri.Host}:{uri.Port}"; - _webSocket.Options.SetRequestHeader("Origin", origin); - - if (!string.IsNullOrEmpty(_credentials)) - { - var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials); - - _webSocket.Options.SetRequestHeader( - "Authorization", - $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}"); - } - - await _webSocket.ConnectAsync(uri, _cts.Token); - - ResetUnsupportedMethodFlags(); - _reconnectAttempts = 0; - _logger.Info("Gateway connected, waiting for challenge..."); - - // Don't send connect yet - wait for challenge event in ListenForMessagesAsync - _ = Task.Run(() => ListenForMessagesAsync(), _cts.Token); - } - catch (Exception ex) - { - _logger.Error("Connection failed", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); - } } public async Task DisconnectAsync() { - if (_webSocket?.State == WebSocketState.Open) + if (IsConnected) { try { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", CancellationToken.None); + await CloseWebSocketAsync(); } catch (Exception ex) { @@ -123,13 +87,13 @@ public async Task DisconnectAsync() } } ClearPendingRequests(); - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); + RaiseStatusChanged(ConnectionStatus.Disconnected); _logger.Info("Disconnected"); } public async Task CheckHealthAsync() { - if (_webSocket?.State != WebSocketState.Open) + if (!IsConnected) { await ReconnectWithBackoffAsync(); return; @@ -149,14 +113,14 @@ public async Task CheckHealthAsync() catch (Exception ex) { _logger.Error("Health check failed", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); + RaiseStatusChanged(ConnectionStatus.Error); await ReconnectWithBackoffAsync(); } } public async Task SendChatMessageAsync(string message) { - if (_webSocket?.State != WebSocketState.Open) + if (!IsConnected) throw new InvalidOperationException("Gateway connection is not open"); var req = new @@ -179,7 +143,7 @@ public async Task RequestSessionsAsync() /// Request usage/context info from gateway (may not be supported on all gateways). public async Task RequestUsageAsync() { - if (_webSocket?.State != WebSocketState.Open) return; + if (!IsConnected) return; try { if (_usageStatusUnsupported) @@ -270,7 +234,7 @@ public Task CompactSessionAsync(string key, int maxLines = 400) /// Start a channel (telegram, whatsapp, etc). public async Task StartChannelAsync(string channelName) { - if (_webSocket?.State != WebSocketState.Open) return false; + if (!IsConnected) return false; try { var req = new @@ -294,7 +258,7 @@ public async Task StartChannelAsync(string channelName) /// Stop a channel (telegram, whatsapp, etc). public async Task StopChannelAsync(string channelName) { - if (_webSocket?.State != WebSocketState.Open) return false; + if (!IsConnected) return false; try { var req = new @@ -315,31 +279,6 @@ public async Task StopChannelAsync(string channelName) } } - // --- Connection management --- - - private async Task ReconnectWithBackoffAsync() - { - var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)]; - _reconnectAttempts++; - _logger.Warn($"Reconnecting in {delay}ms (attempt {_reconnectAttempts})"); - StatusChanged?.Invoke(this, ConnectionStatus.Connecting); - - try - { - await Task.Delay(delay, _cts.Token); - _webSocket?.Dispose(); - _webSocket = null; - await ConnectAsync(); - } - catch (OperationCanceledException) { } - catch (Exception ex) - { - _logger.Error("Reconnect failed", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); - // Don't recurse — the listen loop will trigger reconnect again - } - } - private async Task SendConnectMessageAsync(string? nonce = null) { // Use "cli" client ID for native apps - no browser security checks @@ -373,31 +312,9 @@ private async Task SendConnectMessageAsync(string? nonce = null) await SendRawAsync(JsonSerializer.Serialize(msg)); } - private async Task SendRawAsync(string message) - { - // Capture local reference to avoid TOCTOU race with reconnect/dispose - var ws = _webSocket; - if (ws?.State != WebSocketState.Open) return; - - try - { - var bytes = Encoding.UTF8.GetBytes(message); - await ws.SendAsync(new ArraySegment(bytes), - WebSocketMessageType.Text, true, _cts.Token); - } - catch (ObjectDisposedException) - { - // WebSocket was disposed between state check and send - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState) - { - _logger.Warn($"WebSocket send failed (state changed): {ex.Message}"); - } - } - private async Task SendTrackedRequestAsync(string method, object? parameters = null) { - if (_webSocket?.State != WebSocketState.Open) return; + if (!IsConnected) return; var requestId = Guid.NewGuid().ToString(); TrackPendingRequest(requestId, method); @@ -482,60 +399,6 @@ private void ClearPendingRequests() } } - // --- Message loop --- - - private async Task ListenForMessagesAsync() - { - var buffer = new byte[16384]; // Larger buffer for big events - var sb = new StringBuilder(); - - try - { - while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested) - { - var result = await _webSocket.ReceiveAsync( - new ArraySegment(buffer), _cts.Token); - - if (result.MessageType == WebSocketMessageType.Text) - { - sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); - if (result.EndOfMessage) - { - ProcessMessage(sb.ToString()); - sb.Clear(); - } - } - else if (result.MessageType == WebSocketMessageType.Close) - { - var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown"; - var closeDesc = _webSocket.CloseStatusDescription ?? "no description"; - _logger.Info($"Server closed connection: {closeStatus} - {closeDesc}"); - ClearPendingRequests(); - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); - break; - } - } - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) - { - _logger.Warn("Connection closed prematurely"); - ClearPendingRequests(); - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); - } - catch (OperationCanceledException) { } - catch (Exception ex) - { - _logger.Error("Listen error", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); - } - - // Auto-reconnect if not intentionally disposed - if (!_disposed && !_cts.Token.IsCancellationRequested) - { - await ReconnectWithBackoffAsync(); - } - } - // --- Message processing --- private void ProcessMessage(string json) @@ -594,7 +457,7 @@ private void HandleResponse(JsonElement root) if (payload.TryGetProperty("type", out var t) && t.GetString() == "hello-ok") { _logger.Info("Handshake complete (hello-ok)"); - StatusChanged?.Invoke(this, ConnectionStatus.Connected); + RaiseStatusChanged(ConnectionStatus.Connected); // Request initial state after handshake _ = Task.Run(async () => @@ -1738,21 +1601,4 @@ private static string TruncateLabel(string text, int maxLen = 60) if (string.IsNullOrEmpty(text) || text.Length <= maxLen) return text; return text[..(maxLen - 1)] + "…"; } - - public void Dispose() - { - if (_disposed) return; - _disposed = true; - - try { _cts.Cancel(); } catch { } - - ClearPendingRequests(); - - var ws = _webSocket; - _webSocket = null; - try { ws?.Dispose(); } catch { } - - // Don't dispose _cts immediately — listen loop or reconnect may still reference it. - // It will be GC'd after all pending tasks complete. - } } diff --git a/src/OpenClaw.Shared/WebSocketClientBase.cs b/src/OpenClaw.Shared/WebSocketClientBase.cs new file mode 100644 index 0000000..72c4d10 --- /dev/null +++ b/src/OpenClaw.Shared/WebSocketClientBase.cs @@ -0,0 +1,269 @@ +using System; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace OpenClaw.Shared; + +/// +/// Abstract base class for WebSocket-based gateway clients. +/// Extracts shared connection lifecycle: connect, listen, reconnect, send, dispose. +/// Subclasses implement message processing and provide configuration via abstract members. +/// +public abstract class WebSocketClientBase : IDisposable +{ + private ClientWebSocket? _webSocket; + private readonly string _gatewayUrl; + private readonly string? _credentials; + private CancellationTokenSource _cts; + private bool _disposed; + private int _reconnectAttempts; + private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 }; + + protected readonly string _token; + protected readonly IOpenClawLogger _logger; + + /// Gateway URL with credentials stripped, safe for logging/display. + protected string GatewayUrlForDisplay { get; } + + /// Whether Dispose has been called. + protected bool IsDisposed => _disposed; + + /// Whether the WebSocket is currently open and connected. + protected bool IsConnected => _webSocket?.State == WebSocketState.Open; + + /// Cancellation token tied to this client's lifetime. + protected CancellationToken CancellationToken => _cts.Token; + + // Events + public event EventHandler? StatusChanged; + + // --- Abstract members (subclass MUST implement) --- + + /// + /// Process a received WebSocket text message. Called from the listen loop. + /// Gateway wraps its sync ProcessMessage with Task.CompletedTask; + /// Node directly uses its async implementation. + /// + protected abstract Task ProcessMessageAsync(string json); + + /// Receive buffer size in bytes. Gateway: 16384, Node: 65536. + protected abstract int ReceiveBufferSize { get; } + + /// Client role for log messages, e.g. "gateway" or "node". + protected abstract string ClientRole { get; } + + // --- Virtual hooks (subclass MAY override) --- + + /// Called after WebSocket connects, before the listen loop starts. + protected virtual Task OnConnectedAsync() => Task.CompletedTask; + + /// Called when the server closes the connection or it drops. + protected virtual void OnDisconnected() { } + + /// Called on unrecoverable listen-loop errors. + protected virtual void OnError(Exception ex) { } + + /// Called at the start of Dispose, before CTS cancellation. + protected virtual void OnDisposing() { } + + protected WebSocketClientBase(string gatewayUrl, string token, IOpenClawLogger? logger = null) + { + if (string.IsNullOrEmpty(gatewayUrl)) + throw new ArgumentException("Gateway URL is required.", nameof(gatewayUrl)); + if (string.IsNullOrEmpty(token)) + throw new ArgumentException("Token is required.", nameof(token)); + + _gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl); + GatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl); + _token = token; + _credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl); + _logger = logger ?? NullLogger.Instance; + _cts = new CancellationTokenSource(); + } + + public async Task ConnectAsync() + { + try + { + RaiseStatusChanged(ConnectionStatus.Connecting); + _logger.Info($"Connecting to {ClientRole}: {GatewayUrlForDisplay}"); + + _webSocket = new ClientWebSocket(); + _webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30); + + // Set Origin header (convert ws/wss to http/https) + var uri = new Uri(_gatewayUrl); + var originScheme = uri.Scheme == "wss" ? "https" : "http"; + var origin = $"{originScheme}://{uri.Host}:{uri.Port}"; + _webSocket.Options.SetRequestHeader("Origin", origin); + + if (!string.IsNullOrEmpty(_credentials)) + { + var credentialsToEncode = GatewayUrlHelper.DecodeCredentials(_credentials); + _webSocket.Options.SetRequestHeader( + "Authorization", + $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(credentialsToEncode))}"); + } + + await _webSocket.ConnectAsync(uri, _cts.Token); + + _reconnectAttempts = 0; + _logger.Info($"{ClientRole} connected, waiting for challenge..."); + + await OnConnectedAsync(); + + _ = Task.Run(() => ListenForMessagesAsync(), _cts.Token); + } + catch (Exception ex) + { + _logger.Error($"{ClientRole} connection failed", ex); + RaiseStatusChanged(ConnectionStatus.Error); + } + } + + private async Task ListenForMessagesAsync() + { + var buffer = new byte[ReceiveBufferSize]; + var sb = new StringBuilder(); + + try + { + while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested) + { + var result = await _webSocket.ReceiveAsync( + new ArraySegment(buffer), _cts.Token); + + if (result.MessageType == WebSocketMessageType.Text) + { + sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); + if (result.EndOfMessage) + { + await ProcessMessageAsync(sb.ToString()); + sb.Clear(); + } + } + else if (result.MessageType == WebSocketMessageType.Close) + { + var closeStatus = _webSocket.CloseStatus?.ToString() ?? "unknown"; + var closeDesc = _webSocket.CloseStatusDescription ?? "no description"; + _logger.Info($"Server closed connection: {closeStatus} - {closeDesc}"); + OnDisconnected(); + RaiseStatusChanged(ConnectionStatus.Disconnected); + break; + } + } + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) + { + _logger.Warn("Connection closed prematurely"); + OnDisconnected(); + RaiseStatusChanged(ConnectionStatus.Disconnected); + } + catch (OperationCanceledException) { } + catch (ObjectDisposedException) { /* CTS or WebSocket disposed during shutdown */ } + catch (Exception ex) + { + _logger.Error($"{ClientRole} listen error", ex); + OnError(ex); + RaiseStatusChanged(ConnectionStatus.Error); + } + + // Auto-reconnect if not intentionally disposed + if (!_disposed) + { + try + { + if (!_cts.Token.IsCancellationRequested) + { + await ReconnectWithBackoffAsync(); + } + } + catch (ObjectDisposedException) { /* CTS disposed during check */ } + } + } + + protected async Task ReconnectWithBackoffAsync() + { + var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)]; + _reconnectAttempts++; + _logger.Warn($"{ClientRole} reconnecting in {delay}ms (attempt {_reconnectAttempts})"); + RaiseStatusChanged(ConnectionStatus.Connecting); + + try + { + await Task.Delay(delay, _cts.Token); + + // Check cancellation after delay + if (_cts.Token.IsCancellationRequested) return; + + // Safely dispose old socket + var oldSocket = _webSocket; + _webSocket = null; + try { oldSocket?.Dispose(); } catch { /* ignore dispose errors */ } + + await ConnectAsync(); + } + catch (OperationCanceledException) { } + catch (Exception ex) + { + _logger.Error($"{ClientRole} reconnect failed", ex); + RaiseStatusChanged(ConnectionStatus.Error); + } + } + + /// Send a text message over the WebSocket. Thread-safe. + protected async Task SendRawAsync(string message) + { + // Capture local reference to avoid TOCTOU race with reconnect/dispose + var ws = _webSocket; + if (ws?.State != WebSocketState.Open) return; + + try + { + var bytes = Encoding.UTF8.GetBytes(message); + await ws.SendAsync(new ArraySegment(bytes), + WebSocketMessageType.Text, true, _cts.Token); + } + catch (ObjectDisposedException) + { + // WebSocket was disposed between state check and send + } + catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState) + { + _logger.Warn($"WebSocket send failed (state changed): {ex.Message}"); + } + } + + /// Gracefully close the WebSocket connection. + protected async Task CloseWebSocketAsync() + { + var ws = _webSocket; + if (ws?.State == WebSocketState.Open) + { + await ws.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", System.Threading.CancellationToken.None); + } + } + + /// Fire the StatusChanged event. Use this instead of directly invoking the event. + protected void RaiseStatusChanged(ConnectionStatus status) + => StatusChanged?.Invoke(this, status); + + public void Dispose() + { + if (_disposed) return; + _disposed = true; + + OnDisposing(); + + try { _cts.Cancel(); } catch { } + + var ws = _webSocket; + _webSocket = null; + try { ws?.Dispose(); } catch { } + + // Don't dispose _cts immediately — listen loop or reconnect may still reference it. + // It will be GC'd after all pending tasks complete. + } +} diff --git a/src/OpenClaw.Shared/WindowsNodeClient.cs b/src/OpenClaw.Shared/WindowsNodeClient.cs index 6f6bc3b..0f02f82 100644 --- a/src/OpenClaw.Shared/WindowsNodeClient.cs +++ b/src/OpenClaw.Shared/WindowsNodeClient.cs @@ -1,10 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Net.WebSockets; -using System.Text; using System.Text.Json; -using System.Threading; using System.Threading.Tasks; namespace OpenClaw.Shared; @@ -13,19 +10,9 @@ namespace OpenClaw.Shared; /// Windows Node client - extends gateway connection to act as a node /// Supports both operator (existing) and node (new) roles /// -public class WindowsNodeClient : IDisposable +public class WindowsNodeClient : WebSocketClientBase { - private ClientWebSocket? _webSocket; - private readonly string _gatewayUrl; - private readonly string _gatewayUrlForDisplay; - private readonly string _token; - private readonly string? _credentials; - private readonly IOpenClawLogger _logger; private readonly DeviceIdentity _deviceIdentity; - private CancellationTokenSource _cts; - private bool _disposed; - private int _reconnectAttempts; - private static readonly int[] BackoffMs = { 1000, 2000, 4000, 8000, 15000, 30000, 60000 }; // Node capabilities registry private readonly List _capabilities = new(); @@ -38,13 +25,12 @@ public class WindowsNodeClient : IDisposable private bool _isPendingApproval; // True when connected but awaiting pairing approval // Events - public event EventHandler? StatusChanged; public event EventHandler? InvokeReceived; public event EventHandler? PairingStatusChanged; - public bool IsConnected => _isConnected; + public new bool IsConnected => _isConnected; public string? NodeId => _nodeId; - public string GatewayUrl => _gatewayUrlForDisplay; + public string GatewayUrl => GatewayUrlForDisplay; public IReadOnlyList Capabilities => _capabilities; /// True if connected but waiting for pairing approval on gateway @@ -61,15 +47,12 @@ public class WindowsNodeClient : IDisposable /// Full device ID for approval command public string FullDeviceId => _deviceIdentity.DeviceId; + protected override int ReceiveBufferSize => 65536; + protected override string ClientRole => "node"; + public WindowsNodeClient(string gatewayUrl, string token, string dataPath, IOpenClawLogger? logger = null) + : base(gatewayUrl, token, logger) { - _gatewayUrl = GatewayUrlHelper.NormalizeForWebSocket(gatewayUrl); - _gatewayUrlForDisplay = GatewayUrlHelper.SanitizeForDisplay(_gatewayUrl); - _token = token; - _credentials = GatewayUrlHelper.ExtractCredentials(gatewayUrl); - _logger = logger ?? NullLogger.Instance; - _cts = new CancellationTokenSource(); - // Initialize device identity _deviceIdentity = new DeviceIdentity(dataPath, _logger); _deviceIdentity.Initialize(); @@ -77,7 +60,7 @@ public WindowsNodeClient(string gatewayUrl, string token, string dataPath, IOpen // Initialize registration _registration = new NodeRegistration { - Id = _deviceIdentity.DeviceId, // Use device ID from keypair + Id = _deviceIdentity.DeviceId, Version = "1.0.0", Platform = "windows", DisplayName = $"Windows Node ({Environment.MachineName})" @@ -115,132 +98,21 @@ public void SetPermission(string permission, bool value) _registration.Permissions[permission] = value; } - /// - /// Connect to gateway as a node - /// - public async Task ConnectAsync() - { - try - { - StatusChanged?.Invoke(this, ConnectionStatus.Connecting); - _logger.Info($"Connecting to gateway as node: {_gatewayUrlForDisplay}"); - - _webSocket = new ClientWebSocket(); - _webSocket.Options.KeepAliveInterval = TimeSpan.FromSeconds(30); - - // Set Origin header - var uri = new Uri(_gatewayUrl); - var originScheme = uri.Scheme == "wss" ? "https" : "http"; - var origin = $"{originScheme}://{uri.Host}:{uri.Port}"; - _webSocket.Options.SetRequestHeader("Origin", origin); - - if (!string.IsNullOrEmpty(_credentials)) - { - var authCredentials = GatewayUrlHelper.DecodeCredentials(_credentials); - - _webSocket.Options.SetRequestHeader( - "Authorization", - $"Basic {Convert.ToBase64String(Encoding.UTF8.GetBytes(authCredentials))}"); - } - - await _webSocket.ConnectAsync(uri, _cts.Token); - - _reconnectAttempts = 0; - _logger.Info("Node connected, waiting for challenge..."); - - // Start message loop - _ = Task.Run(() => ListenForMessagesAsync(), _cts.Token); - } - catch (Exception ex) - { - _logger.Error("Node connection failed", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); - } - } - /// /// Disconnect from gateway /// - public async Task DisconnectAsync() + public Task DisconnectAsync() { _isConnected = false; - if (_webSocket?.State == WebSocketState.Open) - { - try - { - await _webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Disconnecting", CancellationToken.None); - } - catch (Exception ex) - { - _logger.Warn($"Error during disconnect: {ex.Message}"); - } - } - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); + Dispose(); + RaiseStatusChanged(ConnectionStatus.Disconnected); _logger.Info("Node disconnected"); + return Task.CompletedTask; } // --- Message handling --- - private async Task ListenForMessagesAsync() - { - var buffer = new byte[65536]; // Large buffer for image data - var sb = new StringBuilder(); - - try - { - while (_webSocket?.State == WebSocketState.Open && !_cts.Token.IsCancellationRequested) - { - var result = await _webSocket.ReceiveAsync( - new ArraySegment(buffer), _cts.Token); - - if (result.MessageType == WebSocketMessageType.Text) - { - sb.Append(Encoding.UTF8.GetString(buffer, 0, result.Count)); - if (result.EndOfMessage) - { - await ProcessMessageAsync(sb.ToString()); - sb.Clear(); - } - } - else if (result.MessageType == WebSocketMessageType.Close) - { - _logger.Info("Server closed connection"); - _isConnected = false; - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); - break; - } - } - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely) - { - _logger.Warn("Connection closed prematurely"); - _isConnected = false; - StatusChanged?.Invoke(this, ConnectionStatus.Disconnected); - } - catch (OperationCanceledException) { } - catch (ObjectDisposedException) { /* CTS was disposed */ } - catch (Exception ex) - { - _logger.Error("Node listen error", ex); - _isConnected = false; - StatusChanged?.Invoke(this, ConnectionStatus.Error); - } - - // Auto-reconnect (with extra safety checks) - if (!_disposed) - { - try - { - if (!_cts.Token.IsCancellationRequested) - { - await ReconnectWithBackoffAsync(); - } - } - catch (ObjectDisposedException) { /* CTS was disposed during check */ } - } - } - - private async Task ProcessMessageAsync(string json) + protected override async Task ProcessMessageAsync(string json) { try { @@ -618,7 +490,7 @@ private void HandleResponse(JsonElement root) _deviceIdentity.DeviceId)); } - StatusChanged?.Invoke(this, ConnectionStatus.Connected); + RaiseStatusChanged(ConnectionStatus.Connected); } // Handle errors @@ -638,7 +510,7 @@ private void HandleResponse(JsonElement root) } } _logger.Error($"Node registration failed: {error} (code: {errorCode})"); - StatusChanged?.Invoke(this, ConnectionStatus.Error); + RaiseStatusChanged(ConnectionStatus.Error); } } @@ -790,70 +662,13 @@ private async Task SendPongAsync(string? requestId) await SendRawAsync(JsonSerializer.Serialize(msg)); } - private async Task SendRawAsync(string message) - { - // Capture local reference to avoid race conditions - var ws = _webSocket; - if (ws?.State != WebSocketState.Open) return; - - try - { - var bytes = Encoding.UTF8.GetBytes(message); - await ws.SendAsync(new ArraySegment(bytes), - WebSocketMessageType.Text, true, _cts.Token); - } - catch (ObjectDisposedException) - { - // WebSocket was disposed between check and send - ignore - } - catch (WebSocketException ex) when (ex.WebSocketErrorCode == WebSocketError.InvalidState) - { - // WebSocket state changed - ignore - _logger.Warn($"WebSocket send failed (state changed): {ex.Message}"); - } - } - - private async Task ReconnectWithBackoffAsync() + protected override void OnDisconnected() { - var delay = BackoffMs[Math.Min(_reconnectAttempts, BackoffMs.Length - 1)]; - _reconnectAttempts++; - _logger.Warn($"Node reconnecting in {delay}ms (attempt {_reconnectAttempts})"); - StatusChanged?.Invoke(this, ConnectionStatus.Connecting); - - try - { - await Task.Delay(delay, _cts.Token); - - // Check cancellation after delay - if (_cts.Token.IsCancellationRequested) return; - - // Safely dispose old socket - var oldSocket = _webSocket; - _webSocket = null; - try { oldSocket?.Dispose(); } catch { /* ignore dispose errors */ } - - await ConnectAsync(); - } - catch (OperationCanceledException) { } - catch (Exception ex) - { - _logger.Error("Node reconnect failed", ex); - StatusChanged?.Invoke(this, ConnectionStatus.Error); - } + _isConnected = false; } - - public void Dispose() + + protected override void OnError(Exception ex) { - if (_disposed) return; - _disposed = true; - - try { _cts.Cancel(); } catch { /* ignore */ } - - var ws = _webSocket; - _webSocket = null; - try { ws?.Dispose(); } catch { /* ignore */ } - - // Don't dispose _cts immediately — reconnect loop may still reference it. - // It will be GC'd after all pending tasks complete. + _isConnected = false; } } diff --git a/tests/OpenClaw.Shared.Tests/OpenClawGatewayClientTests.cs b/tests/OpenClaw.Shared.Tests/OpenClawGatewayClientTests.cs index 4d269d7..424182d 100644 --- a/tests/OpenClaw.Shared.Tests/OpenClawGatewayClientTests.cs +++ b/tests/OpenClaw.Shared.Tests/OpenClawGatewayClientTests.cs @@ -609,7 +609,7 @@ public void Constructor_NormalizesHttpToWs(string inputUrl, string expectedWsUrl { var client = new OpenClawGatewayClient(inputUrl, "test-token"); - var field = typeof(OpenClawGatewayClient).GetField( + var field = typeof(OpenClawGatewayClient).BaseType?.GetField( "_gatewayUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); var actualUrl = field?.GetValue(client) as string; diff --git a/tests/OpenClaw.Shared.Tests/WebSocketClientBaseTests.cs b/tests/OpenClaw.Shared.Tests/WebSocketClientBaseTests.cs new file mode 100644 index 0000000..c5106db --- /dev/null +++ b/tests/OpenClaw.Shared.Tests/WebSocketClientBaseTests.cs @@ -0,0 +1,244 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Xunit; + +namespace OpenClaw.Shared.Tests; + +/// +/// Concrete test double for WebSocketClientBase. +/// Exposes hooks and tracking for unit testing base class behavior. +/// +public class TestWebSocketClient : WebSocketClientBase +{ + public List ProcessedMessages { get; } = new(); + public int OnConnectedCallCount { get; private set; } + public int OnDisconnectedCallCount { get; private set; } + public int OnErrorCallCount { get; private set; } + public Exception? LastError { get; private set; } + public int OnDisposingCallCount { get; private set; } + + protected override int ReceiveBufferSize => 8192; + protected override string ClientRole => "test"; + + public TestWebSocketClient(string gatewayUrl, string token, IOpenClawLogger? logger = null) + : base(gatewayUrl, token, logger) { } + + protected override Task ProcessMessageAsync(string json) + { + ProcessedMessages.Add(json); + return Task.CompletedTask; + } + + protected override Task OnConnectedAsync() + { + OnConnectedCallCount++; + return Task.CompletedTask; + } + + protected override void OnDisconnected() + { + OnDisconnectedCallCount++; + } + + protected override void OnError(Exception ex) + { + OnErrorCallCount++; + LastError = ex; + } + + protected override void OnDisposing() + { + OnDisposingCallCount++; + } + + // Expose protected members for testing + public void TestRaiseStatusChanged(ConnectionStatus status) + => RaiseStatusChanged(status); + + public bool TestIsDisposed => IsDisposed; + public string TestGatewayUrlForDisplay => GatewayUrlForDisplay; + public string TestToken => _token; + public IOpenClawLogger TestLogger => _logger; +} + +public class WebSocketClientBaseTests +{ + private readonly TestLogger _logger = new(); + + [Theory] + [InlineData("http://localhost:18789", "ws://localhost:18789")] + [InlineData("https://gateway.example.com", "wss://gateway.example.com")] + [InlineData("ws://localhost:18789", "ws://localhost:18789")] + [InlineData("wss://gateway.example.com", "wss://gateway.example.com")] + public void Constructor_NormalizesUrl(string input, string _) + { + var client = new TestWebSocketClient(input, "test-token", _logger); + // GatewayUrlForDisplay is the sanitized version — just verify it's set + Assert.NotNull(client.TestGatewayUrlForDisplay); + Assert.DoesNotContain("@", client.TestGatewayUrlForDisplay); // credentials stripped + client.Dispose(); + } + + [Fact] + public void Constructor_StoresToken() + { + var client = new TestWebSocketClient("ws://localhost:18789", "my-token", _logger); + Assert.Equal("my-token", client.TestToken); + client.Dispose(); + } + + [Fact] + public void Constructor_UsesNullLoggerWhenNotProvided() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token"); + Assert.NotNull(client.TestLogger); + client.Dispose(); + } + + [Fact] + public void Constructor_ThrowsOnNullUrl() + { + Assert.Throws(() => + new TestWebSocketClient(null!, "token", _logger)); + } + + [Fact] + public void Constructor_ThrowsOnEmptyUrl() + { + Assert.Throws(() => + new TestWebSocketClient("", "token", _logger)); + } + + [Fact] + public void Constructor_ThrowsOnNullToken() + { + Assert.Throws(() => + new TestWebSocketClient("ws://localhost", null!, _logger)); + } + + [Fact] + public void Constructor_ThrowsOnEmptyToken() + { + Assert.Throws(() => + new TestWebSocketClient("ws://localhost", "", _logger)); + } + + [Fact] + public void Constructor_WithCredentialUrl_StripsFromDisplay() + { + var client = new TestWebSocketClient("ws://user:pass@localhost:18789", "token", _logger); + Assert.DoesNotContain("pass", client.TestGatewayUrlForDisplay); + client.Dispose(); + } + + [Fact] + public void Dispose_SetsIsDisposed() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + Assert.False(client.TestIsDisposed); + client.Dispose(); + Assert.True(client.TestIsDisposed); + } + + [Fact] + public void Dispose_IsIdempotent() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + client.Dispose(); + client.Dispose(); // second call should not throw + Assert.True(client.TestIsDisposed); + Assert.Equal(1, client.OnDisposingCallCount); // hook called only once + } + + [Fact] + public void Dispose_CallsOnDisposingHook() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + client.Dispose(); + Assert.Equal(1, client.OnDisposingCallCount); + } + + [Fact] + public void RaiseStatusChanged_FiresEvent() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + ConnectionStatus? received = null; + client.StatusChanged += (_, status) => received = status; + + client.TestRaiseStatusChanged(ConnectionStatus.Connecting); + + Assert.Equal(ConnectionStatus.Connecting, received); + client.Dispose(); + } + + [Fact] + public void RaiseStatusChanged_WithNoSubscribers_DoesNotThrow() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + client.TestRaiseStatusChanged(ConnectionStatus.Connected); // no subscribers — should not throw + client.Dispose(); + } + + [Fact] + public void RaiseStatusChanged_MultipleSubscribers_AllNotified() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + var statuses = new List(); + client.StatusChanged += (_, s) => statuses.Add(s); + client.StatusChanged += (_, s) => statuses.Add(s); + + client.TestRaiseStatusChanged(ConnectionStatus.Error); + + Assert.Equal(2, statuses.Count); + Assert.All(statuses, s => Assert.Equal(ConnectionStatus.Error, s)); + client.Dispose(); + } + + [Fact] + public void IsConnected_FalseBeforeConnect() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + // Reflection to check IsConnected on the base + var prop = typeof(WebSocketClientBase).GetProperty("IsConnected", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var isConnected = (bool)prop!.GetValue(client)!; + Assert.False(isConnected); + client.Dispose(); + } + + [Fact] + public void IsConnected_FalseAfterDispose() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + client.Dispose(); + var prop = typeof(WebSocketClientBase).GetProperty("IsConnected", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + var isConnected = (bool)prop!.GetValue(client)!; + Assert.False(isConnected); + } + + [Fact] + public async Task ConnectAsync_RaisesStatusChangedConnecting() + { + var client = new TestWebSocketClient("ws://localhost:18789", "token", _logger); + var statuses = new List(); + client.StatusChanged += (_, s) => statuses.Add(s); + + // ConnectAsync will fail (no real server) but should still fire Connecting then Error + await client.ConnectAsync(); + + Assert.Contains(ConnectionStatus.Connecting, statuses); + Assert.Contains(ConnectionStatus.Error, statuses); + client.Dispose(); + } +} + +public class TestLogger : IOpenClawLogger +{ + public List Logs { get; } = new(); + public void Info(string message) => Logs.Add($"INFO: {message}"); + public void Debug(string message) => Logs.Add($"DEBUG: {message}"); + public void Warn(string message) => Logs.Add($"WARN: {message}"); + public void Error(string message, Exception? ex = null) => Logs.Add($"ERROR: {message}"); +} diff --git a/tests/OpenClaw.Shared.Tests/WindowsNodeClientTests.cs b/tests/OpenClaw.Shared.Tests/WindowsNodeClientTests.cs index 018ce9c..8e9f269 100644 --- a/tests/OpenClaw.Shared.Tests/WindowsNodeClientTests.cs +++ b/tests/OpenClaw.Shared.Tests/WindowsNodeClientTests.cs @@ -19,7 +19,7 @@ public void Constructor_NormalizesGatewayUrl(string inputUrl, string expectedUrl try { using var client = new WindowsNodeClient(inputUrl, "test-token", dataPath); - var field = typeof(WindowsNodeClient).GetField( + var field = typeof(WindowsNodeClient).BaseType?.GetField( "_gatewayUrl", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); var actualUrl = field?.GetValue(client) as string;