Skip to content

Commit

Permalink
Rpc.ExecuteAsync: Reduce memory allocation.
Browse files Browse the repository at this point in the history
  • Loading branch information
xljiulang committed Nov 29, 2024
1 parent d1792a5 commit 0ed0b48
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public MqttRpcTopicPair CreateRpcTopics(TopicGenerationContext context)
{
ArgumentNullException.ThrowIfNull(context);

if (context.MethodName.Contains("/") || context.MethodName.Contains("+") || context.MethodName.Contains("#"))
if (context.MethodName.Contains('/') || context.MethodName.Contains('+') || context.MethodName.Contains('#'))
{
throw new ArgumentException("The method name cannot contain /, + or #.");
}
Expand Down
9 changes: 4 additions & 5 deletions Source/MQTTnet.Extensions.Rpc/IMqttRpcClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using MQTTnet.Protocol;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Protocol;

namespace MQTTnet.Extensions.Rpc
{
public interface IMqttRpcClient : IDisposable
{
Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string,object> parameters = null);

Task<byte[]> ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default);
{
Task<ReadOnlySequence<byte>> ExecuteAsync(string methodName, ReadOnlySequence<byte> payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
<NuGetAuditMode>all</NuGetAuditMode>
<NuGetAudit>true</NuGetAudit>
<NuGetAuditLevel>low</NuGetAuditLevel>
<AnalysisLevel>latest-Recommended</AnalysisLevel>
<!--<AnalysisLevel>latest-Recommended</AnalysisLevel>-->
</PropertyGroup>

<ItemGroup>
Expand Down
45 changes: 10 additions & 35 deletions Source/MQTTnet.Extensions.Rpc/MqttRpcClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using MQTTnet.Exceptions;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Protocol;
using System;
using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Exceptions;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Protocol;

namespace MQTTnet.Extensions.Rpc
{
Expand All @@ -21,7 +20,7 @@ public sealed class MqttRpcClient : IMqttRpcClient
readonly IMqttClient _mqttClient;
readonly MqttRpcClientOptions _options;

readonly ConcurrentDictionary<string, AsyncTaskCompletionSource<byte[]>> _waitingCalls = new ConcurrentDictionary<string, AsyncTaskCompletionSource<byte[]>>();
readonly ConcurrentDictionary<string, AsyncTaskCompletionSource<ReadOnlySequence<byte>>> _waitingCalls = new();

public MqttRpcClient(IMqttClient mqttClient, MqttRpcClientOptions options)
{
Expand All @@ -43,27 +42,7 @@ public void Dispose()
_waitingCalls.Clear();
}

public async Task<byte[]> ExecuteAsync(TimeSpan timeout, string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null)
{
using (var timeoutToken = new CancellationTokenSource(timeout))
{
try
{
return await ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, timeoutToken.Token).ConfigureAwait(false);
}
catch (OperationCanceledException exception)
{
if (timeoutToken.IsCancellationRequested)
{
throw new MqttCommunicationTimedOutException(exception);
}

throw;
}
}
}

public async Task<byte[]> ExecuteAsync(string methodName, byte[] payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
public async Task<ReadOnlySequence<byte>> ExecuteAsync(string methodName, ReadOnlySequence<byte> payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(methodName);

Expand Down Expand Up @@ -94,7 +73,7 @@ public async Task<byte[]> ExecuteAsync(string methodName, byte[] payload, MqttQu

try
{
var awaitable = new AsyncTaskCompletionSource<byte[]>();
var awaitable = new AsyncTaskCompletionSource<ReadOnlySequence<byte>>();

if (!_waitingCalls.TryAdd(responseTopic, awaitable))
{
Expand All @@ -106,11 +85,7 @@ public async Task<byte[]> ExecuteAsync(string methodName, byte[] payload, MqttQu
await _mqttClient.SubscribeAsync(subscribeOptions, cancellationToken).ConfigureAwait(false);
await _mqttClient.PublishAsync(requestMessage, cancellationToken).ConfigureAwait(false);

using (cancellationToken.Register(
() =>
{
awaitable.TrySetCanceled();
}))
using (cancellationToken.Register(awaitable.TrySetCanceled))
{
return await awaitable.Task.ConfigureAwait(false);
}
Expand All @@ -129,8 +104,8 @@ Task HandleApplicationMessageReceivedAsync(MqttApplicationMessageReceivedEventAr
return CompletedTask.Instance;
}

var payloadBuffer = eventArgs.ApplicationMessage.Payload.ToArray();
awaitable.TrySetResult(payloadBuffer);
var payload = eventArgs.ApplicationMessage.Payload;
awaitable.TrySetResult(payload);

// Set this message to handled to that other code can avoid execution etc.
eventArgs.IsHandled = true;
Expand Down
75 changes: 70 additions & 5 deletions Source/MQTTnet.Extensions.Rpc/MqttRpcClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,88 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using MQTTnet.Exceptions;
using MQTTnet.Internal;
using MQTTnet.Protocol;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using MQTTnet.Protocol;

namespace MQTTnet.Extensions.Rpc
{
public static class MqttRpcClientExtensions
{
public static Task<byte[]> ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string,object> parameters = null)
public static async Task<byte[]> ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null)
{
var response = await MqttTimeOutAsync(timeout, cancellationToken => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, cancellationToken));
return response.ToArray();
}

public static Task<byte[]> ExecuteAsync(this IMqttRpcClient client, TimeSpan timeout, string methodName, ReadOnlyMemory<byte> payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null)
{
return MqttTimeOutAsync(timeout, cancellationToken => client.ExecuteAsync(methodName, payload, qualityOfServiceLevel, parameters, cancellationToken));
}

public static async Task<byte[]> ExecuteAsync(this IMqttRpcClient client, string methodName, ReadOnlyMemory<byte> payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
{
var response = await client.ExecuteAsync(methodName, new ReadOnlySequence<byte>(payload), qualityOfServiceLevel, parameters, cancellationToken);
return response.ToArray();
}

private static async Task<T> MqttTimeOutAsync<T>(TimeSpan timeout, Func<CancellationToken, Task<T>> executor)
{
if (client == null) throw new ArgumentNullException(nameof(client));
using var timeoutTokenSource = new CancellationTokenSource(timeout);

var buffer = Encoding.UTF8.GetBytes(payload ?? string.Empty);
try
{
return await executor(timeoutTokenSource.Token);
}
catch (OperationCanceledException exception)
{
if (timeoutTokenSource.IsCancellationRequested)
{
throw new MqttCommunicationTimedOutException(exception);
}
throw;
}
}


public static Task<ReadOnlySequence<byte>> ExecuteAsync(this IMqttRpcClient client, string methodName, string payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
{
return string.IsNullOrEmpty(payload)
? client.ExecuteAsync(methodName, ReadOnlySequence<byte>.Empty, qualityOfServiceLevel, parameters, cancellationToken)
: client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken);

async ValueTask WritePayloadAsync(PipeWriter writer)
{
Encoding.UTF8.GetBytes(payload, writer);
await writer.FlushAsync(cancellationToken);
}
}

return client.ExecuteAsync(timeout, methodName, buffer, qualityOfServiceLevel, parameters);
public static Task<ReadOnlySequence<byte>> ExecuteAsync(this IMqttRpcClient client, string methodName, Stream payload, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(payload);
return client.ExecuteAsync(methodName, WritePayloadAsync, qualityOfServiceLevel, parameters, cancellationToken);

async ValueTask WritePayloadAsync(PipeWriter writer)
{
await payload.CopyToAsync(writer, cancellationToken);
await writer.FlushAsync(cancellationToken);
}
}

public static async Task<ReadOnlySequence<byte>> ExecuteAsync(this IMqttRpcClient client, string methodName, Func<PipeWriter, ValueTask> payloadFactory, MqttQualityOfServiceLevel qualityOfServiceLevel, IDictionary<string, object> parameters = null, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(client);
await using var payloadOwner = await MqttPayloadOwnerFactory.CreateMultipleSegmentAsync(payloadFactory, cancellationToken);
return await client.ExecuteAsync(methodName, payloadOwner.Payload, qualityOfServiceLevel, parameters, cancellationToken);
}
}
}

0 comments on commit 0ed0b48

Please sign in to comment.