From 0ed0b4875e424949d6a91f541d3823eb3fe67944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=9B=BD=E4=BC=9F?= <366193849@qq.com> Date: Fri, 29 Nov 2024 23:22:21 +0800 Subject: [PATCH] Rpc.ExecuteAsync: Reduce memory allocation. --- ...ultMqttRpcClientTopicGenerationStrategy.cs | 2 +- .../MQTTnet.Extensions.Rpc/IMqttRpcClient.cs | 9 +-- .../MQTTnet.Extensions.Rpc.csproj | 2 +- .../MQTTnet.Extensions.Rpc/MqttRpcClient.cs | 45 +++-------- .../MqttRpcClientExtensions.cs | 75 +++++++++++++++++-- 5 files changed, 86 insertions(+), 47 deletions(-) diff --git a/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs b/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs index b21affed2..95b9b457a 100644 --- a/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs +++ b/Source/MQTTnet.Extensions.Rpc/DefaultMqttRpcClientTopicGenerationStrategy.cs @@ -12,7 +12,7 @@ public MqttRpcTopicPair CreateRpcTopics(TopicGenerationContext context) { ArgumentNullException.ThrowIfNull(context); - if (context.MethodName.Contains("/") || context.MethodName.Contains("+") || context.MethodName.Contains("#")) + if (context.MethodName.Contains('/') || context.MethodName.Contains('+') || context.MethodName.Contains('#')) { throw new ArgumentException("The method name cannot contain /, + or #."); } diff --git a/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs b/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs index ccece444a..877370767 100644 --- a/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs +++ b/Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs @@ -2,18 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Protocol; using System; +using System.Buffers; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public interface IMqttRpcClient : IDisposable - { - Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null); - - Task ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default); + { + Task> ExecuteAsync(string methodName, ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default); } } \ No newline at end of file diff --git a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj index b38fb489d..f4353d21a 100644 --- a/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj +++ b/Source/MQTTnet.Extensions.Rpc/MQTTnet.Extensions.Rpc.csproj @@ -35,7 +35,7 @@ all true low - latest-Recommended + diff --git a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs index c64b93c2d..bebcf7a11 100644 --- a/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs +++ b/Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs @@ -2,17 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Formatter; +using MQTTnet.Internal; +using MQTTnet.Protocol; using System; using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; using System.Threading; using System.Threading.Tasks; -using MQTTnet.Exceptions; -using MQTTnet.Formatter; -using MQTTnet.Internal; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { @@ -21,7 +20,7 @@ public sealed class MqttRpcClient : IMqttRpcClient readonly IMqttClient _mqttClient; readonly MqttRpcClientOptions _options; - readonly ConcurrentDictionary> _waitingCalls = new ConcurrentDictionary>(); + readonly ConcurrentDictionary>> _waitingCalls = new(); public MqttRpcClient(IMqttClient mqttClient, MqttRpcClientOptions options) { @@ -43,27 +42,7 @@ public void Dispose() _waitingCalls.Clear(); } - public async Task ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) - { - using (var timeoutToken = new CancellationTokenSource(timeout)) - { - try - { - return await ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, timeoutToken.Token).ConfigureAwait(false); - } - catch (OperationCanceledException exception) - { - if (timeoutToken.IsCancellationRequested) - { - throw new MqttCommunicationTimedOutException(exception); - } - - throw; - } - } - } - - public async Task ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + public async Task> ExecuteAsync(string methodName, ReadOnlySequence payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) { ArgumentNullException.ThrowIfNull(methodName); @@ -94,7 +73,7 @@ public async Task ExecuteAsync(string methodName, byte[] payload, MqttQu try { - var awaitable = new AsyncTaskCompletionSource(); + var awaitable = new AsyncTaskCompletionSource>(); if (!_waitingCalls.TryAdd(responseTopic, awaitable)) { @@ -106,11 +85,7 @@ public async Task ExecuteAsync(string methodName, byte[] payload, MqttQu await _mqttClient.SubscribeAsync(subscribeOptions, cancellationToken).ConfigureAwait(false); await _mqttClient.PublishAsync(requestMessage, cancellationToken).ConfigureAwait(false); - using (cancellationToken.Register( - () => - { - awaitable.TrySetCanceled(); - })) + using (cancellationToken.Register(awaitable.TrySetCanceled)) { return await awaitable.Task.ConfigureAwait(false); } @@ -129,8 +104,8 @@ Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventAr return CompletedTask.Instance; } - var payloadBuffer = eventArgs.ApplicationMessage.Payload.ToArray(); - awaitable.TrySetResult(payloadBuffer); + var payload = eventArgs.ApplicationMessage.Payload; + awaitable.TrySetResult(payload); // Set this message to handled to that other code can avoid execution etc. eventArgs.IsHandled = true; diff --git a/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs b/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs index a0b0ebfee..7174b6ebd 100644 --- a/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs +++ b/Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs @@ -2,23 +2,88 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using MQTTnet.Exceptions; +using MQTTnet.Internal; +using MQTTnet.Protocol; using System; +using System.Buffers; using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; using System.Text; +using System.Threading; using System.Threading.Tasks; -using MQTTnet.Protocol; namespace MQTTnet.Extensions.Rpc { public static class MqttRpcClientExtensions { - public static Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + public static async Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + { + var response = await MqttTimeOutAsync(timeout, cancellationToken => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, cancellationToken)); + return response.ToArray(); + } + + public static Task ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, ReadOnlyMemory payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null) + { + return MqttTimeOutAsync(timeout, cancellationToken => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, cancellationToken)); + } + + public static async Task ExecuteAsync(this IMqttRpcClient client, string methodName, ReadOnlyMemory payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + var response = await client.ExecuteAsync(methodName, new ReadOnlySequence(payload), qualityOfServiceLevel, parameters, cancellationToken); + return response.ToArray(); + } + + private static async Task MqttTimeOutAsync(TimeSpan timeout, Func> executor) { - if (client == null) throw new ArgumentNullException(nameof(client)); + using var timeoutTokenSource = new CancellationTokenSource(timeout); - var buffer = Encoding.UTF8.GetBytes(payload ?? string.Empty); + try + { + return await executor(timeoutTokenSource.Token); + } + catch (OperationCanceledException exception) + { + if (timeoutTokenSource.IsCancellationRequested) + { + throw new MqttCommunicationTimedOutException(exception); + } + throw; + } + } + + + public static Task> ExecuteAsync(this IMqttRpcClient client, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + return string.IsNullOrEmpty(payload) + ? client.ExecuteAsync(methodName, ReadOnlySequence.Empty, qualityOfServiceLevel, parameters, cancellationToken) + : client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + Encoding.UTF8.GetBytes(payload, writer); + await writer.FlushAsync(cancellationToken); + } + } - return client.ExecuteAsync(timeout, methodName, buffer, qualityOfServiceLevel, parameters); + public static Task> ExecuteAsync(this IMqttRpcClient client, string methodName, Stream payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(payload); + return client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken); + + async ValueTask WritePayloadAsync(PipeWriter writer) + { + await payload.CopyToAsync(writer, cancellationToken); + await writer.FlushAsync(cancellationToken); + } + } + + public static async Task> ExecuteAsync(this IMqttRpcClient client, string methodName, Func payloadFactory, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary parameters = null, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(client); + await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(payloadFactory, cancellationToken); + return await client.ExecuteAsync(methodName, payloadOwner.Payload, qualityOfServiceLevel, parameters, cancellationToken); } } } \ No newline at end of file