diff --git a/SshAgentLib/WindowsOpenSshPipe.cs b/SshAgentLib/WindowsOpenSshPipe.cs index f735c72..9f1ad89 100644 --- a/SshAgentLib/WindowsOpenSshPipe.cs +++ b/SshAgentLib/WindowsOpenSshPipe.cs @@ -1,7 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2017,2022 David Lechner +// Copyright (c) 2017,2022-2023 David Lechner using System; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.IO.Pipes; @@ -24,7 +25,7 @@ public sealed class WindowsOpenSshPipe : IDisposable private const int bufferSize = 5 * 1024; // 5 KiB private readonly CancellationTokenSource cancelSource; - private readonly Task listenerTask; + private readonly List listenerTasks = new List(); /// /// Creates a new Windows OpenSSH Agent pipe. @@ -43,7 +44,7 @@ public WindowsOpenSshPipe(ConnectionHandler connectionHandler) } cancelSource = new CancellationTokenSource(); - listenerTask = RunListenerAsync(connectionHandler, cancelSource.Token); + listenerTasks.Add(RunListenerAsync(connectionHandler, cancelSource.Token)); Debug.WriteLine("Started new Windows OpenSSH Pipe"); } @@ -53,91 +54,105 @@ private static extern bool GetNamedPipeClientProcessId( out uint ClientProcessId ); - private static async Task RunListenerAsync( + private async Task RunListenerAsync( ConnectionHandler connectionHandler, CancellationToken cancellationToken ) { - while (!cancellationToken.IsCancellationRequested) - { - var security = new PipeSecurity(); - - // Limit access to the current user. This also has the effect - // of allowing non-elevated processes to access the agent when - // it is running as an elevated process. - security.AddAccessRule( - new PipeAccessRule( - WindowsIdentity.GetCurrent().User, - PipeAccessRights.ReadWrite, - AccessControlType.Allow - ) - ); - - using ( - var server = new NamedPipeServerStream( - agentPipeId, - PipeDirection.InOut, - NamedPipeServerStream.MaxAllowedServerInstances, - PipeTransmissionMode.Byte, - PipeOptions.WriteThrough | PipeOptions.Asynchronous, - bufferSize, - bufferSize, - security - ) + var security = new PipeSecurity(); + + // Limit access to the current user. This also has the effect + // of allowing non-elevated processes to access the agent when + // it is running as an elevated process. + security.AddAccessRule( + new PipeAccessRule( + WindowsIdentity.GetCurrent().User, + PipeAccessRights.ReadWrite | PipeAccessRights.CreateNewInstance, + AccessControlType.Allow ) + ); + + using ( + var server = new NamedPipeServerStream( + agentPipeId, + PipeDirection.InOut, + NamedPipeServerStream.MaxAllowedServerInstances, + PipeTransmissionMode.Byte, + PipeOptions.WriteThrough | PipeOptions.Asynchronous, + bufferSize, + bufferSize, + security + ) + ) + { + await server.WaitForConnectionAsync(cancellationToken).ConfigureAwait(false); + Debug.WriteLine("Received Windows OpenSSH Pipe client connection"); + + lock (listenerTasks) { - await server.WaitForConnectionAsync(cancellationToken).ConfigureAwait(false); - Debug.WriteLine("Received Windows OpenSSH Pipe client connection"); - - if ( - !GetNamedPipeClientProcessId( - server.SafePipeHandle.DangerousGetHandle(), - out var clientPid - ) - ) + if (!cancellationToken.IsCancellationRequested) { - throw new IOException( - "Failed to get client PID", - Marshal.GetHRForLastWin32Error() - ); + // start a new listener for the next connection + listenerTasks.Add(RunListenerAsync(connectionHandler, cancellationToken)); } + } - try - { - var proc = Process.GetProcessById((int)clientPid); + if ( + !GetNamedPipeClientProcessId( + server.SafePipeHandle.DangerousGetHandle(), + out var clientPid + ) + ) + { + throw new IOException( + "Failed to get client PID", + Marshal.GetHRForLastWin32Error() + ); + } - using (cancellationToken.Register(() => server.Disconnect())) - { - await Task.Run(() => connectionHandler(server, proc), cancellationToken) - .ConfigureAwait(false); - } - } - catch (ArgumentException) + try + { + var proc = Process.GetProcessById((int)clientPid); + + using (cancellationToken.Register(() => server.Disconnect())) { - // The SSH client process is gone! Nothing we can do ... - Debug.WriteLine($"OpenSSH pipe client already exited (PID: {clientPid})"); + await Task.Run(() => connectionHandler(server, proc), cancellationToken) + .ConfigureAwait(false); } } + catch (ArgumentException) + { + // The SSH client process is gone! Nothing we can do ... + Debug.WriteLine($"OpenSSH pipe client already exited (PID: {clientPid})"); + } } } public void Dispose() { - // allow multiple calls to dispose - if (listenerTask.IsCompleted) + lock (listenerTasks) { - return; - } + // allow multiple calls to dispose + if (listenerTasks.Count == 0) + { + return; + } - cancelSource.Cancel(); + cancelSource.Cancel(); - try - { - listenerTask.Wait(); - } - catch (AggregateException) - { - // expected since we just canceled the task + foreach (var task in listenerTasks) + { + try + { + task.Wait(); + } + catch (AggregateException) + { + // expected since we just canceled the task + } + } + + listenerTasks.Clear(); } Debug.WriteLine("Stopped Windows OpenSSH Pipe");