Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KeepAliveMode and SupportedWebSocketSubProtocols options #1154

Merged
merged 11 commits into from
Oct 27, 2024
2 changes: 1 addition & 1 deletion src/Transports.AspNetCore/GraphQLHttpMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ protected virtual Task WriteJsonResponseAsync<TResult>(HttpContext context, Http
/// <summary>
/// Gets a list of WebSocket sub-protocols supported.
/// </summary>
protected virtual IEnumerable<string> SupportedWebSocketSubProtocols => _supportedSubProtocols;
protected virtual IEnumerable<string> SupportedWebSocketSubProtocols => _options.WebSockets.SupportedWebSocketSubProtocols;

/// <summary>
/// Creates an <see cref="IWebSocketConnection"/>, a WebSocket message pump.
Expand Down
72 changes: 66 additions & 6 deletions src/Transports.AspNetCore/WebSockets/BaseSubscriptionServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,33 @@ protected virtual Task OnNotAuthorizedPolicyAsync(OperationMessage message, Auth
/// <br/><br/>
/// Otherwise, the connection is acknowledged via <see cref="OnConnectionAcknowledgeAsync(OperationMessage)"/>,
/// <see cref="TryInitialize"/> is called to indicate that this WebSocket connection is ready to accept requests,
/// and keep-alive messages are sent via <see cref="OnSendKeepAliveAsync"/> if configured to do so.
/// Keep-alive messages are only sent if no messages have been sent over the WebSockets connection for the
/// length of time configured in <see cref="GraphQLWebSocketOptions.KeepAliveTimeout"/>.
/// and <see cref="OnSendKeepAliveAsync"/> is called to start sending keep-alive messages if configured to do so.
/// </summary>
protected virtual async Task OnConnectionInitAsync(OperationMessage message)
{
if (!await AuthorizeAsync(message))
{
return;
}
await OnConnectionAcknowledgeAsync(message);
if (!TryInitialize())
return;

_ = OnKeepAliveLoopAsync();
}

/// <summary>
/// Executes when the client is attempting to initialize the connection.
/// <br/><br/>
/// By default, this first checks <see cref="AuthorizeAsync(OperationMessage)"/> to validate that the
/// request has passed authentication. If validation fails, the connection is closed with an Access
/// Denied message.
/// <br/><br/>
/// Otherwise, the connection is acknowledged via <see cref="OnConnectionAcknowledgeAsync(OperationMessage)"/>,
/// <see cref="TryInitialize"/> is called to indicate that this WebSocket connection is ready to accept requests,
/// and <see cref="OnSendKeepAliveAsync"/> is called to start sending keep-alive messages if configured to do so.
/// </summary>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)}(message) and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected virtual async Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
{
if (!await AuthorizeAsync(message))
Expand All @@ -277,12 +300,49 @@ protected virtual async Task OnConnectionInitAsync(OperationMessage message, boo
if (keepAliveTimeout > TimeSpan.Zero)
{
if (smartKeepAlive)
_ = StartSmartKeepAliveLoopAsync();
_ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Timeout);
else
_ = StartKeepAliveLoopAsync();
_ = OnKeepAliveLoopAsync(keepAliveTimeout, KeepAliveMode.Interval);
}
}

/// <summary>
/// Starts sending keep-alive messages if configured to do so. Inspects the configured
/// <see cref="GraphQLWebSocketOptions"/> and passes control to <see cref="OnKeepAliveLoopAsync(TimeSpan, KeepAliveMode)"/>
/// if keep-alive messages are enabled.
/// </summary>
protected virtual Task OnKeepAliveLoopAsync()
{
return OnKeepAliveLoopAsync(
_options.KeepAliveTimeout ?? DefaultKeepAliveTimeout,
_options.KeepAliveMode);
}

/// <summary>
/// Sends keep-alive messages according to the specified timeout period and method.
/// See <see cref="KeepAliveMode"/> for implementation details for each supported mode.
/// </summary>
protected virtual async Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
{
if (keepAliveTimeout <= TimeSpan.Zero)
return;

switch (keepAliveMode)
{
case KeepAliveMode.Default:
case KeepAliveMode.Timeout:
await StartSmartKeepAliveLoopAsync();
break;
case KeepAliveMode.Interval:
await StartDumbKeepAliveLoopAsync();
break;
case KeepAliveMode.TimeoutWithPayload:
throw new NotImplementedException($"{nameof(KeepAliveMode.TimeoutWithPayload)} is not implemented within the {nameof(BaseSubscriptionServer)} class.");
default:
throw new ArgumentOutOfRangeException(nameof(keepAliveMode));
}

async Task StartKeepAliveLoopAsync()
async Task StartDumbKeepAliveLoopAsync()
{
while (!CancellationToken.IsCancellationRequested)
{
Expand Down
19 changes: 19 additions & 0 deletions src/Transports.AspNetCore/WebSockets/GraphQLWebSocketOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ public class GraphQLWebSocketOptions
/// </summary>
public TimeSpan? KeepAliveTimeout { get; set; }

/// <summary>
/// Gets or sets the keep-alive mode used for websocket subscriptions.
/// This property is only applicable when using the GraphQLWs protocol.
/// </summary>
public KeepAliveMode KeepAliveMode { get; set; } = KeepAliveMode.Default;

/// <summary>
/// The amount of time to wait to attempt a graceful teardown of the WebSockets protocol.
/// A value of <see langword="null"/> indicates the default value defined by the implementation.
Expand All @@ -42,4 +48,17 @@ public class GraphQLWebSocketOptions
/// Disconnects a subscription from the client in the event of any GraphQL errors during a subscription. The default value is <see langword="false"/>.
/// </summary>
public bool DisconnectAfterAnyError { get; set; }

/// <summary>
/// The list of supported WebSocket sub-protocols.
/// Defaults to <see cref="GraphQLWs.SubscriptionServer.SubProtocol"/> and <see cref="SubscriptionsTransportWs.SubscriptionServer.SubProtocol"/>.
/// Adding other sub-protocols require the <see cref="GraphQLHttpMiddleware.CreateMessageProcessor(IWebSocketConnection, string)"/> method
/// to be overridden to handle the new sub-protocol.
/// </summary>
/// <remarks>
/// When the <see cref="KeepAliveMode"/> is set to <see cref="KeepAliveMode.TimeoutWithPayload"/>, you may wish to remove
/// <see cref="SubscriptionsTransportWs.SubscriptionServer.SubProtocol"/> from this list to prevent clients from using
/// protocols which do not support the <see cref="KeepAliveMode.TimeoutWithPayload"/> keep-alive mode.
/// </remarks>
public List<string> SupportedWebSocketSubProtocols { get; set; } = [GraphQLWs.SubscriptionServer.SubProtocol, SubscriptionsTransportWs.SubscriptionServer.SubProtocol];
}
12 changes: 12 additions & 0 deletions src/Transports.AspNetCore/WebSockets/GraphQLWs/PingPayload.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
namespace GraphQL.Server.Transports.AspNetCore.WebSockets.GraphQLWs;

/// <summary>
/// The payload of the ping message.
/// </summary>
public class PingPayload
{
/// <summary>
/// The unique identifier of the ping message.
/// </summary>
public string? id { get; set; }
gao-artur marked this conversation as resolved.
Show resolved Hide resolved
}
111 changes: 109 additions & 2 deletions src/Transports.AspNetCore/WebSockets/GraphQLWs/SubscriptionServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
public class SubscriptionServer : BaseSubscriptionServer
{
private readonly IWebSocketAuthenticationService? _authenticationService;
private readonly IGraphQLSerializer _serializer;
private readonly GraphQLWebSocketOptions _options;
private DateTime _lastPongReceivedUtc;
private string? _lastPingId;
private readonly object _lastPingLock = new();

/// <summary>
/// The WebSocket sub-protocol used for this protocol.
Expand Down Expand Up @@ -67,6 +72,8 @@
UserContextBuilder = userContextBuilder ?? throw new ArgumentNullException(nameof(userContextBuilder));
Serializer = serializer ?? throw new ArgumentNullException(nameof(serializer));
_authenticationService = authenticationService;
_serializer = serializer;
_options = options;
}

/// <inheritdoc/>
Expand All @@ -90,7 +97,9 @@
}
else
{
#pragma warning disable CS0618 // Type or member is obsolete
await OnConnectionInitAsync(message, true);

Check warning

Code scanning / CodeQL

Call to obsolete method Warning

Call to obsolete method
OnConnectionInitAsync
.
#pragma warning restore CS0618 // Type or member is obsolete
}
return;
}
Expand All @@ -113,6 +122,69 @@
}
}

/// <inheritdoc/>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
Comment on lines +126 to +127
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been carefully crafted to be fully backwards compatible, just in case of the rare instance that someone derived from this class to customize behavior.

{
if (smartKeepAlive)
return OnConnectionInitAsync(message);
else
return base.OnConnectionInitAsync(message, smartKeepAlive);
Dismissed Show dismissed Hide dismissed
}

/// <inheritdoc/>
protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
gao-artur marked this conversation as resolved.
Show resolved Hide resolved
{
if (keepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
if (keepAliveTimeout <= TimeSpan.Zero)
return Task.CompletedTask;
return SecureKeepAliveLoopAsync(keepAliveTimeout, keepAliveTimeout);
}
return base.OnKeepAliveLoopAsync(keepAliveTimeout, keepAliveMode);

// pingInterval is the time since the last pong was received before sending a new ping
// pongInterval is the time to wait for a pong after a ping was sent before forcibly closing the connection
async Task SecureKeepAliveLoopAsync(TimeSpan pingInterval, TimeSpan pongInterval)
{
lock (_lastPingLock)
_lastPongReceivedUtc = DateTime.UtcNow;
Shane32 marked this conversation as resolved.
Show resolved Hide resolved
while (!CancellationToken.IsCancellationRequested)
{
// Wait for the next ping interval
TimeSpan interval;
var now = DateTime.UtcNow;
DateTime lastPongReceivedUtc;
lock (_lastPingLock)
{
lastPongReceivedUtc = _lastPongReceivedUtc;
}
var nextPing = _lastPongReceivedUtc.Add(pingInterval);
Shane32 marked this conversation as resolved.
Show resolved Hide resolved
interval = nextPing.Subtract(now);
if (interval > TimeSpan.Zero) // could easily be zero or less, if pongInterval is equal or greater than pingInterval
await Task.Delay(interval, CancellationToken);

// Send a new ping message
await OnSendKeepAliveAsync();

// Wait for the pong response
await Task.Delay(pongInterval, CancellationToken);
bool abort;
lock (_lastPingLock)
{
abort = _lastPongReceivedUtc == lastPongReceivedUtc;
}
if (abort)
{
// Forcibly close the connection if the client has not responded to the keep-alive message.
// Do not send a close message to the client or wait for a response.
Connection.HttpContext.Abort();
return;
}
}
}
}

/// <summary>
/// Pong is a required response to a ping, and also a unidirectional keep-alive packet,
/// whereas ping is a bidirectional keep-alive packet.
Expand All @@ -131,11 +203,46 @@
/// Executes when a pong message is received.
/// </summary>
protected virtual Task OnPongAsync(OperationMessage message)
=> Task.CompletedTask;
{
if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
try
{
var pingId = _serializer.ReadNode<PingPayload>(message.Payload)?.id;
lock (_lastPingLock)
{
if (_lastPingId == pingId)
_lastPongReceivedUtc = DateTime.UtcNow;
}
}
catch { } // ignore deserialization errors in case the pong message does not match the expected format
Dismissed Show dismissed Hide dismissed
Dismissed Show dismissed Hide dismissed
}
return Task.CompletedTask;
}

/// <inheritdoc/>
protected override Task OnSendKeepAliveAsync()
=> Connection.SendMessageAsync(_pongMessage);
{
if (_options.KeepAliveMode == KeepAliveMode.TimeoutWithPayload)
{
var lastPingId = Guid.NewGuid().ToString("N");
lock (_lastPingLock)
{
_lastPingId = lastPingId;
}
return Connection.SendMessageAsync(
new()
{
Type = MessageType.Ping,
Payload = new PingPayload { id = lastPingId }
}
);
}
else
{
return Connection.SendMessageAsync(_pongMessage);
}
}

private static readonly OperationMessage _connectionAckMessage = new() { Type = MessageType.ConnectionAck };
/// <inheritdoc/>
Expand Down
36 changes: 36 additions & 0 deletions src/Transports.AspNetCore/WebSockets/KeepAliveMode.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
namespace GraphQL.Server.Transports.AspNetCore.WebSockets;

/// <summary>
/// Specifies the mode of keep-alive behavior.
/// </summary>
public enum KeepAliveMode
{
/// <summary>
/// Same as <see cref="Timeout"/>: Sends a unidirectional keep-alive message when no message has been received within the specified timeout period.
/// </summary>
Default = 0,

/// <summary>
/// Sends a unidirectional keep-alive message when no message has been received within the specified timeout period.
/// </summary>
Timeout = 1,

/// <summary>
/// Sends a unidirectional keep-alive message at a fixed interval, regardless of message activity.
/// </summary>
Interval = 2,

/// <summary>
/// Sends a Ping message with a payload after the specified timeout from the last received Pong,
/// and waits for a corresponding Pong response. Requires that the client reflects the payload
/// in the response. Forcibly disconnects the client if the client does not respond with a Pong
/// message within the specified timeout. This means that a dead connection will be closed after
/// a maximum of double the <see cref="GraphQLWebSocketOptions.KeepAliveTimeout"/> period.
/// </summary>
/// <remarks>
/// This mode is particularly useful when backpressure causes subscription messages to be delayed
/// due to a slow or unresponsive client connection. The server can detect that the client is not
/// processing messages in a timely manner and disconnect the client to free up resources.
/// </remarks>
TimeoutWithPayload = 3,
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
}
else
{
#pragma warning disable CS0618 // Type or member is obsolete
await OnConnectionInitAsync(message, false);

Check warning

Code scanning / CodeQL

Call to obsolete method Warning

Call to obsolete method
OnConnectionInitAsync
.
#pragma warning restore CS0618 // Type or member is obsolete
}
return;
}
Expand All @@ -108,6 +110,26 @@
}
}

/// <inheritdoc/>
[Obsolete($"Please use the {nameof(OnConnectionInitAsync)} and {nameof(OnKeepAliveLoopAsync)} methods instead. This method will be removed in a future version of this library.")]
protected override Task OnConnectionInitAsync(OperationMessage message, bool smartKeepAlive)
{
if (!smartKeepAlive)
return OnConnectionInitAsync(message);
else
return base.OnConnectionInitAsync(message, smartKeepAlive);
Dismissed Show dismissed Hide dismissed
}

/// <inheritdoc/>
/// <remarks>
/// This implementation overrides <see cref="GraphQLWebSocketOptions.KeepAliveMode"/> to <see cref="KeepAliveMode.Interval"/>
/// as this protocol does not support the other modes. Override this method to support your own implementation.
/// </remarks>
protected override Task OnKeepAliveLoopAsync(TimeSpan keepAliveTimeout, KeepAliveMode keepAliveMode)
=> base.OnKeepAliveLoopAsync(
keepAliveTimeout,
KeepAliveMode.Interval);

private static readonly OperationMessage _keepAliveMessage = new() { Type = MessageType.GQL_CONNECTION_KEEP_ALIVE };
/// <inheritdoc/>
protected override Task OnSendKeepAliveAsync()
Expand Down
Loading
Loading