diff --git a/Directory.Packages.props b/Directory.Packages.props index 17a4a3c9b..18fab1c05 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -53,6 +53,7 @@ + diff --git a/src/fdc3/dotnet/AppDirectory/test/AppDirectory.Tests/AppDirectory.GetApps.Tests.cs b/src/fdc3/dotnet/AppDirectory/test/AppDirectory.Tests/AppDirectory.GetApps.Tests.cs index 159047bce..97ba6efe9 100644 --- a/src/fdc3/dotnet/AppDirectory/test/AppDirectory.Tests/AppDirectory.GetApps.Tests.cs +++ b/src/fdc3/dotnet/AppDirectory/test/AppDirectory.Tests/AppDirectory.GetApps.Tests.cs @@ -41,7 +41,7 @@ public async Task GetApps_loads_the_data_from_a_file( apps.Should().BeEquivalentTo(GetAppsExpectation); } - [Theory (Skip ="Fail"), CombinatorialData] + [Theory, CombinatorialData] public async Task GetApps_reloads_the_data_if_the_source_file_has_changed(bool useApiSchema) { var source = "/apps.json"; @@ -62,8 +62,8 @@ await fileSystem.File.WriteAllTextAsync( useApiSchema ? GetAppsApiResponseChanged : GetAppsJsonArrayChanged, Encoding.UTF8); - await TaskExtensions.WaitForBackgroundTasksAsync(TimeSpan.FromMilliseconds(100)); - + await TaskExtensions.WaitForBackgroundTasksAsync(TimeSpan.FromSeconds(20)); + var apps = await appDirectory.GetApps(); apps.Should().BeEquivalentTo(GetAppsExpectationChanged); diff --git a/src/fdc3/dotnet/DesktopAgent/tests/DesktopAgent.Tests/Infrastructure/Internal/Fdc3DesktopAgentMessageRouterService.Tests.cs b/src/fdc3/dotnet/DesktopAgent/tests/DesktopAgent.Tests/Infrastructure/Internal/Fdc3DesktopAgentMessageRouterService.Tests.cs index 83fc1b419..3b721b2ee 100644 --- a/src/fdc3/dotnet/DesktopAgent/tests/DesktopAgent.Tests/Infrastructure/Internal/Fdc3DesktopAgentMessageRouterService.Tests.cs +++ b/src/fdc3/dotnet/DesktopAgent/tests/DesktopAgent.Tests/Infrastructure/Internal/Fdc3DesktopAgentMessageRouterService.Tests.cs @@ -405,19 +405,19 @@ public async Task StoreIntentResult_succeeds_with_channel() var target = await _mockModuleLoader.Object.StartModule(new("appId4")); var targetFdc3InstanceId = Fdc3InstanceIdRetriever.Get(target); var raiseIntentRequest = new RaiseIntentRequest() - { - MessageId = int.MaxValue, - Fdc3InstanceId = Guid.NewGuid().ToString(), - Intent = "intentMetadata4", - Selected = false, - Context = new Context("context2"), - TargetAppIdentifier = new AppIdentifier() { AppId = "appId4", InstanceId = targetFdc3InstanceId } - }; + { + MessageId = int.MaxValue, + Fdc3InstanceId = Guid.NewGuid().ToString(), + Intent = "intentMetadata4", + Selected = false, + Context = new Context("context2"), + TargetAppIdentifier = new AppIdentifier() { AppId = "appId4", InstanceId = targetFdc3InstanceId } + }; var raiseIntentResult = await _fdc3.HandleRaiseIntent(raiseIntentRequest, new MessageContext()); raiseIntentResult.Should().NotBeNull(); - - raiseIntentResult.Should().NotBeNull(); + raiseIntentResult!.Error.Should().BeNull(); + raiseIntentResult.AppMetadata.Should().NotBeNull(); raiseIntentResult!.AppMetadata.Should().HaveCount(1); var storeIntentRequest = new StoreIntentResultRequest() diff --git a/src/messaging/dotnet/Messaging.sln b/src/messaging/dotnet/Messaging.sln index e56f5423f..218114cd7 100644 --- a/src/messaging/dotnet/Messaging.sln +++ b/src/messaging/dotnet/Messaging.sln @@ -31,7 +31,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "lib", "lib", "{B7E63957-3C1 EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MorganStanley.ComposeUI.Testing", "..\..\shared\dotnet\MorganStanley.ComposeUI.Testing\MorganStanley.ComposeUI.Testing.csproj", "{AE71CBC4-FD4E-4C66-B894-D7C31DE4D1BE}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MorganStanley.ComposeUI.Messaging.Host.Tests", "test\Host.Tests\MorganStanley.ComposeUI.Messaging.Host.Tests.csproj", "{CEF78D3F-C645-4471-BAD2-9C538A0CA763}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MorganStanley.ComposeUI.Messaging.Host.Tests", "test\Host.Tests\MorganStanley.ComposeUI.Messaging.Host.Tests.csproj", "{CEF78D3F-C645-4471-BAD2-9C538A0CA763}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution diff --git a/src/messaging/dotnet/Messaging.sln.DotSettings b/src/messaging/dotnet/Messaging.sln.DotSettings index 249114a50..a919bc949 100644 --- a/src/messaging/dotnet/Messaging.sln.DotSettings +++ b/src/messaging/dotnet/Messaging.sln.DotSettings @@ -10,4 +10,5 @@ to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + True True \ No newline at end of file diff --git a/src/messaging/dotnet/src/Client/Client/MessageRouterClient.cs b/src/messaging/dotnet/src/Client/Client/MessageRouterClient.cs index 6d0a683bf..14e9f4ad3 100644 --- a/src/messaging/dotnet/src/Client/Client/MessageRouterClient.cs +++ b/src/messaging/dotnet/src/Client/Client/MessageRouterClient.cs @@ -17,6 +17,7 @@ using Microsoft.Extensions.Logging.Abstractions; using MorganStanley.ComposeUI.Messaging.Client.Abstractions; using MorganStanley.ComposeUI.Messaging.Exceptions; +using MorganStanley.ComposeUI.Messaging.Instrumentation; using MorganStanley.ComposeUI.Messaging.Protocol; using MorganStanley.ComposeUI.Messaging.Protocol.Messages; using Nito.AsyncEx; @@ -191,8 +192,8 @@ private Topic GetTopic(string topicName) private readonly StateChangeEvents _stateChangeEvents = new(); private readonly AsyncLock _mutex = new(); - private readonly Channel _sendChannel = - Channel.CreateUnbounded(new UnboundedChannelOptions {SingleReader = true}); + private readonly Channel> _sendChannel = + Channel.CreateUnbounded>(new UnboundedChannelOptions {SingleReader = true}); private readonly ConcurrentDictionary> _pendingRequests = new(); private readonly ConcurrentDictionary _endpointHandlers = new(); @@ -235,23 +236,27 @@ private async ValueTask ConnectAsyncCore(CancellationToken cancellationToken) try { + OnConnectStart(); await _connection.ConnectAsync(cancellationToken); _stateChangeEvents.SendReceiveCompleted = Task.WhenAll(SendMessagesAsync(), ReceiveMessagesAsync()); await _sendChannel.Writer.WriteAsync( - new ConnectRequest {AccessToken = _options.AccessToken}, + new MessageWrapper( + new ConnectRequest { AccessToken = _options.AccessToken }) + , cancellationToken); await _stateChangeEvents.Connected.Task; } - catch (MessageRouterException) + catch (MessageRouterException e) { throw; } catch (Exception e) { await CloseAsyncCore(e); + OnConnectStop(e); throw ThrowHelper.ConnectionFailed(e); } @@ -270,18 +275,18 @@ private void HandleMessage(Message message) { /* Design notes - - While the public API supports async subscribers and other callbacks, + + While the public API supports async subscribers and other callbacks, we avoid (make impossible) to block message processing by an async handler. In general, we process everything synchronously until the point where the actual user code is called, and we do not await it. Subscribers have their own - async message queues so that they can't block each other. + async message queues so that they can't block each other. This can lead to memory issues if a badly written subscriber can't process its messages fast enough, and its queue starts to grow infinitely. We might add some configuration options to fine-tune this behavior later, eg. set a max queue size (in that case, we must signal the subscriber in some way, possibly with a flag in MessageContext, or a dedicated callback). - + */ switch (message) @@ -314,7 +319,7 @@ async message queues so that they can't block each other. private void HandleConnectResponse(ConnectResponse message) { - _ = Task.Run( + _ = Task.Factory.StartNew( async () => { using (await _mutex.LockAsync()) @@ -322,21 +327,29 @@ private void HandleConnectResponse(ConnectResponse message) if (message.Error != null) { _connectionState = ConnectionState.Closed; - _stateChangeEvents.Connected.TrySetException(new MessageRouterException(message.Error)); + var exception = new MessageRouterException(message.Error); + _stateChangeEvents.Connected.TrySetException(exception); + OnCloseStart(); + OnCloseStop(exception); + OnConnectStop(exception); } else { _clientId = message.ClientId; _connectionState = ConnectionState.Connected; _stateChangeEvents.Connected.TrySetResult(); + OnConnectStop(); } } - }); + }, + TaskCreationOptions.RunContinuationsAsynchronously); } private void HandleInvokeRequest(InvokeRequest message) { - _ = Task.Run( + OnRequestStart(message); + + _ = Task.Factory.StartNew( async () => { try @@ -385,6 +398,8 @@ private void HandleInvokeRequest(InvokeRequest message) await SendMessageAsync( response, + (state, exception) => OnRequestStop((Message) state!, exception), + message, _stateChangeEvents.CloseRequested.Token); } catch (Exception e) @@ -393,8 +408,10 @@ await SendMessageAsync( e, $"Unhandled exception while processing an {nameof(InvokeRequest)}: {{ExceptionMessage}}", e.Message); + OnRequestStop(message, e); } - }); + }, + TaskCreationOptions.RunContinuationsAsynchronously); } private void HandleResponse(AbstractResponse message) @@ -404,7 +421,8 @@ private void HandleResponse(AbstractResponse message) if (message.Error != null) { - tcs.TrySetException(new MessageRouterException(message.Error)); + var exception = new MessageRouterException(message.Error); + tcs.TrySetException(exception); } else { @@ -414,7 +432,7 @@ private void HandleResponse(AbstractResponse message) private void HandleTopicMessage(Protocol.Messages.TopicMessage message) { - if (!_topics.TryGetValue(message.Topic, out var topic)) + if (!_topics.TryGetValue(message.Topic, out var topic)) return; var topicMessage = new TopicMessage( @@ -427,7 +445,26 @@ private void HandleTopicMessage(Protocol.Messages.TopicMessage message) CorrelationId = message.CorrelationId }); - topic.OnNext(topicMessage); + OnRequestStart(message); + + var wrapper = new MessageWrapper( + topicMessage, + OnRequestStop, + message); + + try + { + if (!topic.OnNext(wrapper)) + { + OnRequestStop(message); + } + } + catch (Exception e) + { + OnRequestStop(message, e); + + throw; + } } private string GenerateRequestId() => Guid.NewGuid().ToString("N"); @@ -464,10 +501,30 @@ private async Task SendRequestAsync( return (TResponse) await tcs.Task; } - private async ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken) + private ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken) { - await ConnectAsync(CancellationToken.None); - await _sendChannel.Writer.WriteAsync(message, CancellationToken.None); + return SendMessageAsync(message, onDequeued: null, state: null, cancellationToken); + } + + private async ValueTask SendMessageAsync(Message message, Action? onDequeued, object state, CancellationToken cancellationToken) + { + await ConnectAsync(cancellationToken); + + var wrapper = new MessageWrapper(message, onDequeued, state); + wrapper.OnQueued(); + + try + { + await _sendChannel.Writer.WriteAsync( + wrapper, + cancellationToken); + } + catch (Exception e) + { + wrapper.OnDequeued(e); + + throw; + } } private async Task SendMessagesAsync() @@ -477,9 +534,11 @@ private async Task SendMessagesAsync() while (await _sendChannel.Reader.WaitToReadAsync(_stateChangeEvents.CloseRequested.Token)) { while (!_stateChangeEvents.CloseRequested.IsCancellationRequested - && _sendChannel.Reader.TryRead(out var message)) + && _sendChannel.Reader.TryRead(out var wrapper)) { - await _connection.SendAsync(message, _stateChangeEvents.CloseRequested.Token); + await _connection.SendAsync(wrapper.Message, _stateChangeEvents.CloseRequested.Token); + OnMessageSent(wrapper.Message); + wrapper.OnDequeued(); } } } @@ -504,7 +563,7 @@ private async Task ReceiveMessagesAsync() while (!_stateChangeEvents.CloseRequested.IsCancellationRequested) { var message = await _connection.ReceiveAsync(_stateChangeEvents.CloseRequested.Token); - + OnMessageReceived(message); HandleMessage(message); } } @@ -539,6 +598,7 @@ private async ValueTask RegisterServiceCore( Descriptor = descriptor, }; + await SendRequestAsync(request, cancellationToken); } @@ -551,15 +611,6 @@ private async ValueTask SubscribeAsyncCore( if (!subscribeResult.NeedsSubscription) return subscribeResult.Subscription; - if (cancellationToken.CanBeCanceled) - { - cancellationToken.Register( - () => - { - subscribeResult.Subscription.Complete(); - }); - } - try { await SendMessageAsync( @@ -593,7 +644,9 @@ private async ValueTask UnregisterServiceCore(string serviceName, CancellationTo private void RequestClose(Exception? exception) { - _ = Task.Run(() => CloseAsyncCore(exception).AsTask()); + _ = Task.Factory.StartNew( + () => CloseAsyncCore(exception).AsTask(), + TaskCreationOptions.RunContinuationsAsynchronously); } private ValueTask CloseAsync(Exception? exception) @@ -609,7 +662,7 @@ private ValueTask CloseAsync(Exception? exception) return CloseAsyncCore(exception); } - private async ValueTask CloseAsyncCore(Exception? exception, ConnectionState? previousState = null) + private async ValueTask CloseAsyncCore(Exception? exception) { ConnectionState oldState; @@ -681,23 +734,37 @@ private async ValueTask CloseAsyncCore(Exception? exception, ConnectionState? pr } } - FailPendingRequests(exception); - FailSubscribers(exception); - _stateChangeEvents.CloseRequested.Cancel(); - await _stateChangeEvents.SendReceiveCompleted!; + OnCloseStart(); try { - await _connection.DisposeAsync(); + FailPendingRequests(exception); + FailSubscribers(exception); + await CloseTopics(); + _stateChangeEvents.CloseRequested.Cancel(); + await _stateChangeEvents.SendReceiveCompleted!; + + try + { + await _connection.DisposeAsync(); + } + catch (Exception e) + { + _logger.LogError(e, "Exception thrown when closing the connection: {ExceptionMessage}", e.Message); + } + + using (await _mutex.LockAsync()) + { + _connectionState = ConnectionState.Closed; + } + + OnCloseStop(); } catch (Exception e) { - _logger.LogError(e, "Exception thrown when closing the connection: {ExceptionMessage}", e.Message); - } + OnCloseStop(e); - using (await _mutex.LockAsync()) - { - _connectionState = ConnectionState.Closed; + throw; } // ReSharper disable once VariableHidesOuterVariable @@ -718,7 +785,11 @@ void FailSubscribers(Exception exception) { topic.Value.OnError(exception); } + } + async Task CloseTopics() + { + await Task.WhenAll(_topics.Select(t => t.Value.CloseAsync())); _topics.Clear(); } } @@ -730,13 +801,94 @@ private ValueTask TryUnsubscribe(Topic topic) : default; } + private void OnConnectStart() + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.ConnectStart)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.ConnectStart, + new MessageRouterEvent(this, MessageRouterEventTypes.ConnectStart)); + } + } + + private void OnConnectStop(Exception? exception = null) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.ConnectStop)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.ConnectStop, + new MessageRouterEvent(this, MessageRouterEventTypes.ConnectStop, Exception: exception)); + } + } + + private void OnMessageReceived(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageReceived)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageReceived, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageReceived, message)); + } + } + + private void OnMessageSent(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageSent)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageSent, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent, message)); + } + } + + private void OnRequestStart(Message message, Exception? exception = null) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.RequestStart)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.RequestStart, + new MessageRouterEvent(this, MessageRouterEventTypes.RequestStart, message, exception)); + } + } + + private void OnRequestStop(Message message, Exception? exception = null) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.RequestStop)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.RequestStop, + new MessageRouterEvent(this, MessageRouterEventTypes.RequestStop, message, exception)); + } + } + + private void OnCloseStart() + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.CloseStart)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.CloseStart, + new MessageRouterEvent(this, MessageRouterEventTypes.CloseStart)); + } + } + + private void OnCloseStop(Exception? exception = null) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.CloseStop)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.CloseStop, + new MessageRouterEvent(this, MessageRouterEventTypes.CloseStop, Exception: exception)); + } + } + [DebuggerStepThrough] private static void CheckNotOnMainThread() { #if DEBUG if (Thread.CurrentThread.GetApartmentState() == ApartmentState.STA) { - throw new InvalidOperationException("The current thread is the main thread. Awaiting the resulting Task can cause a deadlock."); + throw new InvalidOperationException( + "The current thread is the main thread. Awaiting the resulting Task can cause a deadlock."); } #endif } @@ -752,15 +904,44 @@ private enum ConnectionState private class StateChangeEvents { - public readonly TaskCompletionSource ConnectResponseReceived = - new(TaskCreationOptions.RunContinuationsAsynchronously); - public readonly TaskCompletionSource Connected = new(TaskCreationOptions.RunContinuationsAsynchronously); public readonly CancellationTokenSource CloseRequested = new(); public readonly TaskCompletionSource Closed = new(TaskCreationOptions.RunContinuationsAsynchronously); public Task? SendReceiveCompleted; } + private sealed class MessageWrapper + { + public MessageWrapper( + TMessage message, + Action? onDequeued = null, + TState state = default!) + { + Message = message; + _state = state; + _onDequeued = onDequeued; + } + + public TMessage Message { get; } + + internal void OnQueued() + { + Interlocked.Increment(ref _queuedCount); + } + + internal void OnDequeued(Exception? exception = null) + { + if (Interlocked.Decrement(ref _queuedCount) == 0) + { + _onDequeued?.Invoke(_state, exception); + } + } + + private readonly TState _state; + private readonly Action? _onDequeued; + private int _queuedCount; + } + private class Topic : IAsyncObservable { public Topic(string name, MessageRouterClient messageRouter, ILogger logger) @@ -796,22 +977,25 @@ public SubscribeResult Subscribe(IAsyncObserver subscriber) var needsSubscription = _subscriptions.Count == 0; var subscription = new Subscription(this, subscriber, _logger); _subscriptions.Add(subscription); + _subscriberCount.AddCount(); return new SubscribeResult(subscription, needsSubscription); } } - public void OnNext(TopicMessage value) + public bool OnNext(MessageWrapper value) { lock (_mutex) { - if (_isCompleted) - return; + if (_isCompleted || _subscriptions.Count == 0) + return false; foreach (var subscription in _subscriptions) { subscription.OnNext(value); } + + return true; } } @@ -832,30 +1016,17 @@ public void OnError(Exception exception) } } - public void Complete() - { - lock (_mutex) - { - if (_isCompleted) - return; - - _isCompleted = true; - - foreach (var subscription in _subscriptions) - { - subscription.Complete(); - } - } - } - public void Unsubscribe(Subscription subscription) { lock (_mutex) { - _subscriptions.Remove(subscription); + if (_isCompleted || !_subscriptions.Remove(subscription)) return; + if (_subscriptions.Count == 0) { - Task.Run(() => _messageRouter.TryUnsubscribe(this)); + Task.Factory.StartNew( + () => _messageRouter.TryUnsubscribe(this), + TaskCreationOptions.RunContinuationsAsynchronously); } } } @@ -865,12 +1036,23 @@ public ValueTask SubscribeAsync(IAsyncObserver o return _messageRouter.SubscribeAsyncCore(this, observer, CancellationToken.None); } + public Task CloseAsync() + { + return _subscriberCount.WaitAsync(); + } + + public void OnSubscriberCompleted() + { + _subscriberCount.Signal(); + } + private readonly MessageRouterClient _messageRouter; private readonly ILogger _logger; private readonly object _mutex = new(); - private readonly List _subscriptions = new(); + private readonly HashSet _subscriptions = new(); private bool _isCompleted; private Exception? _exception; + private AsyncCountdownEvent _subscriberCount = new(0); } private class Subscription : IAsyncDisposable @@ -880,13 +1062,22 @@ public Subscription(Topic topic, IAsyncObserver subscriber, ILogge _subscriber = subscriber; _topic = topic; _logger = logger; - _ = Task.Run(ProcessMessages); + _ = Task.Factory.StartNew(ProcessMessages, TaskCreationOptions.RunContinuationsAsynchronously); } - public void OnNext(TopicMessage value) + public void OnNext(MessageWrapper value) { - _queue.Writer.TryWrite( - value); // Since the queue is unbounded, this will succeed unless the channel was completed + // Note the order. If we only invoked OnQueued AFTER the TryWrite call, + // a race condition would allow OnDequeued to be called in between, + // resulting in a false ordering of events (the MessageRouterClient would signal OnMessageProcessed + // before the subscriber actually receiving it). + + value.OnQueued(); + + if (!_queue.Writer.TryWrite(value)) + { + value.OnDequeued(); + } } public void OnError(Exception exception) @@ -894,13 +1085,9 @@ public void OnError(Exception exception) _queue.Writer.TryComplete(exception); } - public void Complete() - { - _queue.Writer.TryComplete(); - } - public ValueTask DisposeAsync() { + _queue.Writer.TryComplete(); _topic.Unsubscribe(this); return default; @@ -916,7 +1103,7 @@ private async Task ProcessMessages() { try { - await _subscriber.OnNextAsync(value); + await _subscriber.OnNextAsync(value.Message); } catch (Exception e) { @@ -926,6 +1113,10 @@ private async Task ProcessMessages() _topic.Name, e.Message); } + finally + { + value.OnDequeued(); + } } } @@ -957,17 +1148,22 @@ private async Task ProcessMessages() e2.Message); } } + finally + { + _topic.OnSubscriberCompleted(); + } } private readonly IAsyncObserver _subscriber; - private readonly Channel _queue = Channel.CreateUnbounded( - new UnboundedChannelOptions - { - AllowSynchronousContinuations = false, - SingleReader = true, - SingleWriter = false - }); + private readonly Channel> _queue = + Channel.CreateUnbounded>( + new UnboundedChannelOptions + { + AllowSynchronousContinuations = false, + SingleReader = true, + SingleWriter = false + }); private readonly Topic _topic; private readonly ILogger _logger; diff --git a/src/messaging/dotnet/src/Client/Client/WebSocket/WebSocketConnection.cs b/src/messaging/dotnet/src/Client/Client/WebSocket/WebSocketConnection.cs index 9f6706f68..c1905ceba 100644 --- a/src/messaging/dotnet/src/Client/Client/WebSocket/WebSocketConnection.cs +++ b/src/messaging/dotnet/src/Client/Client/WebSocket/WebSocketConnection.cs @@ -21,6 +21,7 @@ using Microsoft.Extensions.Options; using MorganStanley.ComposeUI.Messaging.Client.Abstractions; using MorganStanley.ComposeUI.Messaging.Exceptions; +using MorganStanley.ComposeUI.Messaging.Instrumentation; using MorganStanley.ComposeUI.Messaging.Protocol.Json; using MorganStanley.ComposeUI.Messaging.Protocol.Messages; @@ -55,7 +56,7 @@ public async ValueTask ConnectAsync(CancellationToken cancellationToken = defaul { _webSocket = new ClientWebSocket(); await _webSocket.ConnectAsync(_options.Value.Uri, cancellationToken); - _ = Task.Run(ReceiveMessages); + _ = Task.Factory.StartNew(ReceiveMessages, TaskCreationOptions.RunContinuationsAsynchronously); } catch { @@ -77,6 +78,8 @@ await _webSocket.SendAsync( WebSocketMessageType.Text, WebSocketMessageFlags.EndOfMessage, _stopTokenSource.Token); + + OnMessageSent(message); } catch (OperationCanceledException) { @@ -155,6 +158,8 @@ private async Task ReceiveMessages() while (!readBuffer.IsEmpty && TryReadMessage(ref readBuffer, out var message)) { + OnMessageReceived(message); + if (!_receiveChannel.Writer.TryWrite(message)) { break; @@ -177,6 +182,26 @@ private async Task ReceiveMessages() } } + private void OnMessageReceived(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageReceived)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageReceived, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageReceived, message)); + } + } + + private void OnMessageSent(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageSent)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageSent, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent, message)); + } + } + private static bool TryReadMessage(ref ReadOnlySequence buffer, [NotNullWhen(true)] out Message? message) { var innerBuffer = buffer; diff --git a/src/messaging/dotnet/src/Client/MorganStanley.ComposeUI.Messaging.Client.csproj b/src/messaging/dotnet/src/Client/MorganStanley.ComposeUI.Messaging.Client.csproj index 7e972019c..e43eb566e 100644 --- a/src/messaging/dotnet/src/Client/MorganStanley.ComposeUI.Messaging.Client.csproj +++ b/src/messaging/dotnet/src/Client/MorganStanley.ComposeUI.Messaging.Client.csproj @@ -1,4 +1,4 @@ - + net6.0 diff --git a/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticObserver.cs b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticObserver.cs new file mode 100644 index 000000000..504979b2e --- /dev/null +++ b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticObserver.cs @@ -0,0 +1,296 @@ +// Morgan Stanley makes this available to you under the Apache License, +// Version 2.0 (the "License"). You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0. +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. Unless required by applicable law or agreed +// to in writing, software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions +// and limitations under the License. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using MorganStanley.ComposeUI.Messaging.Protocol.Messages; +using Nito.AsyncEx; + +namespace MorganStanley.ComposeUI.Messaging.Instrumentation; + +/// +/// Provides the mechanisms to wait for asynchronous background operations raised by messaging components deterministically. +/// +public class MessageRouterDiagnosticObserver : IDisposable, IObserver> +{ + /// An optional object which will be used as a filter on the sender of the observed events. + public MessageRouterDiagnosticObserver(object? sender = null) + { + _sender = sender; + _subscription = MessageRouterDiagnosticSource.Log.Subscribe(this); + } + + /// + /// Asynchronously waits until all outstanding background operations complete. + /// + /// + /// + /// + /// Background operations are considered done when the sender has no more work to do, eg. + /// + /// + /// + /// for an incoming , all subscribers were notified via + /// + /// + /// + /// + /// + /// for , the connection has been disposed and all subscribers + /// were + /// notified via + /// + /// + /// + /// + /// for , the response has arrived and the task was completed. + /// + /// + /// + /// This method will also wait for a event for any message + /// registered using . + /// + public Task WaitForCompletionAsync(CancellationToken cancellationToken = default) + { + // ReSharper disable once InconsistentlySynchronizedField + return _outstandingEvents.WaitAsync(cancellationToken); + } + + /// + /// Asynchronously waits until all outstanding background operations complete. + /// + /// A timeout after which this method will throw a + /// + /// + public async Task WaitForCompletionAsync(TimeSpan timeout) + { + if (Debugger.IsAttached) + { + timeout = Timeout.InfiniteTimeSpan; + } + + using var cts = new CancellationTokenSource(timeout); + try + { + await WaitForCompletionAsync(cts.Token); + } + catch (OperationCanceledException e) when (e.CancellationToken == cts.Token) + { + lock (_lock) + { + throw new TimeoutException( + $"The operation has timed out.\nOutstanding events:\n" + + string.Join(separator: '\n', _expectedEvents.Select(exp => exp.Description)), + e); + } + } + } + + /// + /// Registers a message that is expected to be processed by the sender. + /// + /// + /// + /// + /// The unchanged message. + /// + /// + /// For each message registered with this method, or logged by the sender using the + /// event type, will wait + /// for + /// a matching event. + /// + public TMessage RegisterRequest(TMessage message, object? sender = null) where TMessage : Message + { + sender = ValidateSender(sender); + + lock (_lock) + { + var reg = new RegisteredRequest(sender, message); + + if (!_registeredRequests.Add(reg)) return message; + + AddExpectedEvent( + evt => + { + var result = evt.Type == MessageRouterEventTypes.RequestStop + && evt.Sender == sender + && evt.Message == message; + + if (result) + { + _registeredRequests.Remove(reg); + } + + return result; + }, + $"{MessageRouterEventTypes.RequestStop}: {typeof(TMessage).Name}"); + } + + return message; + } + + /// + /// Registers a message that is expected to be sent by the source. + /// + /// A predicate used for recognising the messages + /// The expected sender, if it differs from the one provided in the constructor of the current object + /// + /// will wait for expected messages to be sent by the source. + /// + public void ExpectMessage(Predicate predicate, object? sender = null) + { + sender = ValidateSender(sender); + + AddExpectedEvent( + evt => evt is {Type: MessageRouterEventTypes.MessageSent, Message: not null} + && evt.Sender == sender + && predicate(evt.Message), + $"{MessageRouterEventTypes.MessageSent}: "); + } + + /// + /// Registers a message of type that is expected to be sent by the source. + /// + /// A predicate used for recognising the messages + /// + /// + /// will wait for expected messages to be sent by the source. + /// + public void ExpectMessage(Predicate? predicate = null, object? sender = null) where TMessage : Message + { + sender = ValidateSender(sender); + + Predicate innerPredicate = + predicate is null + ? evt => evt is {Type: MessageRouterEventTypes.MessageSent, Message: TMessage} + && evt.Sender == sender + : evt => evt is {Type: MessageRouterEventTypes.MessageSent, Message: TMessage} + && evt.Sender == sender + && predicate((TMessage) evt.Message); + + AddExpectedEvent(innerPredicate, $"{MessageRouterEventTypes.MessageSent}: {typeof(TMessage).Name}"); + } + + /// + /// Adds an expected event to wait for when is called. + /// + /// + /// + public void ExpectEvent(string eventType, object? sender = null) + { + sender = ValidateSender(sender); + + AddExpectedEvent( + evt => evt.Sender == sender && evt.Type == eventType, + eventType); + } + + /// + public void Dispose() + { + _subscription.Dispose(); + } + + void IObserver>.OnCompleted() + { + lock (_lock) + { + _outstandingEvents.Signal(_outstandingEvents.CurrentCount); + } + } + + void IObserver>.OnError(Exception error) + { + lock (_lock) + { + _outstandingEvents.Signal(_outstandingEvents.CurrentCount); + } + } + + void IObserver>.OnNext(KeyValuePair value) + { + if (value.Value is not MessageRouterEvent evt) return; + + lock (_lock) + { + if (TryRemoveExpectedEvent(evt)) return; + + switch (evt.Type) + { + case MessageRouterEventTypes.RequestStart: + ArgumentNullException.ThrowIfNull(evt.Message); + RegisterRequest(evt.Message, evt.Sender); + break; + + case MessageRouterEventTypes.CloseStart: + ExpectEvent(MessageRouterEventTypes.CloseStop, evt.Sender); + break; + + case MessageRouterEventTypes.ConnectStart: + ExpectEvent(MessageRouterEventTypes.ConnectStop, evt.Sender); + break; + } + } + } + + private readonly object? _sender; + + private object ValidateSender(object? sender, [CallerMemberName] string? callerName = null) => + sender ?? _sender + ?? throw new ArgumentNullException( + nameof(sender), + $"sender must be specified either for the constructor of {nameof(MessageRouterDiagnosticObserver)} or {callerName}"); + + private void AddExpectedEvent(Predicate predicate, string description) + { + lock (_lock) + { + _expectedEvents.Add(new Expectation(predicate, description)); + _outstandingEvents.AddCount(); + } + } + + private bool TryRemoveExpectedEvent(MessageRouterEvent evt) + { + lock (_lock) + { + var index = _expectedEvents.FindIndex(expectation => expectation.Predicate(evt)); + if (index < 0) return false; + + _expectedEvents.RemoveAt(index); + _outstandingEvents.Signal(); + + return true; + } + } + + private readonly object _lock = new(); + private readonly IDisposable _subscription; + private readonly AsyncCountdownEvent _outstandingEvents = new(0); + private readonly HashSet _registeredRequests = new(); + private readonly List _expectedEvents = new(); + + private sealed record RegisteredRequest(object Sender, Message Message); + + private sealed class Expectation + { + public Expectation(Predicate predicate, string description) + { + Predicate = predicate; + Description = description; + } + + public Predicate Predicate { get; } + public string Description { get; } + } +} \ No newline at end of file diff --git a/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticSource.cs b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticSource.cs new file mode 100644 index 000000000..b8492440f --- /dev/null +++ b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterDiagnosticSource.cs @@ -0,0 +1,21 @@ +// Morgan Stanley makes this available to you under the Apache License, +// Version 2.0 (the "License"). You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0. +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. Unless required by applicable law or agreed +// to in writing, software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions +// and limitations under the License. + +using System.Diagnostics; + +namespace MorganStanley.ComposeUI.Messaging.Instrumentation; + +public static class MessageRouterDiagnosticSource +{ + public const string Name = "MorganStanley.ComposeUI.Messaging"; + public static DiagnosticListener Log { get; } = new (Name); +} diff --git a/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEvent.cs b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEvent.cs new file mode 100644 index 000000000..ebffc0468 --- /dev/null +++ b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEvent.cs @@ -0,0 +1,21 @@ +// Morgan Stanley makes this available to you under the Apache License, +// Version 2.0 (the "License"). You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0. +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. Unless required by applicable law or agreed +// to in writing, software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions +// and limitations under the License. + +using MorganStanley.ComposeUI.Messaging.Protocol.Messages; + +namespace MorganStanley.ComposeUI.Messaging.Instrumentation; + +public record MessageRouterEvent( + object Sender, + string Type, + Message? Message = null, + Exception? Exception = null); \ No newline at end of file diff --git a/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEventTypes.cs b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEventTypes.cs new file mode 100644 index 000000000..c8a2d36d2 --- /dev/null +++ b/src/messaging/dotnet/src/Core/Instrumentation/MessageRouterEventTypes.cs @@ -0,0 +1,25 @@ +// Morgan Stanley makes this available to you under the Apache License, +// Version 2.0 (the "License"). You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0. +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. Unless required by applicable law or agreed +// to in writing, software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions +// and limitations under the License. + +namespace MorganStanley.ComposeUI.Messaging.Instrumentation; + +public static class MessageRouterEventTypes +{ + public const string ConnectStart = nameof(ConnectStart); + public const string ConnectStop = nameof(ConnectStop); + public const string MessageReceived = nameof(MessageReceived); + public const string MessageSent = nameof(MessageSent); + public const string RequestStart = nameof(RequestStart); + public const string RequestStop = nameof(RequestStop); + public const string CloseStart = nameof(CloseStart); + public const string CloseStop = nameof(CloseStop); +} \ No newline at end of file diff --git a/src/messaging/dotnet/src/Core/MorganStanley.ComposeUI.Messaging.Core.csproj b/src/messaging/dotnet/src/Core/MorganStanley.ComposeUI.Messaging.Core.csproj index a42396d9c..7386a76e6 100644 --- a/src/messaging/dotnet/src/Core/MorganStanley.ComposeUI.Messaging.Core.csproj +++ b/src/messaging/dotnet/src/Core/MorganStanley.ComposeUI.Messaging.Core.csproj @@ -1,4 +1,4 @@ - + net6.0 @@ -15,6 +15,7 @@ + diff --git a/src/messaging/dotnet/src/Server/Server/MessageRouterServer.cs b/src/messaging/dotnet/src/Server/Server/MessageRouterServer.cs index 3b97e1fab..2f9bbb387 100644 --- a/src/messaging/dotnet/src/Server/Server/MessageRouterServer.cs +++ b/src/messaging/dotnet/src/Server/Server/MessageRouterServer.cs @@ -15,6 +15,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using MorganStanley.ComposeUI.Messaging.Exceptions; +using MorganStanley.ComposeUI.Messaging.Instrumentation; using MorganStanley.ComposeUI.Messaging.Protocol; using MorganStanley.ComposeUI.Messaging.Protocol.Messages; using MorganStanley.ComposeUI.Messaging.Server.Abstractions; @@ -36,6 +37,7 @@ public async ValueTask DisposeAsync() _stopTokenSource.Cancel(); await Task.WhenAll(_clients.Values.Select(client => client.Connection.CloseAsync().AsTask())); + await Task.WhenAll(_tasks); } public ValueTask ClientConnected(IClientConnection connection) @@ -45,7 +47,10 @@ public ValueTask ClientConnected(IClientConnection connection) _logger.LogInformation("Client '{ClientId}' connected", client.ClientId); - _ = Task.Run(() => ProcessMessages(client, _stopTokenSource.Token)); + _tasks.Add( + Task.Factory.StartNew( + () => ProcessMessages(client, _stopTokenSource.Token), + TaskCreationOptions.RunContinuationsAsynchronously)); return default; } @@ -71,6 +76,7 @@ public ValueTask ClientDisconnected(IClientConnection connection) private readonly ConcurrentDictionary _serviceInvocations = new(); private readonly ConcurrentDictionary _serviceRegistrations = new(); private readonly CancellationTokenSource _stopTokenSource = new(); + private readonly HashSet _tasks = new(); // Tasks that must be awaited on DisposeAsync private readonly ConcurrentDictionary _topics = new(); private readonly ConcurrentDictionary _connectionToClient = new(); private readonly IAccessTokenValidator? _accessTokenValidator; @@ -80,6 +86,8 @@ private async Task HandleConnectRequest( ConnectRequest message, CancellationToken cancellationToken) { + OnRequestStart(message); + try { if (_accessTokenValidator != null) @@ -95,15 +103,23 @@ await client.Connection.SendAsync( CancellationToken.None); _clients.TryAdd(client.ClientId, client); + OnRequestStop(message); } catch (Exception e) { - await client.Connection.SendAsync( - new ConnectResponse - { - Error = new Error(e), - }, - CancellationToken.None); + try + { + await client.Connection.SendAsync( + new ConnectResponse + { + Error = new Error(e), + }, + CancellationToken.None); + } + finally + { + OnRequestStop(message, e); + } } } @@ -112,6 +128,8 @@ private async Task HandleInvokeRequest( InvokeRequest message, CancellationToken cancellationToken) { + OnRequestStart(message); + try { Client? serviceClient = null; @@ -131,19 +149,19 @@ private async Task HandleInvokeRequest( throw ThrowHelper.UnknownEndpoint(message.Endpoint); } - var request = new ServiceInvocation( - message.RequestId, + var invocation = new ServiceInvocation( + message, Guid.NewGuid().ToString(), client.ClientId, serviceClient.ClientId); - if (!_serviceInvocations.TryAdd(request.ServiceRequestId, request)) + if (!_serviceInvocations.TryAdd(invocation.ServiceRequestId, invocation)) throw ThrowHelper.DuplicateRequestId(); await serviceClient.Connection.SendAsync( new InvokeRequest { - RequestId = request.ServiceRequestId, + RequestId = invocation.ServiceRequestId, Endpoint = message.Endpoint, Payload = message.Payload, CorrelationId = message.CorrelationId, @@ -158,13 +176,20 @@ await serviceClient.Connection.SendAsync( e.GetType().FullName, message.Endpoint); - await client.Connection.SendAsync( - new InvokeResponse - { - RequestId = message.RequestId, - Error = new Error(e), - }, - CancellationToken.None); + try + { + await client.Connection.SendAsync( + new InvokeResponse + { + RequestId = message.RequestId, + Error = new Error(e), + }, + CancellationToken.None); + } + finally + { + OnRequestStop(message, e); + } } } @@ -173,20 +198,30 @@ private async Task HandleInvokeResponse( InvokeResponse message, CancellationToken cancellationToken) { - if (!_serviceInvocations.TryRemove(message.RequestId, out var request)) + if (!_serviceInvocations.TryRemove(message.RequestId, out var invocation)) return; // TODO: Log warning - if (!_clients.TryGetValue(request.CallerClientId, out var caller)) - return; // TODO: Log warning + try + { + if (!_clients.TryGetValue(invocation.CallerClientId, out var caller)) + return; // TODO: Log warning + + var response = new InvokeResponse + { + RequestId = invocation.Request.RequestId, + Payload = message.Payload, + Error = message.Error, + }; - var response = new InvokeResponse + await caller.Connection.SendAsync(response, CancellationToken.None); + OnRequestStop(invocation.Request); + } + catch (Exception e) { - RequestId = request.CallerRequestId, - Payload = message.Payload, - Error = message.Error, - }; + OnRequestStop(invocation.Request, e); - await caller.Connection.SendAsync(response, CancellationToken.None); + throw; + } } private async Task HandlePublishMessage( @@ -194,27 +229,40 @@ private async Task HandlePublishMessage( PublishMessage message, CancellationToken cancellationToken) { - if (string.IsNullOrWhiteSpace(message.Topic)) - return; + OnRequestStart(message); + + try + { + if (string.IsNullOrWhiteSpace(message.Topic)) + return; - var topic = _topics.GetOrAdd(message.Topic, topicName => new Topic(topicName, ImmutableHashSet.Empty)); + var topic = _topics.GetOrAdd(message.Topic, topicName => new Topic(topicName, ImmutableHashSet.Empty)); - var outgoingMessage = new Protocol.Messages.TopicMessage + var outgoingMessage = new Protocol.Messages.TopicMessage + { + Topic = message.Topic, + Payload = message.Payload, + Scope = message.Scope, + SourceId = client.ClientId, + CorrelationId = message.CorrelationId, + }; + + await Task.WhenAll( + topic.Subscribers.Select( + async subscriberId => + { + if (_clients.TryGetValue(subscriberId, out var subscriber)) + await subscriber.Connection.SendAsync(outgoingMessage, cancellationToken); + })); + + OnRequestStop(message); + } + catch (Exception e) { - Topic = message.Topic, - Payload = message.Payload, - Scope = message.Scope, - SourceId = client.ClientId, - CorrelationId = message.CorrelationId, - }; - - await Task.WhenAll( - topic.Subscribers.Select( - async subscriberId => - { - if (_clients.TryGetValue(subscriberId, out var subscriber)) - await subscriber.Connection.SendAsync(outgoingMessage, cancellationToken); - })); + OnRequestStop(message, e); + + throw; + } } private async Task HandleRegisterServiceRequest( @@ -222,6 +270,8 @@ private async Task HandleRegisterServiceRequest( RegisterServiceRequest message, CancellationToken cancellationToken) { + OnRequestStart(message); + try { Endpoint.Validate(message.Endpoint); @@ -235,16 +285,25 @@ await client.Connection.SendAsync( RequestId = message.RequestId, }, CancellationToken.None); + + OnRequestStop(message); } catch (Exception e) { - await client.Connection.SendAsync( - new RegisterServiceResponse - { - RequestId = message.RequestId, - Error = new Error(e), - }, - CancellationToken.None); + try + { + await client.Connection.SendAsync( + new RegisterServiceResponse + { + RequestId = message.RequestId, + Error = new Error(e), + }, + CancellationToken.None); + } + finally + { + OnRequestStop(message, e); + } } } @@ -253,38 +312,56 @@ private Task HandleSubscribeMessage( SubscribeMessage message, CancellationToken cancellationToken) { - if (!Protocol.Topic.IsValidTopicName(message.Topic)) - return Task.CompletedTask; + OnRequestStart(message); - var topic = _topics.AddOrUpdate( - message.Topic, - // ReSharper disable once VariableHidesOuterVariable - static (topicName, client) => new Topic(topicName, ImmutableHashSet.Empty.Add(client.ClientId)), - // ReSharper disable once VariableHidesOuterVariable - static (topicName, topic, client) => - { - topic.Subscribers = topic.Subscribers.Add(client.ClientId); + try + { + if (!Protocol.Topic.IsValidTopicName(message.Topic)) + return Task.CompletedTask; + + var topic = _topics.AddOrUpdate( + message.Topic, + // ReSharper disable once VariableHidesOuterVariable + static (topicName, client) => new Topic(topicName, ImmutableHashSet.Empty.Add(client.ClientId)), + // ReSharper disable once VariableHidesOuterVariable + static (topicName, topic, client) => + { + topic.Subscribers = topic.Subscribers.Add(client.ClientId); - return topic; - }, - client); + return topic; + }, + client); - return Task.CompletedTask; + return Task.CompletedTask; + } + finally + { + OnRequestStop(message); + } } - private async Task HandleUnregisterServiceMessage( + private async Task HandleUnregisterServiceRequest( Client client, UnregisterServiceRequest request, CancellationToken cancellationToken) { + OnRequestStart(request); + _serviceRegistrations.TryRemove(new KeyValuePair(request.Endpoint, client.ClientId)); - await client.Connection.SendAsync( - new UnregisterServiceResponse - { - RequestId = request.RequestId, - }, - CancellationToken.None); + try + { + await client.Connection.SendAsync( + new UnregisterServiceResponse + { + RequestId = request.RequestId, + }, + CancellationToken.None); + } + finally + { + OnRequestStop(request); + } } private Task HandleUnsubscribeMessage( @@ -292,23 +369,32 @@ private Task HandleUnsubscribeMessage( UnsubscribeMessage message, CancellationToken cancellationToken) { - if (string.IsNullOrWhiteSpace(message.Topic)) - return Task.CompletedTask; + OnRequestStart(message); - var topic = _topics.AddOrUpdate( - message.Topic, - // ReSharper disable once VariableHidesOuterVariable - static (topicName, client) => new Topic(topicName, ImmutableHashSet.Empty), - // ReSharper disable once VariableHidesOuterVariable - static (topicName, topic, client) => - { - topic.Subscribers = topic.Subscribers.Remove(client.ClientId); + try + { + if (string.IsNullOrWhiteSpace(message.Topic)) + return Task.CompletedTask; + + var topic = _topics.AddOrUpdate( + message.Topic, + // ReSharper disable once VariableHidesOuterVariable + static (topicName, client) => new Topic(topicName, ImmutableHashSet.Empty), + // ReSharper disable once VariableHidesOuterVariable + static (topicName, topic, client) => + { + topic.Subscribers = topic.Subscribers.Remove(client.ClientId); - return topic; - }, - client); + return topic; + }, + client); - return Task.CompletedTask; + return Task.CompletedTask; + } + finally + { + OnRequestStop(message); + } } private async Task ProcessMessages(Client client, CancellationToken cancellationToken) @@ -318,6 +404,7 @@ private async Task ProcessMessages(Client client, CancellationToken cancellation while (!client.StopTokenSource.IsCancellationRequested) { var message = await client.Connection.ReceiveAsync(cancellationToken); + OnMessageReceived(message); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -370,7 +457,7 @@ await HandleRegisterServiceRequest( break; case MessageType.UnregisterService: - await HandleUnregisterServiceMessage( + await HandleUnregisterServiceRequest( client, (UnregisterServiceRequest) message, cancellationToken); @@ -410,6 +497,46 @@ await HandleUnregisterServiceMessage( client.StopTaskSource.TrySetResult(); } + private void OnMessageReceived(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageReceived)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageReceived, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageReceived, message)); + } + } + + private void OnMessageSent(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.MessageSent)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.MessageSent, + new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent, message)); + } + } + + private void OnRequestStart(Message message) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.RequestStart)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.RequestStart, + new MessageRouterEvent(this, MessageRouterEventTypes.RequestStart, message)); + } + } + + private void OnRequestStop(Message message, Exception? exception = null) + { + if (MessageRouterDiagnosticSource.Log.IsEnabled(MessageRouterEventTypes.RequestStop)) + { + MessageRouterDiagnosticSource.Log.Write( + MessageRouterEventTypes.RequestStop, + new MessageRouterEvent(this, MessageRouterEventTypes.RequestStop, message)); + } + } + private class Client { public Client(IClientConnection connection) @@ -433,24 +560,24 @@ public Topic(string name, ImmutableHashSet subscribers) public string Name { get; } - public ImmutableHashSet Subscribers { get; set; } = ImmutableHashSet.Empty; + public ImmutableHashSet Subscribers { get; set; } } private class ServiceInvocation { public ServiceInvocation( - string callerRequestId, + InvokeRequest request, string serviceRequestId, string callerClientId, string serviceClientId) { - CallerRequestId = callerRequestId; + Request = request; CallerClientId = callerClientId; ServiceClientId = serviceClientId; ServiceRequestId = serviceRequestId; } - public string CallerRequestId { get; } + public InvokeRequest Request { get; } public string ServiceRequestId { get; } public string CallerClientId { get; } public string ServiceClientId { get; } diff --git a/src/messaging/dotnet/src/Server/Server/WebSocket/WebSocketListenerService.cs b/src/messaging/dotnet/src/Server/Server/WebSocket/WebSocketListenerService.cs index d93e16c29..58c5409f2 100644 --- a/src/messaging/dotnet/src/Server/Server/WebSocket/WebSocketListenerService.cs +++ b/src/messaging/dotnet/src/Server/Server/WebSocket/WebSocketListenerService.cs @@ -40,7 +40,7 @@ public WebSocketListenerService( public Task StartAsync(CancellationToken cancellationToken) { - Task.Run(StartAsyncCore); + Task.Factory.StartNew(StartAsyncCore, TaskCreationOptions.RunContinuationsAsynchronously); return _startTaskSource.Task; } diff --git a/src/messaging/dotnet/test/Client.Tests/Client/MessageRouterClient.Tests.cs b/src/messaging/dotnet/test/Client.Tests/Client/MessageRouterClient.Tests.cs index 7a8c248a6..d60c8c5b4 100644 --- a/src/messaging/dotnet/test/Client.Tests/Client/MessageRouterClient.Tests.cs +++ b/src/messaging/dotnet/test/Client.Tests/Client/MessageRouterClient.Tests.cs @@ -12,10 +12,11 @@ using System.Linq.Expressions; using MorganStanley.ComposeUI.Messaging.Client.Abstractions; +using MorganStanley.ComposeUI.Messaging.Instrumentation; using MorganStanley.ComposeUI.Messaging.Protocol; using MorganStanley.ComposeUI.Messaging.Protocol.Messages; using MorganStanley.ComposeUI.Messaging.TestUtils; -using TaskExtensions = MorganStanley.ComposeUI.Testing.TaskExtensions; +using Nito.AsyncEx; namespace MorganStanley.ComposeUI.Messaging.Client; @@ -24,9 +25,9 @@ public class MessageRouterClientTests : IAsyncLifetime [Fact] public async Task DisposeAsync_does_not_invoke_the_connection_when_called_before_connecting() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); - + await _messageRouter.DisposeAsync(); + + await WaitForCompletionAsync(); _connectionMock.Verify(_ => _.DisposeAsync()); _connectionMock.VerifyNoOtherCalls(); } @@ -34,40 +35,40 @@ public async Task DisposeAsync_does_not_invoke_the_connection_when_called_before [Fact] public async Task DisposeAsync_disposes_the_connection() { - var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); - await messageRouter.DisposeAsync(); - + await _messageRouter.ConnectAsync(); + await _messageRouter.DisposeAsync(); + + await WaitForCompletionAsync(); _connectionMock.Verify(_ => _.DisposeAsync()); } [Fact] public async Task DisposeAsync_does_not_throw_if_the_client_was_already_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); } [Fact] public async Task DisposeAsync_disposes_the_connection_exactly_once() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); + await WaitForCompletionAsync(); _connectionMock.Verify(_ => _.DisposeAsync(), Times.Once); } - [Fact (Skip ="Ci fail")] + [Fact] public async Task DisposeAsync_calls_OnError_on_active_subscribers() { - var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var subscriber = new Mock>(); - await messageRouter.SubscribeAsync("test-topic", subscriber.Object); + await _messageRouter.SubscribeAsync("test-topic", subscriber.Object); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); + + await WaitForCompletionAsync(); subscriber.Verify( _ => _.OnErrorAsync(It.Is(e => e.Name == MessageRouterErrors.ConnectionClosed))); @@ -77,22 +78,22 @@ public async Task DisposeAsync_calls_OnError_on_active_subscribers() [Fact] public async Task DisposeAsync_completes_pending_requests_with_a_MessageRouterException() { - var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); - var invokeTask = messageRouter.InvokeAsync("test-endpoint"); - await messageRouter.DisposeAsync(); + await _messageRouter.ConnectAsync(); + var invokeTask = _messageRouter.InvokeAsync("test-endpoint").AsTask(); + await _messageRouter.DisposeAsync(); - var exception = await Assert.ThrowsAsync(async () => await invokeTask); + await WaitForCompletionAsync(); + var exception = await Assert.ThrowsAsync(async () => await invokeTask).WaitAsync(TestTimeout); + exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } [Fact] public async Task ConnectAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = - await Assert.ThrowsAsync(async () => await messageRouter.ConnectAsync()); + await Assert.ThrowsAsync(async () => await _messageRouter.ConnectAsync()).WaitAsync(TestTimeout); exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } @@ -101,17 +102,16 @@ public async Task ConnectAsync_sends_a_ConnectRequest_and_waits_for_a_ConnectRes { var connectRequestReceived = new TaskCompletionSource(); _connectionMock.Handle(_ => connectRequestReceived.SetResult()); - await using var messageRouter = CreateMessageRouter(); - var connectTask = messageRouter.ConnectAsync(); + var connectTask = _messageRouter.ConnectAsync().AsTask(); await connectRequestReceived.Task; connectTask.IsCompleted.Should().BeFalse(); await _connectionMock.SendToClient(new ConnectResponse {ClientId = "client-id"}); - await connectTask; + await connectTask.WaitAsync(TestTimeout); - messageRouter.ClientId.Should().Be("client-id"); + _messageRouter.ClientId.Should().Be("client-id"); } [Fact] @@ -119,16 +119,17 @@ public async Task ConnectAsync_throws_a_MessageRouterException_if_the_ConnectRes { var connectRequestReceived = new TaskCompletionSource(); _connectionMock.Handle(_ => connectRequestReceived.SetResult()); - await using var messageRouter = CreateMessageRouter(); - var connectTask = messageRouter.ConnectAsync(); - await connectRequestReceived.Task; + var connectTask = _messageRouter.ConnectAsync(); + await connectRequestReceived.Task.WaitAsync(TestTimeout); connectTask.IsCompleted.Should().BeFalse(); await _connectionMock.SendToClient(new ConnectResponse {Error = new Error("Error", "Fail")}); - var exception = await Assert.ThrowsAsync(async () => await connectTask); + await WaitForCompletionAsync(); + + var exception = await Assert.ThrowsAsync(async () => await connectTask).WaitAsync(TestTimeout); exception.Name.Should().Be("Error"); exception.Message.Should().Be("Fail"); } @@ -136,27 +137,30 @@ public async Task ConnectAsync_throws_a_MessageRouterException_if_the_ConnectRes [Fact] public async Task PublishAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.PublishAsync("test-topic")); + async () => await _messageRouter.PublishAsync("test-topic")).WaitAsync(TestTimeout); + exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } - [Fact (Skip="CI fail")] + [Fact] public async Task PublishAsync_sends_a_PublishMessage() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); + _diagnosticObserver.ExpectMessage(); - await messageRouter.PublishAsync( + await _messageRouter.PublishAsync( "test-topic", "test-payload", new PublishOptions {CorrelationId = "test-correlation-id", Scope = MessageScope.FromClientId("other-client")}); + + await WaitForCompletionAsync(); + _connectionMock.Expect( msg => msg.Topic == "test-topic" && msg.Payload != null @@ -168,37 +172,41 @@ await messageRouter.PublishAsync( [Fact] public async Task SubscribeAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.SubscribeAsync( + async () => await _messageRouter.SubscribeAsync( "test-topic", - new Mock>().Object)); + new Mock>().Object)).WaitAsync(TestTimeout); + exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } - [Fact (Skip="CI Fail")] + [Fact] public async Task SubscribeAsync_sends_a_Subscribe_message() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); + + await _messageRouter.SubscribeAsync("test-topic", new Mock>().Object); - await messageRouter.SubscribeAsync("test-topic", new Mock>().Object); + _diagnosticObserver.ExpectMessage(); + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic"); } - [Fact (Skip="CI Fail")] + [Fact] public async Task SubscribeAsync_only_sends_a_Subscribe_message_on_the_first_subscription() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); - await messageRouter.SubscribeAsync("test-topic", new Mock>().Object); - await messageRouter.SubscribeAsync("test-topic", new Mock>().Object); - await messageRouter.SubscribeAsync("test-topic", new Mock>().Object); + _diagnosticObserver.ExpectMessage(); + await _messageRouter.SubscribeAsync("test-topic", new Mock>().Object); + await _messageRouter.SubscribeAsync("test-topic", new Mock>().Object); + await _messageRouter.SubscribeAsync("test-topic", new Mock>().Object); + + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Once); } @@ -206,24 +214,24 @@ public async Task SubscribeAsync_only_sends_a_Subscribe_message_on_the_first_sub [Fact] public async Task When_Topic_message_received_it_invokes_the_subscribers() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var sub1 = new Mock>(); var sub2 = new Mock>(); - await messageRouter.SubscribeAsync("test-topic", sub1.Object); - await messageRouter.SubscribeAsync("test-topic", sub2.Object); + await _messageRouter.SubscribeAsync("test-topic", sub1.Object); + await _messageRouter.SubscribeAsync("test-topic", sub2.Object); await _connectionMock.SendToClient( - new Protocol.Messages.TopicMessage - { - Topic = "test-topic", - Payload = MessageBuffer.Create("test-payload"), - SourceId = "other-client", - CorrelationId = "test-correlation-id", - }); + RegisterRequest( + new Protocol.Messages.TopicMessage + { + Topic = "test-topic", + Payload = MessageBuffer.Create("test-payload"), + SourceId = "other-client", + CorrelationId = "test-correlation-id", + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); Expression, ValueTask>> expectedInvocation = _ => _.OnNextAsync( @@ -241,32 +249,47 @@ await _connectionMock.SendToClient( [Fact] public async Task When_Topic_message_received_it_keeps_processing_messages_if_the_subscriber_calls_InvokeAsync() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); + + // Register two subscribers, the first one will invoke a service that completes when + // the second subscriber has been called twice. If the first subscriber could block + // the pipeline with the InvokeAsync call, this test would fail, because the second + // subscriber would never get the second message. + + var countdown = new AsyncCountdownEvent(2); + _connectionMock.Handle(_ => new ValueTask(countdown.WaitAsync())); var sub1 = new Mock>(); sub1.Setup(_ => _.OnNextAsync(It.IsAny())) - .Returns(async (TopicMessage msg) => await messageRouter.InvokeAsync("test-service")); - var sub2 = new Mock>(); + .Returns(async (TopicMessage msg) => await _messageRouter.InvokeAsync("test-service")); - _connectionMock.Handle(req => { }); // Swallow the request, let the caller wait forever + var sub2 = new Mock>(); + sub1.Setup(_ => _.OnNextAsync(It.IsAny())) + .Returns((TopicMessage msg) => + { + countdown.Signal(); + return default; + }); - await messageRouter.SubscribeAsync( + await _messageRouter.SubscribeAsync( "test-topic", sub1.Object); - await messageRouter.SubscribeAsync( + await _messageRouter.SubscribeAsync( "test-topic", sub2.Object); await _connectionMock.SendToClient( - new Protocol.Messages.TopicMessage - {Topic = "test-topic", Payload = MessageBuffer.Create("payload1")}); + RegisterRequest( + new Protocol.Messages.TopicMessage + {Topic = "test-topic", Payload = MessageBuffer.Create("payload1")})); + await _connectionMock.SendToClient( - new Protocol.Messages.TopicMessage - {Topic = "test-topic", Payload = MessageBuffer.Create("payload2")}); + RegisterRequest( + new Protocol.Messages.TopicMessage + {Topic = "test-topic", Payload = MessageBuffer.Create("payload2")})); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); sub1.Verify( _ => _.OnNextAsync( @@ -285,35 +308,35 @@ await _connectionMock.SendToClient( [Fact] public async Task Topic_extension_sends_a_Subscribe_message_on_first_subscription() { - await using var messageRouter = CreateMessageRouter(); - - var topic = messageRouter.Topic("test-topic"); + var topic = _messageRouter.Topic("test-topic"); + _diagnosticObserver.ExpectMessage(); await using var sub1 = await topic.SubscribeAsync(_ => { }); + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Once); _connectionMock.Invocations.Clear(); await using var sub2 = await topic.SubscribeAsync(_ => { }); - + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Never); } [Fact] public async Task Topic_extension_sends_an_Unsubscribe_message_after_the_last_subscription_is_disposed() { - await using var messageRouter = CreateMessageRouter(); - - var topic = messageRouter.Topic("test-topic"); + var topic = _messageRouter.Topic("test-topic"); + _diagnosticObserver.ExpectMessage(); var sub1 = await topic.SubscribeAsync(_ => { }); var sub2 = await topic.SubscribeAsync(_ => { }); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); await sub1.DisposeAsync(); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Never); + _diagnosticObserver.ExpectMessage(); await sub2.DisposeAsync(); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Once); } @@ -321,20 +344,20 @@ public async Task Topic_extension_sends_an_Unsubscribe_message_after_the_last_su [Fact] public async Task When_the_last_subscription_is_disposed_it_sends_an_Unsubscribe_message() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var subscriber = new Mock>(); - var sub1 = await messageRouter.SubscribeAsync("test-topic", subscriber.Object); - var sub2 = await messageRouter.SubscribeAsync("test-topic", subscriber.Object); - var sub3 = await messageRouter.SubscribeAsync("test-topic", subscriber.Object); + var sub1 = await _messageRouter.SubscribeAsync("test-topic", subscriber.Object); + var sub2 = await _messageRouter.SubscribeAsync("test-topic", subscriber.Object); + var sub3 = await _messageRouter.SubscribeAsync("test-topic", subscriber.Object); await sub1.DisposeAsync(); await sub2.DisposeAsync(); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); _connectionMock.Expect(Times.Never); + _diagnosticObserver.ExpectMessage(); await sub3.DisposeAsync(); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); _connectionMock.Expect(msg => msg.Topic == "test-topic", Times.Once); } @@ -342,12 +365,11 @@ public async Task When_the_last_subscription_is_disposed_it_sends_an_Unsubscribe [Fact] public async Task InvokeAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.InvokeAsync("test-service")); + async () => await _messageRouter.InvokeAsync("test-service")); exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } @@ -355,12 +377,11 @@ await Assert.ThrowsAsync( public async Task InvokeAsync_sends_an_InvokeRequest_and_waits_for_an_InvokeResponse() { var invokeRequestReceived = new TaskCompletionSource(); - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle(request => invokeRequestReceived.SetResult(request)); - var invokeTask = messageRouter.InvokeAsync("test-service", MessageBuffer.Create("test-payload")); + var invokeTask = _messageRouter.InvokeAsync("test-service", MessageBuffer.Create("test-payload")).AsTask(); invokeTask.IsCompleted.Should().BeFalse(); @@ -369,7 +390,7 @@ public async Task InvokeAsync_sends_an_InvokeRequest_and_waits_for_an_InvokeResp await _connectionMock.SendToClient( new InvokeResponse {RequestId = request.RequestId, Payload = MessageBuffer.Create("test-response")}); - var response = await invokeTask; + var response = await invokeTask.WaitAsync(TestTimeout); response.Should().NotBeNull(); response!.GetString().Should().Be("test-response"); @@ -379,16 +400,15 @@ await _connectionMock.SendToClient( public async Task InvokeAsync_throws_if_the_InvokeResponse_contains_an_error() { var invokeRequestReceived = new TaskCompletionSource(); - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle((InvokeRequest request) => invokeRequestReceived.SetResult(request)); - var invokeTask = messageRouter.InvokeAsync("test-service", MessageBuffer.Create("test-payload")); + var invokeTask = _messageRouter.InvokeAsync("test-service", MessageBuffer.Create("test-payload")).AsTask(); invokeTask.IsCompleted.Should().BeFalse(); - var request = await invokeRequestReceived.Task; + var request = await invokeRequestReceived.Task.WaitAsync(TestTimeout); await _connectionMock.SendToClient( new InvokeResponse {RequestId = request.RequestId, Error = new Error("Error", "Invoke failed")}); @@ -401,12 +421,11 @@ await _connectionMock.SendToClient( [Fact] public async Task RegisterServiceAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.RegisterServiceAsync("test-service", Mock.Of())); + async () => await _messageRouter.RegisterServiceAsync("test-service", Mock.Of())); exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } @@ -415,47 +434,45 @@ await Assert.ThrowsAsync( public async Task RegisterServiceAsync_Sends_a_RegisterService_request_and_waits_for_the_response() { var registerServiceRequestReceived = new TaskCompletionSource(); - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle(request => registerServiceRequestReceived.SetResult(request)); - var registerServiceTask = messageRouter.RegisterServiceAsync("test-service", Mock.Of()); + var registerServiceTask = _messageRouter.RegisterServiceAsync("test-service", Mock.Of()).AsTask(); registerServiceTask.IsCompleted.Should().BeFalse(); - var request = await registerServiceRequestReceived.Task; + var request = await registerServiceRequestReceived.Task.WaitAsync(TestTimeout); await _connectionMock.SendToClient(new RegisterServiceResponse {RequestId = request.RequestId}); - await registerServiceTask; + await registerServiceTask.WaitAsync(TestTimeout); } [Fact] public async Task RegisterServiceAsync_throws_a_MessageRouterException_if_the_endpoint_is_already_registered() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle(); - await messageRouter.RegisterServiceAsync("test-service", Mock.Of()); + await _messageRouter.RegisterServiceAsync("test-service", Mock.Of()); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.RegisterServiceAsync("test-service", Mock.Of())); + async () => await _messageRouter.RegisterServiceAsync("test-service", Mock.Of())).WaitAsync(TestTimeout); + exception.Name.Should().Be(MessageRouterErrors.DuplicateEndpoint); } [Fact] public async Task RegisterServiceAsync_throws_if_the_response_contains_an_error() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); + _connectionMock.Handle( request => new RegisterServiceResponse { - RequestId = - request.RequestId, - Error = new Error(MessageRouterErrors.DuplicateEndpoint, null) + RequestId = request.RequestId, + Error = new Error(MessageRouterErrors.DuplicateEndpoint, message: null) }); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.RegisterServiceAsync("test-service", Mock.Of())); + async () => await _messageRouter.RegisterServiceAsync("test-service", Mock.Of())).WaitAsync(TestTimeout); exception.Name.Should().Be(MessageRouterErrors.DuplicateEndpoint); } @@ -463,12 +480,11 @@ public async Task RegisterServiceAsync_throws_if_the_response_contains_an_error( [Fact] public async Task UnregisterServiceAsync_throws_a_MessageRouterException_if_the_client_was_previously_closed() { - var messageRouter = CreateMessageRouter(); - await messageRouter.DisposeAsync(); + await _messageRouter.DisposeAsync(); var exception = await Assert.ThrowsAsync( - async () => await messageRouter.UnregisterServiceAsync("test-service")); + async () => await _messageRouter.UnregisterServiceAsync("test-service")); exception.Name.Should().Be(MessageRouterErrors.ConnectionClosed); } @@ -477,18 +493,17 @@ await Assert.ThrowsAsync( public async Task UnregisterServiceAsync_sends_an_UnregisterServiceRequest_and_waits_for_the_response() { var unregisterServiceRequestReceived = new TaskCompletionSource(); - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle(); _connectionMock.Handle(msg => unregisterServiceRequestReceived.SetResult(msg)); - await messageRouter.RegisterServiceAsync("test-service", Mock.Of()); - var unregisterServiceTask = messageRouter.UnregisterServiceAsync("test-service"); + await _messageRouter.RegisterServiceAsync("test-service", Mock.Of()); + var unregisterServiceTask = _messageRouter.UnregisterServiceAsync("test-service").AsTask(); unregisterServiceTask.IsCompleted.Should().BeFalse(); - var request = await unregisterServiceRequestReceived.Task; + var request = await unregisterServiceRequestReceived.Task.WaitAsync(TestTimeout); await _connectionMock.SendToClient(new UnregisterServiceResponse {RequestId = request.RequestId}); await unregisterServiceTask; } @@ -497,8 +512,7 @@ public async Task UnregisterServiceAsync_sends_an_UnregisterServiceRequest_and_w public async Task When_responding_to_Invoke_messages_it_invokes_the_registered_handler_and_responds_with_InvokeResponse() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var invokeResponseReceived = new TaskCompletionSource(); _connectionMock.Handle(); _connectionMock.Handle(msg => invokeResponseReceived.SetResult(msg)); @@ -513,17 +527,18 @@ public async Task .ReturnsAsync(() => MessageBuffer.Create("test-response")) .Verifiable(); - await messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); + await _messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); await _connectionMock.SendToClient( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service", - Payload = MessageBuffer.Create("test-payload"), - SourceId = "other-client", - CorrelationId = "test-correlation-id" - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service", + Payload = MessageBuffer.Create("test-payload"), + SourceId = "other-client", + CorrelationId = "test-correlation-id" + })); var response = await invokeResponseReceived.Task; @@ -535,8 +550,7 @@ await _connectionMock.SendToClient( public async Task When_responding_to_Invoke_messages_it_sends_an_InvokeResponse_with_error_if_the_handler_threw_an_exception() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var invokeResponseReceived = new TaskCompletionSource(); _connectionMock.Handle(); _connectionMock.Handle(msg => invokeResponseReceived.SetResult(msg)); @@ -549,7 +563,7 @@ public async Task It.IsAny())) .Callback(() => throw new InvalidOperationException("Invoke failed")); - await messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); + await _messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); await _connectionMock.SendToClient( new InvokeRequest @@ -561,7 +575,7 @@ await _connectionMock.SendToClient( CorrelationId = "test-correlation-id" }); - var response = await invokeResponseReceived.Task; + var response = await invokeResponseReceived.Task.WaitAsync(TestTimeout); response.Error.Should().NotBeNull(); response.Error!.Message.Should().Be("Invoke failed"); } @@ -570,11 +584,10 @@ await _connectionMock.SendToClient( public async Task When_responding_to_Invoke_messages_it_sends_an_InvokeResponse_with_error_if_the_endpoint_is_not_registered() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); var invokeResponseReceived = new TaskCompletionSource(); _connectionMock.Handle(msg => invokeResponseReceived.SetResult(msg)); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); await _connectionMock.SendToClient( new InvokeRequest @@ -583,7 +596,7 @@ await _connectionMock.SendToClient( Endpoint = "unknown-service", }); - var response = await invokeResponseReceived.Task; + var response = await invokeResponseReceived.Task.WaitAsync(TestTimeout); response.Error.Should().NotBeNull(); response.Error!.Name.Should().Be(MessageRouterErrors.UnknownEndpoint); } @@ -592,37 +605,46 @@ await _connectionMock.SendToClient( public async Task When_responding_to_Invoke_messages_it_repeatedly_calls_the_registered_handler_without_waiting_for_it_to_complete() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); _connectionMock.Handle(); + var countdown = new AsyncCountdownEvent(3); var messageHandler = new Mock(); messageHandler.Setup(_ => _("test-service", It.IsAny(), It.IsAny())) - .Returns(() => new ValueTask(new TaskCompletionSource().Task)); + .Returns(async () => + { + countdown.Signal(); + await countdown.WaitAsync(); - await messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); + return null; + }); + + await _messageRouter.RegisterServiceAsync("test-service", messageHandler.Object); await _connectionMock.SendToClient( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service", - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service", + })); await _connectionMock.SendToClient( - new InvokeRequest - { - RequestId = "2", - Endpoint = "test-service", - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "2", + Endpoint = "test-service", + })); await _connectionMock.SendToClient( - new InvokeRequest - { - RequestId = "3", - Endpoint = "test-service", - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "3", + Endpoint = "test-service", + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); messageHandler.Verify( _ => _("test-service", It.IsAny(), It.IsAny()), @@ -632,14 +654,17 @@ await _connectionMock.SendToClient( [Fact] public async Task When_the_connection_closes_it_calls_OnErrorAsync_on_active_subscribers() { - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); + await _messageRouter.ConnectAsync(); + + var subscriberCalled = new AsyncManualResetEvent(); var subscriber = new Mock>(); - await messageRouter.SubscribeAsync("test-topic", subscriber.Object); + subscriber.Setup(_ => _.OnErrorAsync(It.IsAny())).Callback(() => subscriberCalled.Set()); + await _messageRouter.SubscribeAsync("test-topic", subscriber.Object); _connectionMock.Close(new MessageRouterException(MessageRouterErrors.ConnectionAborted, "")); - - await TaskExtensions.WaitForBackgroundTasksAsync(); + + await WaitForCompletionAsync(); + await subscriberCalled.WaitAsync().WaitAsync(TestTimeout); subscriber.Verify( _ => _.OnErrorAsync( @@ -649,14 +674,12 @@ public async Task When_the_connection_closes_it_calls_OnErrorAsync_on_active_sub [Fact] public async Task When_the_connection_closes_it_fails_pending_requests() { - - await using var messageRouter = CreateMessageRouter(); - await messageRouter.ConnectAsync(); - var invokeTask = messageRouter.InvokeAsync("test-service"); + await _messageRouter.ConnectAsync(); + var invokeTask = _messageRouter.InvokeAsync("test-service"); _connectionMock.Close(new MessageRouterException(MessageRouterErrors.ConnectionAborted, "")); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); var exception = await Assert.ThrowsAsync(async () => await invokeTask); exception.Name.Should().Be(MessageRouterErrors.ConnectionAborted); @@ -666,6 +689,10 @@ public MessageRouterClientTests() { _connectionMock = new MockConnection(); _connectionMock.AcceptConnections(); + var connectionFactory = new Mock(); + connectionFactory.Setup(_ => _.CreateConnection()).Returns(_connectionMock.Object); + _messageRouter = new MessageRouterClient(connectionFactory.Object, new MessageRouterOptions()); + _diagnosticObserver = new MessageRouterDiagnosticObserver(_messageRouter); } public Task InitializeAsync() @@ -678,14 +705,19 @@ public Task DisposeAsync() return Task.CompletedTask; } + public static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(1); + private readonly MockConnection _connectionMock; - private static readonly MessageRouterOptions DefaultOptions = new MessageRouterOptions(); + private readonly MessageRouterClient _messageRouter; + private readonly MessageRouterDiagnosticObserver _diagnosticObserver; - private MessageRouterClient CreateMessageRouter(MessageRouterOptions? options = null) + private TMessage RegisterRequest(TMessage message) where TMessage: Message { - var connectionFactory = new Mock(); - connectionFactory.Setup(_ => _.CreateConnection()).Returns(_connectionMock.Object); + return _diagnosticObserver.RegisterRequest(message); + } - return new MessageRouterClient(connectionFactory.Object, options ?? DefaultOptions); + private async Task WaitForCompletionAsync() + { + await _diagnosticObserver.WaitForCompletionAsync(TestTimeout); } } \ No newline at end of file diff --git a/src/messaging/dotnet/test/Core.Tests/Instrumentation/MessageRouterDiagnosticObserver.Tests.cs b/src/messaging/dotnet/test/Core.Tests/Instrumentation/MessageRouterDiagnosticObserver.Tests.cs new file mode 100644 index 000000000..6159f3982 --- /dev/null +++ b/src/messaging/dotnet/test/Core.Tests/Instrumentation/MessageRouterDiagnosticObserver.Tests.cs @@ -0,0 +1,134 @@ +// Morgan Stanley makes this available to you under the Apache License, +// Version 2.0 (the "License"). You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0. +// +// See the NOTICE file distributed with this work for additional information +// regarding copyright ownership. Unless required by applicable law or agreed +// to in writing, software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions +// and limitations under the License. + +using MorganStanley.ComposeUI.Messaging.Protocol.Messages; + +namespace MorganStanley.ComposeUI.Messaging.Instrumentation; + +using TopicMessage = Protocol.Messages.TopicMessage; + +public class MessageRouterDiagnosticObserverTests +{ + [Fact] + public async Task Wait_in_initial_state() + { + await _observer.WaitForCompletionAsync().WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_ConnectStop_after_ConnectStart() + { + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.ConnectStart)); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.ConnectStop)); + await task.WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_CloseStop_after_CloseStart() + { + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.CloseStart)); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.CloseStop)); + await task.WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_RequestStop_after_RequestStart() + { + var message = new TopicMessage(); + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStart, message)); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStop, message)); + await task.WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_RequestStop_after_calling_RegisterRequest() + { + var message1 = new TopicMessage(); + _observer.RegisterRequest(message1); + var message2 = new TopicMessage(); + _observer.RegisterRequest(message2); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStart, message1)); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStart, message2)); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStop, message1)); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.RequestStop, message2)); + + await task.WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_MessageSent_after_calling_ExpectMessage() + { + var message = new InvokeRequest(); + _observer.ExpectMessage(msg => msg == message); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent, message)); + + await task.WaitAsync(TestTimeout); + } + + [Fact] + public async Task Wait_for_expected_events_to_be_written() + { + _observer.ExpectEvent(MessageRouterEventTypes.MessageSent); + _observer.ExpectEvent(MessageRouterEventTypes.MessageSent); + + var task = _observer.WaitForCompletionAsync(); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent)); + task.IsCompleted.Should().BeFalse(); + + WriteEvent(new MessageRouterEvent(this, MessageRouterEventTypes.MessageSent)); + + await task.WaitAsync(TestTimeout); + } + + // TODO: Add test cases when the sender does not match the one specified in constructor + + public MessageRouterDiagnosticObserverTests() + { + _observer = new MessageRouterDiagnosticObserver(this); + } + + private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(1); + private readonly MessageRouterDiagnosticObserver _observer; + + private void WriteEvent(MessageRouterEvent evt) + { + MessageRouterDiagnosticSource.Log.Write(evt.Type, evt); + } +} \ No newline at end of file diff --git a/src/messaging/dotnet/test/Core.Tests/MessageBuffer.Tests.cs b/src/messaging/dotnet/test/Core.Tests/MessageBuffer.Tests.cs index 0d09a597b..9c8a69861 100644 --- a/src/messaging/dotnet/test/Core.Tests/MessageBuffer.Tests.cs +++ b/src/messaging/dotnet/test/Core.Tests/MessageBuffer.Tests.cs @@ -11,7 +11,6 @@ // and limitations under the License. using System.Buffers; -using System.Buffers.Text; using System.Text; using MorganStanley.ComposeUI.Messaging.TestUtils; diff --git a/src/messaging/dotnet/test/Core.Tests/MorganStanley.ComposeUI.Messaging.Core.Tests.csproj b/src/messaging/dotnet/test/Core.Tests/MorganStanley.ComposeUI.Messaging.Core.Tests.csproj index 8cb94a9be..c7a16cfe3 100644 --- a/src/messaging/dotnet/test/Core.Tests/MorganStanley.ComposeUI.Messaging.Core.Tests.csproj +++ b/src/messaging/dotnet/test/Core.Tests/MorganStanley.ComposeUI.Messaging.Core.Tests.csproj @@ -16,4 +16,8 @@ + + + + \ No newline at end of file diff --git a/src/messaging/dotnet/test/IntegrationTests/EndToEndTestsBase.cs b/src/messaging/dotnet/test/IntegrationTests/EndToEndTestsBase.cs index 86afded86..a738e5dae 100644 --- a/src/messaging/dotnet/test/IntegrationTests/EndToEndTestsBase.cs +++ b/src/messaging/dotnet/test/IntegrationTests/EndToEndTestsBase.cs @@ -38,7 +38,7 @@ public async Task Client_can_subscribe_and_receive_messages() observerMock.Setup(x => x.OnNext(Capture.In(receivedMessages))); await subscriber.SubscribeAsync(topic: "test-topic", observerMock.Object); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await Task.Delay(TimeSpan.FromSeconds(2)); var publishedPayload = new TestPayload { @@ -50,8 +50,7 @@ await publisher.PublishAsync( topic: "test-topic", MessageBuffer.Create(JsonSerializer.SerializeToUtf8Bytes(publishedPayload))); - // TODO: Investigate why WaitForBackgroundTasksAsync is unreliable in this particular scenario - await TaskExtensions.WaitForBackgroundTasksAsync(TimeSpan.FromMilliseconds(100)); + await Task.Delay(TimeSpan.FromSeconds(2)); var receivedPayload = JsonSerializer.Deserialize(receivedMessages.Single().Payload!.GetSpan()); @@ -176,7 +175,7 @@ await subscriber.SubscribeAsync( { using (await semaphore.LockAsync(new CancellationTokenSource(TimeSpan.Zero).Token)) { - await TaskExtensions.WaitForBackgroundTasksAsync(); + await Task.Delay(TimeSpan.FromSeconds(1)); } if (msg.Payload?.GetString() == "done") diff --git a/src/messaging/dotnet/test/Server.Tests/Server/MessageRouterServer.Tests.cs b/src/messaging/dotnet/test/Server.Tests/Server/MessageRouterServer.Tests.cs index c56cb2de7..e8cc69e94 100644 --- a/src/messaging/dotnet/test/Server.Tests/Server/MessageRouterServer.Tests.cs +++ b/src/messaging/dotnet/test/Server.Tests/Server/MessageRouterServer.Tests.cs @@ -10,12 +10,13 @@ // or implied. See the License for the specific language governing permissions // and limitations under the License. +using System.Diagnostics; using System.Linq.Expressions; +using MorganStanley.ComposeUI.Messaging.Instrumentation; using MorganStanley.ComposeUI.Messaging.Protocol; using MorganStanley.ComposeUI.Messaging.Protocol.Messages; using MorganStanley.ComposeUI.Messaging.Server.Abstractions; using MorganStanley.ComposeUI.Messaging.TestUtils; -using TaskExtensions = MorganStanley.ComposeUI.Testing.TaskExtensions; namespace MorganStanley.ComposeUI.Messaging.Server; @@ -25,11 +26,10 @@ public class MessageRouterServerTests public async Task It_responds_to_ConnectRequest_with_ConnectResponse() { var connectResponseReceived = new TaskCompletionSource(); - await using var server = CreateServer(); var client = CreateClient(); client.Handle(connectResponseReceived.SetResult); - await server.ClientConnected(client.Object); + await _server.ClientConnected(client.Object); await client.SendToServer(new ConnectRequest()); var connectResponse = await connectResponseReceived.Task; @@ -41,12 +41,13 @@ public async Task It_responds_to_ConnectRequest_with_ConnectResponse() public async Task It_accepts_connection_with_valid_token() { var connectResponseReceived = new TaskCompletionSource(); - var validator = new Mock(); - await using var server = CreateServer(validator.Object); + _accessTokenValidator + .Setup(_ => _.Validate(It.IsAny(), It.IsAny())) + .Returns(default(ValueTask)); var client = CreateClient(); client.Handle(connectResponseReceived.SetResult); - await server.ClientConnected(client.Object); + await _server.ClientConnected(client.Object); await client.SendToServer(new ConnectRequest {AccessToken = "token"}); var connectResponse = await connectResponseReceived.Task; @@ -58,10 +59,10 @@ public async Task It_accepts_connection_with_valid_token() public async Task It_accepts_connection_without_token_if_validator_is_not_registered() { var connectResponseReceived = new TaskCompletionSource(); - await using var server = CreateServer(); var client = CreateClient(); client.Handle(connectResponseReceived.SetResult); + var server = CreateServer(null); await server.ClientConnected(client.Object); await client.SendToServer(new ConnectRequest()); var connectResponse = await connectResponseReceived.Task; @@ -78,16 +79,14 @@ public async Task It_accepts_connection_without_token_if_validator_is_not_regist public async Task It_rejects_connections_with_invalid_token(string? token) { var connectResponseReceived = new TaskCompletionSource(); - var validator = new Mock(); - validator.Setup(_ => _.Validate(It.IsAny(), It.IsAny())) + _accessTokenValidator.Setup(_ => _.Validate(It.IsAny(), It.IsAny())) .Throws(new InvalidOperationException("Invalid token")); - await using var server = CreateServer(validator.Object); var client = CreateClient(); client.Handle(connectResponseReceived.SetResult); - await server.ClientConnected(client.Object); + await _server.ClientConnected(client.Object); await client.SendToServer(new ConnectRequest {AccessToken = token}); var connectResponse = await connectResponseReceived.Task; @@ -95,25 +94,27 @@ public async Task It_rejects_connections_with_invalid_token(string? token) } - [Fact (Skip = "CI Fail")] + [Fact] public async Task When_Publish_message_received_it_dispatches_the_message_to_the_subscribers() { - await using var server = CreateServer(); - var client1 = await CreateAndConnectClient(server); - var client2 = await CreateAndConnectClient(server); + var client1 = await CreateAndConnectClient(); + var client2 = await CreateAndConnectClient(); + + await client1.SendToServer(RegisterRequest(new SubscribeMessage {Topic = "test-topic"})); + await client2.SendToServer(RegisterRequest(new SubscribeMessage {Topic = "test-topic"})); - await client1.SendToServer(new SubscribeMessage {Topic = "test-topic"}); - await client2.SendToServer(new SubscribeMessage {Topic = "test-topic"}); + await WaitForCompletionAsync(); await client1.SendToServer( - new PublishMessage - { - Topic = "test-topic", - Payload = MessageBuffer.Create("test-payload"), - CorrelationId = "test-correlation-id" - }); + RegisterRequest( + new PublishMessage + { + Topic = "test-topic", + Payload = MessageBuffer.Create("test-payload"), + CorrelationId = "test-correlation-id" + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); Expression> expectation = msg => msg.Topic == "test-topic" @@ -129,31 +130,28 @@ await client1.SendToServer( [Fact] public async Task It_does_not_dispatch_Topic_message_if_the_client_has_unsubscribed() { - await using var server = CreateServer(); - var client = await CreateAndConnectClient(server); + var client = await CreateAndConnectClient(); - await client.SendToServer(new SubscribeMessage {Topic = "test-topic"}); + await client.SendToServer( + RegisterRequest(new SubscribeMessage {Topic = "test-topic"})); await client.SendToServer( - new PublishMessage - { - Topic = "test-topic" - }); + RegisterRequest(new PublishMessage {Topic = "test-topic"})); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); client.Expect(msg => msg.Topic == "test-topic", Times.Once); client.Invocations.Clear(); - await client.SendToServer(new UnsubscribeMessage {Topic = "test-topic"}); + await client.SendToServer( + RegisterRequest( + new UnsubscribeMessage {Topic = "test-topic"})); await client.SendToServer( - new PublishMessage - { - Topic = "test-topic" - }); + RegisterRequest( + new PublishMessage {Topic = "test-topic"})); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); client.Expect(msg => msg.Topic == "test-topic", Times.Never); } @@ -161,11 +159,10 @@ await client.SendToServer( [Fact] public async Task Client_can_register_itself_as_a_service() { - await using var server = CreateServer(); - var client = await CreateAndConnectClient(server); + var client = await CreateAndConnectClient(); - await client.SendToServer(new RegisterServiceRequest {Endpoint = "test-service"}); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await client.SendToServer(RegisterRequest(new RegisterServiceRequest {Endpoint = "test-service"})); + await WaitForCompletionAsync(); var registerServiceResponse = client.Received.OfType().First(); registerServiceResponse.Error.Should().BeNull(); @@ -174,12 +171,11 @@ public async Task Client_can_register_itself_as_a_service() [Fact] public async Task Client_can_unregister_itself_as_a_service() { - await using var server = CreateServer(); - var client = await CreateAndConnectClient(server); + var client = await CreateAndConnectClient(); - await client.SendToServer(new RegisterServiceRequest { Endpoint = "test-service" }); - await client.SendToServer(new UnregisterServiceRequest { Endpoint = "test-service" }); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await client.SendToServer(RegisterRequest(new RegisterServiceRequest {Endpoint = "test-service"})); + await client.SendToServer(RegisterRequest(new UnregisterServiceRequest {Endpoint = "test-service"})); + await WaitForCompletionAsync(); var registerServiceResponse = client.Received.OfType().Single(); registerServiceResponse.Error.Should().BeNull(); @@ -190,27 +186,27 @@ public async Task Client_can_unregister_itself_as_a_service() [Fact] public async Task It_handles_service_invocation() { - await using var server = CreateServer(); - var service = await CreateAndConnectClient(server); + var service = await CreateAndConnectClient(); service.Handle( (InvokeRequest req) => new InvokeResponse {RequestId = req.RequestId, Payload = MessageBuffer.Create("test-response")}); - var caller = await CreateAndConnectClient(server); + var caller = await CreateAndConnectClient(); - await service.SendToServer(new RegisterServiceRequest {Endpoint = "test-service"}); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await service.SendToServer(RegisterRequest(new RegisterServiceRequest {Endpoint = "test-service"})); + await WaitForCompletionAsync(); await caller.SendToServer( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service", - Payload = MessageBuffer.Create("test-payload") - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service", + Payload = MessageBuffer.Create("test-payload") + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); service.Expect( msg => msg.Endpoint == "test-service" && msg.Payload!.GetString() == "test-payload", @@ -224,8 +220,7 @@ await caller.SendToServer( [Fact] public async Task It_handles_service_invocation_with_error() { - await using var server = CreateServer(); - var service = await CreateAndConnectClient(server); + var service = await CreateAndConnectClient(); service.Handle( (InvokeRequest req) => new InvokeResponse @@ -234,20 +229,21 @@ public async Task It_handles_service_invocation_with_error() Error = new Error("Error", "Invoke failed") }); - var caller = await CreateAndConnectClient(server); + var caller = await CreateAndConnectClient(); - await service.SendToServer(new RegisterServiceRequest {Endpoint = "test-service"}); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await service.SendToServer(RegisterRequest(new RegisterServiceRequest {Endpoint = "test-service"})); + await WaitForCompletionAsync(); await caller.SendToServer( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service", - Payload = MessageBuffer.Create("test-payload") - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service", + Payload = MessageBuffer.Create("test-payload") + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); service.Expect( msg => msg.Endpoint == "test-service" && msg.Payload!.GetString() == "test-payload", @@ -263,18 +259,17 @@ await caller.SendToServer( [Fact] public async Task It_fails_the_service_invocation_if_the_service_is_not_registered() { - await using var server = CreateServer(); - - var client = await CreateAndConnectClient(server); + var client = await CreateAndConnectClient(); await client.SendToServer( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service" - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service" + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); client.Expect( msg => msg.RequestId == "1" @@ -285,40 +280,48 @@ await client.SendToServer( [Fact] public async Task It_fails_the_service_invocation_if_the_service_has_unregistered_itself() { - await using var server = CreateServer(); - - var service = await CreateAndConnectClient(server); + var service = await CreateAndConnectClient(); service.Handle(); - var caller = await CreateAndConnectClient(server); + var caller = await CreateAndConnectClient(); + + await service.SendToServer( + RegisterRequest( + new RegisterServiceRequest + { + RequestId = "1", + Endpoint = "test-service" + })); - await service.SendToServer(new RegisterServiceRequest {RequestId = "1", Endpoint = "test-service"}); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); await caller.SendToServer( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-service" - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-service" + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); caller.Expect( msg => msg.RequestId == "1", Times.Once); caller.Invocations.Clear(); - await service.SendToServer(new UnregisterServiceRequest {RequestId = "2", Endpoint = "test-service"}); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await service.SendToServer( + RegisterRequest(new UnregisterServiceRequest {RequestId = "2", Endpoint = "test-service"})); + await WaitForCompletionAsync(); await caller.SendToServer( - new InvokeRequest - { - RequestId = "2", - Endpoint = "test-service" - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "2", + Endpoint = "test-service" + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); caller.Expect( msg => msg.RequestId == "2" && msg.Error != null && msg.Error.Name == MessageRouterErrors.UnknownEndpoint, @@ -328,19 +331,18 @@ await caller.SendToServer( [Fact] public async Task It_fails_direct_invocation_if_the_client_is_not_found() { - await using var server = CreateServer(); - - var client = await CreateAndConnectClient(server); + var client = await CreateAndConnectClient(); await client.SendToServer( - new InvokeRequest - { - RequestId = "1", - Endpoint = "test-endpoint", - Scope = MessageScope.FromClientId("unknown-client") - }); + RegisterRequest( + new InvokeRequest + { + RequestId = "1", + Endpoint = "test-endpoint", + Scope = MessageScope.FromClientId("unknown-client") + })); - await TaskExtensions.WaitForBackgroundTasksAsync(); + await WaitForCompletionAsync(); client.Expect( msg => msg.RequestId == "1" @@ -351,33 +353,49 @@ await client.SendToServer( [Fact] public async Task When_disposed_it_calls_CloseAsync_on_active_connections() { - var server = CreateServer(); - var connection = new Mock(); - - connection.SetupSequence(_ => _.ReceiveAsync(It.IsAny())) - .Returns(new ValueTask( - new ConnectRequest())) - .Returns(new ValueTask( - Task.Delay(1000).ContinueWith(_ => (Message)new PublishMessage {Topic = "dummy"}))); - - await server.ClientConnected(connection.Object); - await TaskExtensions.WaitForBackgroundTasksAsync(); - await server.DisposeAsync(); + var connection = await CreateAndConnectClient(); + await connection.SendToServer(RegisterRequest(new PublishMessage {Topic = "test-topic"})); + + await WaitForCompletionAsync(); + + await _server.DisposeAsync(); connection.Verify(_ => _.CloseAsync(), Times.Once); } - private MessageRouterServer CreateServer(IAccessTokenValidator? accessTokenValidator = null) => + public MessageRouterServerTests() + { + _server = CreateServer(_accessTokenValidator.Object); + _diagnosticObserver = new MessageRouterDiagnosticObserver(_server); + } + + private MessageRouterServer _server; + private Mock _accessTokenValidator = new(); + private MessageRouterDiagnosticObserver _diagnosticObserver; + + private MessageRouterServer CreateServer(IAccessTokenValidator? accessTokenValidator) => new MessageRouterServer(new MessageRouterServerDependencies(accessTokenValidator)); - private MockClientConnection CreateClient() => new MockClientConnection(); + private MockClientConnection CreateClient() => new(); - private async Task CreateAndConnectClient(MessageRouterServer server) + private async Task CreateAndConnectClient() { var client = CreateClient(); - await server.ClientConnected(client.Object); + await _server.ClientConnected(client.Object); await client.Connect(); return client; } + + private static readonly TimeSpan TestTimeout = TimeSpan.FromSeconds(2); + + private TMessage RegisterRequest(TMessage message) where TMessage : Message + { + return _diagnosticObserver.RegisterRequest(message); + } + + private async Task WaitForCompletionAsync() + { + await _diagnosticObserver.WaitForCompletionAsync(TestTimeout); + } } \ No newline at end of file diff --git a/src/messaging/dotnet/test/Server.Tests/TestUtils/MockClientConnection.cs b/src/messaging/dotnet/test/Server.Tests/TestUtils/MockClientConnection.cs index cde1040eb..1bc83ae83 100644 --- a/src/messaging/dotnet/test/Server.Tests/TestUtils/MockClientConnection.cs +++ b/src/messaging/dotnet/test/Server.Tests/TestUtils/MockClientConnection.cs @@ -24,7 +24,7 @@ public MockClientConnection() Setup(_ => _.SendAsync(Capture.In(Received), It.IsAny())); Setup(_ => _.ReceiveAsync(It.IsAny())) - .Returns((CancellationToken ct) => _sendChannel.Reader.ReadAsync(ct)); + .Returns(async (CancellationToken ct) => await _sendChannel.Reader.ReadAsync(ct)); Setup(_ => _.CloseAsync()) .Callback( @@ -122,13 +122,8 @@ public void Expect(Expression> expectation, Times Verify(_ => _.SendAsync(It.Is(expectation), It.IsAny()), times); } - public void Expect(Func times) where TMessage : Message - { - Expect(msg => true, times()); - } - public ValueTask SendToServer(Message message) { return _sendChannel.Writer.WriteAsync(message); } -} \ No newline at end of file +} diff --git a/src/shared/dotnet/MorganStanley.ComposeUI.Testing/MorganStanley.ComposeUI.Testing.csproj b/src/shared/dotnet/MorganStanley.ComposeUI.Testing/MorganStanley.ComposeUI.Testing.csproj index bde7239fc..cb0fcd5aa 100644 --- a/src/shared/dotnet/MorganStanley.ComposeUI.Testing/MorganStanley.ComposeUI.Testing.csproj +++ b/src/shared/dotnet/MorganStanley.ComposeUI.Testing/MorganStanley.ComposeUI.Testing.csproj @@ -8,6 +8,7 @@ + diff --git a/src/shared/dotnet/MorganStanley.ComposeUI.Testing/TaskExtensions.cs b/src/shared/dotnet/MorganStanley.ComposeUI.Testing/TaskExtensions.cs index d2bb9f4b2..e5c8dda78 100644 --- a/src/shared/dotnet/MorganStanley.ComposeUI.Testing/TaskExtensions.cs +++ b/src/shared/dotnet/MorganStanley.ComposeUI.Testing/TaskExtensions.cs @@ -13,12 +13,14 @@ // ReSharper disable UnusedMember.Global using System.Diagnostics; +using Xunit.Sdk; namespace MorganStanley.ComposeUI.Testing; public static class TaskExtensions { - public static Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken = default) + [Obsolete("Don't use, this method is not reliable. Add proper instrumentation to your code instead.")] + public static async Task WaitForBackgroundTasksAsync(CancellationToken cancellationToken = default) { // Quick and dirty method of waiting for background tasks to finish. // We try to schedule enough tasks so that the thread pool is fully utilized. @@ -32,14 +34,18 @@ public static Task WaitForBackgroundTasksAsync(CancellationToken cancellationTok .Select(async _ => await gate.WaitAsync(cancellationToken))); // Let the tasks complete - Task.Delay(1, cancellationToken) - .ContinueWith( - _ => gate.Release(taskCount), - TaskContinuationOptions.RunContinuationsAsynchronously); + await Task.Yield(); + gate.Release(taskCount); - return task; + await task; + + if (SynchronizationContext.Current is AsyncTestSyncContext asyncContext) + { + await asyncContext.WaitForCompletionAsync(); + } } + [Obsolete("Don't use, this method is not reliable. Add proper instrumentation to your code instead.")] public static async Task WaitForBackgroundTasksAsync(TimeSpan minimumWaitTime, CancellationToken cancellationToken = default) { var stopwatch = Stopwatch.StartNew();