Skip to content

Commit

Permalink
Merge pull request #790 from Cysharp/feature/ReduceAllocationOnStream…
Browse files Browse the repository at this point in the history
…ingHubClient

Reduce allocations on StreamingHubClient method calls
  • Loading branch information
mayuki authored Jun 17, 2024
2 parents 6407f91 + 036e813 commit 02bf1ed
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 42 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Runtime.CompilerServices;
using System.Threading.Channels;
using Grpc.Core;

Expand All @@ -14,7 +15,13 @@ public ChannelAsyncStreamReader(Channel<T> channel)
reader = channel.Reader;
}

public async Task<bool> MoveNext(CancellationToken cancellationToken)
public Task<bool> MoveNext(CancellationToken cancellationToken)
{
return MoveNextCore(cancellationToken).AsTask();
}

[AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))]
async ValueTask<bool> MoveNextCore(CancellationToken cancellationToken)
{
if (await reader.WaitToReadAsync())
{
Expand Down
9 changes: 8 additions & 1 deletion perf/Microbenchmark/Microbenchmark.Client/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@
using MagicOnion;
using Microbenchmark.Client;

BenchmarkRunner.Run<HubReceiverBroadcastBenchmarks>();
//BenchmarkRunner.Run<HubReceiverBroadcastBenchmarks>();
BenchmarkRunner.Run<HubMethodBenchmarks>();

#if FALSE
var b = new HubMethodBenchmarks();
for (var i = 0; i < 1000000; i++)
await b.Parameter_Zero_Return_ValueType();
#endif

class MySynchronizationContext : SynchronizationContext;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,28 +86,49 @@ public async Task<TStreamingHub> ConnectAsync(StreamingHubClientOptions options,
);
}

public async Task<ReadOnlyMemory<byte>> ReadRequestRawAsync()
public async ValueTask<ReadOnlyMemory<byte>> ReadRequestRawAsync()
{
var requestPayload = await requestChannel.Reader.ReadAsync();
return requestPayload.Memory;
}

public async Task<(int MessageId, int MethodId, T Request)> ReadRequestAsync<T>()
public async ValueTask<(int MessageId, int MethodId, T Request)> ReadRequestAsync<T>()
{
var requestPayload = await requestChannel.Reader.ReadAsync();
return ReadRequestPayload<T>(requestPayload.Memory);
try
{
return ReadRequestPayload<T>(requestPayload.Memory);
}
finally
{
StreamingHubPayloadPool.Shared.Return(requestPayload);
}
}

public async Task<(int MessageId, int MethodId, ReadOnlyMemory<byte> Request)> ReadRequestNoDeserializeAsync()
public async ValueTask<(int MessageId, int MethodId, ReadOnlyMemory<byte> Request)> ReadRequestNoDeserializeAsync()
{
var requestPayload = await requestChannel.Reader.ReadAsync();
return ReadRequestPayload(requestPayload.Memory);
try
{
return ReadRequestPayload(requestPayload.Memory);
}
finally
{
StreamingHubPayloadPool.Shared.Return(requestPayload);
}
}

public async Task<(int MethodId, T Request)> ReadFireAndForgetRequestAsync<T>()
public async ValueTask<(int MethodId, T Request)> ReadFireAndForgetRequestAsync<T>()
{
var requestPayload = await requestChannel.Reader.ReadAsync();
return ReadFireAndForgetRequestPayload<T>(requestPayload.Memory);
try
{
return ReadFireAndForgetRequestPayload<T>(requestPayload.Memory);
}
finally
{
StreamingHubPayloadPool.Shared.Return(requestPayload);
}
}

public void WriteResponseRaw(ReadOnlySpan<byte> data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public abstract class StreamingHubClientBase<TStreamingHub, TReceiver>
readonly IMagicOnionSerializer messageSerializer;
readonly Method<StreamingHubPayload, StreamingHubPayload> duplexStreamingConnectMethod;
// {messageId, TaskCompletionSource}
readonly ConcurrentDictionary<int, ITaskCompletion> responseFutures = new();
readonly Dictionary<int, ITaskCompletion> responseFutures = new();
readonly TaskCompletionSource<bool> waitForDisconnect = new();
readonly CancellationTokenSource cancellationTokenSource = new();

Expand Down Expand Up @@ -319,27 +319,44 @@ SendOrPostCallback CreateBroadcastCallback(int methodId, int consumed)
void ProcessResponse(SynchronizationContext? syncContext, StreamingHubPayload payload, ref StreamingHubClientMessageReader messageReader)
{
var message = messageReader.ReadResponseMessage();
if (responseFutures.Remove(message.MessageId, out var future))

ITaskCompletion? future;
lock (responseFutures)
{
try
if (!responseFutures.Remove(message.MessageId, out future))
{
OnResponseEvent(message.MethodId, future, message.Body);
StreamingHubPayloadPool.Shared.Return(payload);
return;
}
catch (Exception ex)
}

try
{
OnResponseEvent(message.MethodId, future, message.Body);
StreamingHubPayloadPool.Shared.Return(payload);
}
catch (Exception ex)
{
if (!future.TrySetException(ex))
{
if (!future.TrySetException(ex))
{
throw;
}
throw;
}
}
}

void ProcessResponseWithError(SynchronizationContext? syncContext, StreamingHubPayload payload, ref StreamingHubClientMessageReader messageReader)
{
var message = messageReader.ReadResponseWithErrorMessage();
if (responseFutures.Remove(message.MessageId, out var future))

ITaskCompletion? future;
lock (responseFutures)
{
if (!responseFutures.Remove(message.MessageId, out future))
{
return;
}
}

if (responseFutures.Remove(message.MessageId, out future))
{
RpcException ex;
if (string.IsNullOrWhiteSpace(message.Error))
Expand Down Expand Up @@ -462,7 +479,10 @@ protected Task<TResponse> WriteMessageWithResponseAsync<TRequest, TResponse>(int
TaskCreationOptions.RunContinuationsAsynchronously
#endif
);
responseFutures[mid] = tcs;
lock (responseFutures)
{
responseFutures[mid] = tcs;
}

var v = BuildRequestMessage(methodId, mid, message);
_ = writerQueue.Writer.TryWrite(v);
Expand Down
59 changes: 38 additions & 21 deletions src/MagicOnion.Client/StreamingHubClientBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public abstract class StreamingHubClientBase<TStreamingHub, TReceiver>
readonly IMagicOnionSerializer messageSerializer;
readonly Method<StreamingHubPayload, StreamingHubPayload> duplexStreamingConnectMethod;
// {messageId, TaskCompletionSource}
readonly ConcurrentDictionary<int, ITaskCompletion> responseFutures = new();
readonly Dictionary<int, ITaskCompletion> responseFutures = new();
readonly TaskCompletionSource<bool> waitForDisconnect = new();
readonly CancellationTokenSource cancellationTokenSource = new();

Expand Down Expand Up @@ -319,41 +319,55 @@ SendOrPostCallback CreateBroadcastCallback(int methodId, int consumed)
void ProcessResponse(SynchronizationContext? syncContext, StreamingHubPayload payload, ref StreamingHubClientMessageReader messageReader)
{
var message = messageReader.ReadResponseMessage();
if (responseFutures.Remove(message.MessageId, out var future))

ITaskCompletion? future;
lock (responseFutures)
{
try
if (!responseFutures.Remove(message.MessageId, out future))
{
OnResponseEvent(message.MethodId, future, message.Body);
StreamingHubPayloadPool.Shared.Return(payload);
return;
}
catch (Exception ex)
}

try
{
OnResponseEvent(message.MethodId, future, message.Body);
StreamingHubPayloadPool.Shared.Return(payload);
}
catch (Exception ex)
{
if (!future.TrySetException(ex))
{
if (!future.TrySetException(ex))
{
throw;
}
throw;
}
}
}

void ProcessResponseWithError(SynchronizationContext? syncContext, StreamingHubPayload payload, ref StreamingHubClientMessageReader messageReader)
{
var message = messageReader.ReadResponseWithErrorMessage();
if (responseFutures.Remove(message.MessageId, out var future))

ITaskCompletion? future;
lock (responseFutures)
{
RpcException ex;
if (string.IsNullOrWhiteSpace(message.Error))
{
ex = new RpcException(new Status((StatusCode)message.StatusCode, message.Detail ?? string.Empty));
}
else
if (!responseFutures.Remove(message.MessageId, out future))
{
ex = new RpcException(new Status((StatusCode)message.StatusCode, message.Detail ?? string.Empty), message.Detail + Environment.NewLine + message.Error);
return;
}
}

future.TrySetException(ex);
StreamingHubPayloadPool.Shared.Return(payload);
RpcException ex;
if (string.IsNullOrWhiteSpace(message.Error))
{
ex = new RpcException(new Status((StatusCode)message.StatusCode, message.Detail ?? string.Empty));
}
else
{
ex = new RpcException(new Status((StatusCode)message.StatusCode, message.Detail ?? string.Empty), message.Detail + Environment.NewLine + message.Error);
}

future.TrySetException(ex);
StreamingHubPayloadPool.Shared.Return(payload);
}

void ProcessClientResultRequest(SynchronizationContext? syncContext, StreamingHubPayload payload, ref StreamingHubClientMessageReader messageReader)
Expand Down Expand Up @@ -462,7 +476,10 @@ protected Task<TResponse> WriteMessageWithResponseAsync<TRequest, TResponse>(int
TaskCreationOptions.RunContinuationsAsynchronously
#endif
);
responseFutures[mid] = tcs;
lock (responseFutures)
{
responseFutures[mid] = tcs;
}

var v = BuildRequestMessage(methodId, mid, message);
_ = writerQueue.Writer.TryWrite(v);
Expand Down

0 comments on commit 02bf1ed

Please sign in to comment.