diff --git a/MaxLib.WebServer.Test/WebSocket/FrameParsing.cs b/MaxLib.WebServer.Test/WebSocket/FrameParsing.cs new file mode 100644 index 0000000..e20b3df --- /dev/null +++ b/MaxLib.WebServer.Test/WebSocket/FrameParsing.cs @@ -0,0 +1,111 @@ +using MaxLib.WebServer.WebSocket; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; +using System.Threading.Tasks; + +namespace MaxLib.WebServer.Test.WebSocket +{ + [TestClass] + public class FrameParsing + { + [TestMethod] + public async Task ReadSingleFrameUnmaskedTextMessage() + { + var m = new MemoryStream(new byte[] { 0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f }); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Text, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("Hello", frame.TextPayload); + } + + [TestMethod] + public async Task ReadSingleFrameMaskedTextMessage() + { + var m = new MemoryStream(new byte[] { 0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58 }); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Text, frame.OpCode); + Assert.IsTrue(frame.HasMaskingKey); + frame.UnapplyMask(); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("Hello", frame.TextPayload); + } + + [TestMethod] + public async Task ReadFragmentedUnmaskedTextMessage() + { + var m = new MemoryStream(new byte[] { 0x01, 0x03, 0x48, 0x65, 0x6c }); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsFalse(frame!.FinalFrame); + Assert.AreEqual(OpCode.Text, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("Hel", frame.TextPayload); + + m = new MemoryStream(new byte[] { 0x80, 0x02, 0x6c, 0x6f }); + frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Continuation, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("lo", frame.TextPayload); + } + + + [TestMethod] + public async Task ReadUnmaskedPingAndMaskedPongMessage() + { + var m = new MemoryStream(new byte[] { 0x89, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f }); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Ping, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("Hello", frame.TextPayload); + + m = new MemoryStream(new byte[] { 0x8a, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58 }); + frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Pong, frame.OpCode); + Assert.IsTrue(frame.HasMaskingKey); + frame.UnapplyMask(); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual("Hello", frame.TextPayload); + } + + [TestMethod] + public async Task Read256ByteUnmaskedMessage() + { + Memory data = new byte[4 + 256]; + (new byte[] { 0x82, 0x7E, 0x01, 0x00 }).CopyTo(data[..4]); + var m = new MemoryStream(data.ToArray()); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Binary, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual(256, frame.Payload.Length); + } + + [TestMethod] + public async Task Read64KiByteUnmaskedMessage() + { + Memory data = new byte[10 + 65536]; + (new byte[] { 0x82, 0x7F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00 }).CopyTo(data[..10]); + var m = new MemoryStream(data.ToArray()); + var frame = await Frame.TryRead(m); + Assert.IsNotNull(frame); + Assert.IsTrue(frame!.FinalFrame); + Assert.AreEqual(OpCode.Binary, frame.OpCode); + Assert.IsFalse(frame.HasMaskingKey); + Assert.AreEqual(65536, frame.Payload.Length); + } + } +} diff --git a/MaxLib.WebServer.WebSocket.Echo/EchoConnection.cs b/MaxLib.WebServer.WebSocket.Echo/EchoConnection.cs new file mode 100644 index 0000000..a95b8b5 --- /dev/null +++ b/MaxLib.WebServer.WebSocket.Echo/EchoConnection.cs @@ -0,0 +1,31 @@ +using System.IO; +using System.Threading.Tasks; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket.Echo +{ + public class EchoConnection : WebSocketConnection + { + public EchoConnection(Stream networkStream) + : base(networkStream) + { + } + + protected override async Task ReceiveClose(CloseReason? reason, string? info) + { + WebServerLog.Add(ServerLogType.Information, GetType(), "WebSocket", $"client close websocket ({reason}): {info}"); + if (!SendCloseSignal) + await Close(); + } + + protected override async Task ReceivedFrame(Frame frame) + { + await SendFrame(new Frame + { + OpCode = frame.OpCode, + Payload = frame.Payload + }); + } + } +} diff --git a/MaxLib.WebServer.WebSocket.Echo/EchoEndpoint.cs b/MaxLib.WebServer.WebSocket.Echo/EchoEndpoint.cs new file mode 100644 index 0000000..499a36b --- /dev/null +++ b/MaxLib.WebServer.WebSocket.Echo/EchoEndpoint.cs @@ -0,0 +1,16 @@ +using System.IO; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket.Echo +{ + public class EchoEndpoint : WebSocketEndpoint + { + public override string? Protocol => null; + + protected override EchoConnection CreateConnection(Stream stream, HttpRequestHeader header) + { + return new EchoConnection(stream); + } + } +} diff --git a/MaxLib.WebServer.WebSocket.Echo/MaxLib.WebServer.WebSocket.Echo.csproj b/MaxLib.WebServer.WebSocket.Echo/MaxLib.WebServer.WebSocket.Echo.csproj new file mode 100644 index 0000000..fc78de2 --- /dev/null +++ b/MaxLib.WebServer.WebSocket.Echo/MaxLib.WebServer.WebSocket.Echo.csproj @@ -0,0 +1,12 @@ + + + + Exe + netcoreapp3.1 + + + + + + + diff --git a/MaxLib.WebServer.WebSocket.Echo/Program.cs b/MaxLib.WebServer.WebSocket.Echo/Program.cs new file mode 100644 index 0000000..39987d1 --- /dev/null +++ b/MaxLib.WebServer.WebSocket.Echo/Program.cs @@ -0,0 +1,38 @@ +using MaxLib.WebServer.Services; +using System; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket.Echo +{ + class Program + { + static void Main() + { + WebServerLog.LogAdded += WebServerLog_LogAdded; + var server = new Server(new WebServerSettings(8000, 5000)); + // add services + server.AddWebService(new HttpRequestParser()); + server.AddWebService(new HttpHeaderPostParser()); + server.AddWebService(new HttpHeaderSpecialAction()); + server.AddWebService(new HttpResponseCreator()); + server.AddWebService(new HttpSender()); + // setup web socket + var websocket = new WebSocketService(); + websocket.Add(new EchoEndpoint()); + server.AddWebService(websocket); + // start server + server.Start(); + // wait for console quit + while (Console.ReadKey().Key != ConsoleKey.Q) ; + // close + server.Stop(); + websocket.Dispose(); + } + + private static void WebServerLog_LogAdded(ServerLogItem item) + { + Console.WriteLine($"[{item.Date}] [{item.Type}] ({item.InfoType}) {item.SenderType}: {item.Information}"); + } + } +} diff --git a/MaxLib.WebServer.sln b/MaxLib.WebServer.sln index ef915f0..b7611c4 100644 --- a/MaxLib.WebServer.sln +++ b/MaxLib.WebServer.sln @@ -7,10 +7,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MaxLib.WebServer", "MaxLib. EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MaxLib.WebServer.Test", "MaxLib.WebServer.Test\MaxLib.WebServer.Test.csproj", "{60225D92-5742-4BC0-A3A5-206123F2129D}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MaxLib.WebServer.Example", "example\MaxLib.WebServer.Example\MaxLib.WebServer.Example.csproj", "{6616448A-1AD7-4897-9124-F1560FB12461}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MaxLib.WebServer.Example", "example\MaxLib.WebServer.Example\MaxLib.WebServer.Example.csproj", "{6616448A-1AD7-4897-9124-F1560FB12461}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Example", "Example", "{39740782-6F07-470C-92B8-C0A07A5C0DFB}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MaxLib.WebServer.WebSocket.Echo", "MaxLib.WebServer.WebSocket.Echo\MaxLib.WebServer.WebSocket.Echo.csproj", "{01C302FF-8E99-4D54-8234-D4BA59862176}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -57,12 +59,25 @@ Global {6616448A-1AD7-4897-9124-F1560FB12461}.Release|x64.Build.0 = Release|Any CPU {6616448A-1AD7-4897-9124-F1560FB12461}.Release|x86.ActiveCfg = Release|Any CPU {6616448A-1AD7-4897-9124-F1560FB12461}.Release|x86.Build.0 = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|Any CPU.Build.0 = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|x64.ActiveCfg = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|x64.Build.0 = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|x86.ActiveCfg = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Debug|x86.Build.0 = Debug|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|Any CPU.ActiveCfg = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|Any CPU.Build.0 = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|x64.ActiveCfg = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|x64.Build.0 = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|x86.ActiveCfg = Release|Any CPU + {01C302FF-8E99-4D54-8234-D4BA59862176}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE EndGlobalSection GlobalSection(NestedProjects) = preSolution {6616448A-1AD7-4897-9124-F1560FB12461} = {39740782-6F07-470C-92B8-C0A07A5C0DFB} + {01C302FF-8E99-4D54-8234-D4BA59862176} = {39740782-6F07-470C-92B8-C0A07A5C0DFB} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {E8A0EB3B-E5B8-4120-92DD-1C11CEE7017F} diff --git a/MaxLib.WebServer/MaxLib.WebServer.csproj.include b/MaxLib.WebServer/MaxLib.WebServer.csproj.include index 30bed8f..2cf7dbf 100644 --- a/MaxLib.WebServer/MaxLib.WebServer.csproj.include +++ b/MaxLib.WebServer/MaxLib.WebServer.csproj.include @@ -2,7 +2,7 @@ - 2.3.1 + 2.4.0 $(Version).0 $(Version).0 diff --git a/MaxLib.WebServer/Server.cs b/MaxLib.WebServer/Server.cs index b62bd8e..351e359 100644 --- a/MaxLib.WebServer/Server.cs +++ b/MaxLib.WebServer/Server.cs @@ -204,6 +204,15 @@ protected virtual async Task ClientStartListen(HttpConnection connection) await ExecuteTaskChain(task); + if (task.SwitchProtocolHandler != null) + { + KeepAliveConnections.Remove(connection); + AllConnections.Remove(connection); + task.Dispose(); + _ = task.SwitchProtocolHandler(); + return; + } + if (task.Request.FieldConnection == HttpConnectionType.KeepAlive) { if (!KeepAliveConnections.Contains(connection)) diff --git a/MaxLib.WebServer/WebProgressTask.cs b/MaxLib.WebServer/WebProgressTask.cs index feb767e..c4da6dd 100644 --- a/MaxLib.WebServer/WebProgressTask.cs +++ b/MaxLib.WebServer/WebProgressTask.cs @@ -1,5 +1,6 @@ using System.Net; using System; +using System.Threading.Tasks; #nullable enable @@ -38,5 +39,23 @@ public void Dispose() { Document?.Dispose(); } + + internal Func? SwitchProtocolHandler { get; private set; } = null; + + /// + /// A call to this method notify the web server that this connection will switch protocols + /// after all steps are finished. The web server will remove this connection from its + /// watch list and call after its finished. + ///
+ /// You as the caller are responsible to safely cleanup the connection it is no more + /// used. + ///
+ /// + /// This handler will be called after the server has no more control of this connection. + /// + public void SwitchProtocols(Func handler) + { + SwitchProtocolHandler = handler; + } } } diff --git a/MaxLib.WebServer/WebSocket/CloseReason.cs b/MaxLib.WebServer/WebSocket/CloseReason.cs new file mode 100644 index 0000000..5c481c4 --- /dev/null +++ b/MaxLib.WebServer/WebSocket/CloseReason.cs @@ -0,0 +1,67 @@ +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + public enum CloseReason : ushort + { + /// + /// 1000 indicates a normal closure, meaning that the purpose for + /// which the connection was established has been fulfilled. + /// + NormalClose = 1000, + /// + /// 1001 indicates that an endpoint is "going away", such as a server + /// going down or a browser having navigated away from a page. + /// + GoingAway = 1001, + /// + /// 1002 indicates that an endpoint is terminating the connection due + /// to a protocol error. + /// + ProtocolError = 1002, + /// + /// 1003 indicates that an endpoint is terminating the connection + /// because it has received a type of data it cannot accept (e.g., an + /// endpoint that understands only text data MAY send this if it + /// receives a binary message). + /// + CannotAccept = 1003, + /// + /// 1007 indicates that an endpoint is terminating the connection + /// because it has received data within a message that was not + /// consistent with the type of the message (e.g., non-UTF-8 [RFC3629] + /// data within a text message). + /// + InvalidMessageContent = 1007, + /// + /// 1008 indicates that an endpoint is terminating the connection + /// because it has received a message that violates its policy. This + /// is a generic status code that can be returned when there is no + /// other more suitable status code (e.g., 1003 or 1009) or if there + /// is a need to hide specific details about the policy. + /// + PolicyError = 1008, + /// + /// 1009 indicates that an endpoint is terminating the connection + /// because it has received a message that is too big for it to + /// process. + /// + TooBigMessage = 1009, + /// + /// 1010 indicates that an endpoint (client) is terminating the + /// connection because it has expected the server to negotiate one or + /// more extension, but the server didn't return them in the response + /// message of the WebSocket handshake. The list of extensions that + /// are needed SHOULD appear in the /reason/ part of the Close frame. + /// Note that this status code is not used by the server, because it + /// can fail the WebSocket handshake instead. + /// + MissingExtension = 1010, + /// + /// 1011 indicates that a server is terminating the connection because + /// it encountered an unexpected condition that prevented it from + /// fulfilling the request. + /// + UnexpectedCondition = 1011, + } +} diff --git a/MaxLib.WebServer/WebSocket/Frame.cs b/MaxLib.WebServer/WebSocket/Frame.cs new file mode 100644 index 0000000..aab51df --- /dev/null +++ b/MaxLib.WebServer/WebSocket/Frame.cs @@ -0,0 +1,170 @@ +using System; +using System.IO; +using System.Text; +using System.Threading.Tasks; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + public class Frame + { + public bool FinalFrame { get; set; } = true; + + public OpCode OpCode { get; set; } + + public bool HasMaskingKey { get; set; } + + public Memory MaskingKey { get; } = new byte[4]; + + public Memory Payload { get; set; } + + public string TextPayload + { + get => Encoding.UTF8.GetString(Payload.Span); + set => Payload = Encoding.UTF8.GetBytes(value ?? throw new ArgumentNullException(nameof(value))); + } + + public async Task Write(Stream output) + { + Memory buffer = new byte[8]; + buffer.Span[0] = (byte)((byte)OpCode | (FinalFrame ? 0x80 : 0x00)); + buffer.Span[1] = (byte)(Payload.Length < 126 ? Payload.Length : + (Payload.Length <= ushort.MaxValue ? 126 : 127) + ); + await output.WriteAsync(buffer[ .. 2]); + if (Payload.Length >= 126 && Payload.Length <= ushort.MaxValue) + { + ToNetworkByteOrder(BitConverter.GetBytes((ushort)Payload.Length), buffer.Span[0..2]); + await output.WriteAsync(buffer[..2]); + } + if (Payload.Length > ushort.MaxValue) + { + ToNetworkByteOrder(BitConverter.GetBytes((ulong)Payload.Length), buffer.Span); + await output.WriteAsync(buffer); + } + if (HasMaskingKey) + await output.WriteAsync(MaskingKey); + await output.WriteAsync(Payload); + } + + public static async Task TryRead(Stream input, bool throwLargePayload = false) + { + try + { + Memory buffer = new byte[8]; + if (await input.ReadAsync(buffer[0..2]) != 2) + return null; + var frame = new Frame + { + FinalFrame = (buffer.Span[0] & 0x80) == 0x80, + OpCode = (OpCode)(buffer.Span[0] & 0x0f), + HasMaskingKey = (buffer.Span[1] & 0x80) == 0x80, + }; + var lengthIndicator = buffer.Span[1] & 0x7f; + ulong length = (ulong)lengthIndicator; + if (lengthIndicator == 126) + { + if (await input.ReadAsync(buffer[0..2]) != 2) + return null; + ToLocalByteOrder(buffer.Span[..2]); + length = BitConverter.ToUInt16(buffer.Span[..2]); + } + if (lengthIndicator == 127) + { + if (await input.ReadAsync(buffer) != 8) + return null; + ToLocalByteOrder(buffer.Span); + length = BitConverter.ToUInt64(buffer.Span); + } + if (length > int.MaxValue) + { + if (throwLargePayload) + throw new TooLargePayloadException(); + else return null; + } + + if (frame.HasMaskingKey) + { + if (await input.ReadAsync(buffer[..4]) != 4) + return null; + buffer[..4].CopyTo(frame.MaskingKey); + } + + frame.Payload = new byte[(int)length]; + if (await input.ReadAsync(frame.Payload) != frame.Payload.Length) + return null; + + return frame; + } + catch (TooLargePayloadException) + { + throw; + } + catch (Exception e) + { + WebServerLog.Add(ServerLogType.Information, typeof(Frame), "WebSocket", $"cannot read frame: {e}"); + return null; + } + } + + public static void ToNetworkByteOrder(ReadOnlySpan input, Span buffer) + { + if (input.Length != buffer.Length) + throw new InvalidOperationException(); + if (BitConverter.IsLittleEndian) + for (int i = 0; i < input.Length; ++i) + buffer[input.Length - i - 1] = input[i]; + else input.CopyTo(buffer); + } + + public static void ToLocalByteOrder(Span buffer) + { + if (BitConverter.IsLittleEndian) + buffer.Reverse(); + } + + protected void ToBytes(ushort value, Span buffer) + { + var result = BitConverter.GetBytes(value); + if (result.Length > buffer.Length) + throw new InvalidOperationException(); + if (BitConverter.IsLittleEndian) + for (int i = 0; i < result.Length; ++i) + buffer[result.Length - i - 1] = result[i]; + else result.CopyTo(buffer); + } + + public void ApplyMask() + { + if (HasMaskingKey) + return; + var span = Payload.Span; + var mask = MaskingKey.Span; + for (int i = 0; i < span.Length; ++i) + span[i] ^= mask[i & 0x3]; + HasMaskingKey = true; + } + + public void UnapplyMask() + { + if (!HasMaskingKey) + return; + var span = Payload.Span; + var mask = MaskingKey.Span; + for (int i = 0; i < span.Length; ++i) + span[i] ^= mask[i & 0x3]; + HasMaskingKey = false; + } + } + + public enum OpCode : byte + { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xa, + } +} diff --git a/MaxLib.WebServer/WebSocket/TooLargePayloadException.cs b/MaxLib.WebServer/WebSocket/TooLargePayloadException.cs new file mode 100644 index 0000000..63ccf52 --- /dev/null +++ b/MaxLib.WebServer/WebSocket/TooLargePayloadException.cs @@ -0,0 +1,17 @@ +using System; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + [Serializable] + public class TooLargePayloadException : Exception + { + public TooLargePayloadException() { } + public TooLargePayloadException(string message) : base(message) { } + public TooLargePayloadException(string message, Exception inner) : base(message, inner) { } + protected TooLargePayloadException( + System.Runtime.Serialization.SerializationInfo info, + System.Runtime.Serialization.StreamingContext context) : base(info, context) { } + } +} diff --git a/MaxLib.WebServer/WebSocket/WebSocketConnection.cs b/MaxLib.WebServer/WebSocket/WebSocketConnection.cs new file mode 100644 index 0000000..d361428 --- /dev/null +++ b/MaxLib.WebServer/WebSocket/WebSocketConnection.cs @@ -0,0 +1,190 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading.Tasks; +using System.Linq; +using System.Threading; +using System.Text; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + public abstract class WebSocketConnection : IDisposable, IAsyncDisposable + { + public Stream NetworkStream { get; } + private readonly SemaphoreSlim lockStream = new SemaphoreSlim(0, 1); + + public bool ReceivedCloseSignal { get; private set; } + + public bool SendCloseSignal { get; private set; } + + public DateTime LastPong { get; private set; } + + public event EventHandler? Closed; + + public WebSocketConnection(Stream networkStream) + { + NetworkStream = networkStream ?? throw new ArgumentNullException(nameof(networkStream)); + } + + public virtual void Dispose() + { + NetworkStream.Dispose(); + lockStream.Dispose(); + } + + public virtual async ValueTask DisposeAsync() + { + await NetworkStream.DisposeAsync(); + lockStream.Dispose(); + } + + public async Task Close(CloseReason reason = CloseReason.NormalClose, string? info = null) + { + Memory payload = new byte[2 + (info == null ? 0 : Encoding.UTF8.GetByteCount(info))]; + Frame.ToNetworkByteOrder(BitConverter.GetBytes((ushort)reason), payload.Span[..2]); + int size = payload.Length; + if (info != null) + size = 2 + Encoding.UTF8.GetBytes(info, payload.Span[2..]); + await SendFrame(new Frame + { + OpCode = OpCode.Close, + Payload = payload[..size], + }); + } + + /// + /// This function is called after the handshake is finished + /// + public async Task HandshakeFinished() + { + lockStream.Release(); + // receiving end + var receiver = Task.Run(async () => + { + var payloadQueue = new Queue>(); + OpCode code = OpCode.Binary; + while (!ReceivedCloseSignal) + { + Frame? frame; + try + { + frame = await Frame.TryRead(NetworkStream); + } + catch (TooLargePayloadException) + { + await Close(CloseReason.TooBigMessage, $"Payload is larger then the allowed {int.MaxValue} bytes"); + return; + } + if (frame == null) + return; + + frame.UnapplyMask(); + + if (!frame.FinalFrame) + { + code = frame.OpCode; + payloadQueue.Enqueue(frame.Payload); + continue; + } + + switch (frame.OpCode) + { + case OpCode.Close: + CloseReason? reason = null; + string? info = null; + if (frame.Payload.Length >= 2) + { + Frame.ToLocalByteOrder(frame.Payload.Span[0..2]); + reason = (CloseReason)BitConverter.ToUInt16(frame.Payload.Span[0..2]); + } + if (frame.Payload.Length > 2) + { + info = Encoding.UTF8.GetString(frame.Payload.Span[2..]); + } + ReceivedCloseSignal = true; + await ReceiveClose(reason, info); + break; + case OpCode.Ping: + frame.OpCode = OpCode.Pong; + await SendFrame(frame); + break; + case OpCode.Pong: + LastPong = DateTime.UtcNow; + break; + default: + if (payloadQueue.Count == 0) + await ReceivedFrame(frame); + else + { + payloadQueue.Enqueue(frame.Payload); + long maxSize = payloadQueue.Sum(x => (long)x.Length); + if (maxSize > int.MaxValue) + { + await Close(CloseReason.TooBigMessage, + $"the payload of all frames add up to {maxSize}. Only {int.MaxValue} is allowed." + ); + } + Memory payload = new byte[maxSize]; + int start = 0; + while (payloadQueue.Count > 0) + { + var item = payloadQueue.Dequeue(); + item.CopyTo(payload.Slice(start, item.Length)); + start += item.Length; + } + frame.Payload = payload; + frame.OpCode = code; + await ReceivedFrame(frame); + } + break; + } + } + }); + + // ping + var pinger = Task.Run(async () => + { + while (!SendCloseSignal) + { + await Task.Delay(TimeSpan.FromSeconds(10)); + await SendFrame(new Frame + { + OpCode = OpCode.Ping + }); + } + }); + + await Task.WhenAll(receiver, pinger); + Closed?.Invoke(this, EventArgs.Empty); + } + + protected virtual async Task SendFrame(Frame frame) + { + if (SendCloseSignal) + return; + await lockStream.WaitAsync(); + if (frame.OpCode == OpCode.Close) + SendCloseSignal = true; + try + { + await frame.Write(NetworkStream); + lockStream.Release(); + } + catch (IOException e) + { + WebServerLog.Add(ServerLogType.Information, GetType(), "WebSocket", $"Unexpected network error: {e}"); + var alreadyReceived = ReceivedCloseSignal; + ReceivedCloseSignal = true; + SendCloseSignal = true; + if (!alreadyReceived) + await ReceiveClose(null, null); + } + } + + protected abstract Task ReceiveClose(CloseReason? reason, string? info); + + protected abstract Task ReceivedFrame(Frame frame); + } +} diff --git a/MaxLib.WebServer/WebSocket/WebSocketEndpoint.cs b/MaxLib.WebServer/WebSocket/WebSocketEndpoint.cs new file mode 100644 index 0000000..cfc31fc --- /dev/null +++ b/MaxLib.WebServer/WebSocket/WebSocketEndpoint.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + public interface IWebSocketEndpoint : IDisposable, IAsyncDisposable + { + Task Create(Stream stream, HttpRequestHeader header); + + /// + /// The Protocol of this endpoint. If the client sends a selection of protocols with + /// Sec-WebSocket-Protocol than this protocol must be a member of it to select + /// this endpoint. If this property is null the clients is required to send an empty + /// list (or doesn't send the Sec-WebSocket-Protocol header at all) to select + /// this endpoint. + /// + string? Protocol { get; } + } + + public abstract class WebSocketEndpoint : IWebSocketEndpoint + where T : WebSocketConnection + { + readonly List connections = new List(); + readonly SemaphoreSlim connectionLock = new SemaphoreSlim(1, 1); + + public abstract string? Protocol { get; } + + public void Dispose() + { + connectionLock.Wait(); + foreach (var connection in connections) + connection.Dispose(); + connections.Clear(); + connectionLock.Dispose(); + } + + public async ValueTask DisposeAsync() + { + await connectionLock.WaitAsync(); + foreach (var connection in connections) + await connection.DisposeAsync(); + connections.Clear(); + connectionLock.Dispose(); + } + + public async Task Create(Stream stream, HttpRequestHeader header) + { + var connection = CreateConnection(stream, header); + if (connection == null) + return null; + await connectionLock.WaitAsync(); + connections.Add(connection); + connectionLock.Release(); + connection.Closed += Connection_Closed; + return connection; + } + + private void Connection_Closed(object sender, EventArgs eventArgs) + { + if (sender is T connection) + { + _ = RemoveConnection(connection); + } + } + + protected abstract T? CreateConnection(Stream stream, HttpRequestHeader header); + + public async Task RemoveConnection(T connection) + { + await connectionLock.WaitAsync(); + connections.Remove(connection); + connectionLock.Release(); + connection.Closed -= Connection_Closed; + } + } +} diff --git a/MaxLib.WebServer/WebSocket/WebSocketService.cs b/MaxLib.WebServer/WebSocket/WebSocketService.cs new file mode 100644 index 0000000..b5ee9eb --- /dev/null +++ b/MaxLib.WebServer/WebSocket/WebSocketService.cs @@ -0,0 +1,110 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +#nullable enable + +namespace MaxLib.WebServer.WebSocket +{ + public class WebSocketService : WebService, IDisposable, IAsyncDisposable + { + public WebSocketService() + : base(ServerStage.ParseRequest) + { + } + + public ICollection Endpoints { get; } + = new List(); + + public void Add(WebSocketEndpoint endpoint) + where T : WebSocketConnection + { + Endpoints.Add(endpoint ?? throw new ArgumentNullException(nameof(endpoint))); + } + + public override bool CanWorkWith(WebProgressTask task) + { + return task.Request.GetHeader("Upgrade") == "websocket" && + (task.Request.GetHeader("Connection")?.ToLower().Contains("upgrade") ?? false); + } + + public void Dispose() + { + foreach (var endpoint in Endpoints) + endpoint.Dispose(); + } + + public async ValueTask DisposeAsync() + { + await Task.WhenAll(Endpoints.Select(async x => await x.DisposeAsync())); + } + + public override async Task ProgressTask(WebProgressTask task) + { + if (task.NetworkStream == null) + return; + + var protocols = (task.Request.GetHeader("Sec-WebSocket-Protocol")?.ToLower() ?? "") + .Split(new char[] { ' ', ',' }, StringSplitOptions.RemoveEmptyEntries); + + var key = task.Request.GetHeader("Sec-WebSocket-Key"); + var version = task.Request.GetHeader("Sec-WebSocket-Version"); // MUST be 13 according RFC 6455 + + if (key == null || version != "13") + { + task.Response.StatusCode = HttpStateCode.BadRequest; + task.Response.SetHeader("Sec-WebSocket-Version", "13"); + task.NextStage = ServerStage.CreateResponse; + return; + } + + var responseKey = Convert.ToBase64String( + System.Security.Cryptography.SHA1.Create().ComputeHash( + Encoding.UTF8.GetBytes( + $"{key.Trim()}258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + ) + ) + ); + + + foreach (var endpoint in Endpoints) + { + if (protocols.Length > 0 && (endpoint.Protocol == null || !protocols.Contains(endpoint.Protocol))) + { + continue; + } + + var connection = await endpoint.Create(task.NetworkStream, task.Request); + if (connection == null) + continue; + + task.Response.StatusCode = HttpStateCode.SwitchingProtocols; + task.Response.SetHeader( + ("Access-Control-Allow-Origin", "*"), + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", responseKey), + ("Sec-WebSocket-Protocol", endpoint.Protocol) + ); + + task.SwitchProtocols(async () => + { + if (System.Diagnostics.Debugger.IsAttached) + await connection.HandshakeFinished(); + else + try + { + await connection.HandshakeFinished(); + } + catch (Exception e) + { + WebServerLog.Add(ServerLogType.Error, GetType(), "handshake", $"handshake error: {e}"); + } + }); + task.NextStage = ServerStage.SendResponse; + } + } + } +}