diff --git a/eng/testing/xunit/xunit.targets b/eng/testing/xunit/xunit.targets index 6b048e6f6a9a4..e72ebd444ad83 100644 --- a/eng/testing/xunit/xunit.targets +++ b/eng/testing/xunit/xunit.targets @@ -6,6 +6,11 @@ Condition="'$(TargetFrameworkIdentifier)' == '.NETCoreApp'" /> + + true + true + + $(OutDir) diff --git a/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs b/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs new file mode 100644 index 0000000000000..8c38dec4df8eb --- /dev/null +++ b/src/libraries/Common/src/Interop/Windows/Advapi32/Interop.QueryServiceStatusEx.cs @@ -0,0 +1,34 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Win32.SafeHandles; +using System; +using System.Runtime.InteropServices; + +internal static partial class Interop +{ + internal static partial class Advapi32 + { + [StructLayout(LayoutKind.Sequential)] + internal struct SERVICE_STATUS_PROCESS + { + public int dwServiceType; + public int dwCurrentState; + public int dwControlsAccepted; + public int dwWin32ExitCode; + public int dwServiceSpecificExitCode; + public int dwCheckPoint; + public int dwWaitHint; + public int dwProcessId; + public int dwServiceFlags; + } + + private const int SC_STATUS_PROCESS_INFO = 0; + + [LibraryImport(Libraries.Advapi32, SetLastError = true)] + [return: MarshalAs(UnmanagedType.Bool)] + private static unsafe partial bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, int InfoLevel, SERVICE_STATUS_PROCESS* pStatus, int cbBufSize, out int pcbBytesNeeded); + + internal static unsafe bool QueryServiceStatusEx(SafeServiceHandle serviceHandle, SERVICE_STATUS_PROCESS* pStatus) => QueryServiceStatusEx(serviceHandle, SC_STATUS_PROCESS_INFO, pStatus, sizeof(SERVICE_STATUS_PROCESS), out _); + } +} diff --git a/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs b/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs index cde3ae0ac197e..c810603e6300a 100644 --- a/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs +++ b/src/libraries/Common/src/Interop/Windows/Interop.Errors.cs @@ -64,6 +64,8 @@ internal static partial class Errors internal const int ERROR_IO_PENDING = 0x3E5; internal const int ERROR_NO_TOKEN = 0x3f0; internal const int ERROR_SERVICE_DOES_NOT_EXIST = 0x424; + internal const int ERROR_EXCEPTION_IN_SERVICE = 0x428; + internal const int ERROR_PROCESS_ABORTED = 0x42B; internal const int ERROR_NO_UNICODE_TRANSLATION = 0x459; internal const int ERROR_DLL_INIT_FAILED = 0x45A; internal const int ERROR_COUNTER_TIMEOUT = 0x461; diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs index 164e60670fb67..642f770591d39 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/src/WindowsServiceLifetime.cs @@ -15,8 +15,10 @@ namespace Microsoft.Extensions.Hosting.WindowsServices public class WindowsServiceLifetime : ServiceBase, IHostLifetime { private readonly TaskCompletionSource _delayStart = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _serviceDispatcherStopped = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly ManualResetEventSlim _delayStop = new ManualResetEventSlim(); private readonly HostOptions _hostOptions; + private bool _serviceStopRequested; public WindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : this(environment, applicationLifetime, loggerFactory, optionsAccessor, Options.Options.Create(new WindowsServiceLifetimeOptions())) @@ -69,19 +71,30 @@ private void Run() { Run(this); // This blocks until the service is stopped. _delayStart.TrySetException(new InvalidOperationException("Stopped without starting")); + _serviceDispatcherStopped.TrySetResult(null); } catch (Exception ex) { _delayStart.TrySetException(ex); + _serviceDispatcherStopped.TrySetException(ex); } } - public Task StopAsync(CancellationToken cancellationToken) + /// + /// Called from to stop the service if not already stopped, and wait for the service dispatcher to exit. + /// Once this method returns the service is stopped and the process can be terminated at any time. + /// + public async Task StopAsync(CancellationToken cancellationToken) { - // Avoid deadlock where host waits for StopAsync before firing ApplicationStopped, - // and Stop waits for ApplicationStopped. - Task.Run(Stop, CancellationToken.None); - return Task.CompletedTask; + cancellationToken.ThrowIfCancellationRequested(); + + if (!_serviceStopRequested) + { + await Task.Run(Stop, cancellationToken).ConfigureAwait(false); + } + + // When the underlying service is stopped this will cause the ServiceBase.Run method to complete and return, which completes _serviceDispatcherStopped. + await _serviceDispatcherStopped.Task.ConfigureAwait(false); } // Called by base.Run when the service is ready to start. @@ -91,18 +104,28 @@ protected override void OnStart(string[] args) base.OnStart(args); } - // Called by base.Stop. This may be called multiple times by service Stop, ApplicationStopping, and StopAsync. - // That's OK because StopApplication uses a CancellationTokenSource and prevents any recursion. + /// + /// Executes when a Stop command is sent to the service by the Service Control Manager (SCM). + /// Triggers and waits for . + /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point. + /// protected override void OnStop() { + _serviceStopRequested = true; ApplicationLifetime.StopApplication(); // Wait for the host to shutdown before marking service as stopped. _delayStop.Wait(_hostOptions.ShutdownTimeout); base.OnStop(); } + /// + /// Executes when a Shutdown command is sent to the service by the Service Control Manager (SCM). + /// Triggers and waits for . + /// Shortly after this method returns, the Service will be marked as stopped in SCM and the process may exit at any point. + /// protected override void OnShutdown() { + _serviceStopRequested = true; ApplicationLifetime.StopApplication(); // Wait for the host to shutdown before marking service as stopped. _delayStop.Wait(_hostOptions.ShutdownTimeout); diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj index 93be9b87c967b..ee433d9207d1d 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/Microsoft.Extensions.Hosting.WindowsServices.Tests.csproj @@ -4,12 +4,45 @@ $(NetCoreAppCurrent)-windows;$(NetFrameworkMinimum) true + true + true + true + + + + + + + + + + + + + + + + diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/UseWindowsServiceTests.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/UseWindowsServiceTests.cs index 1fb2ade8a9407..c18d5037d665a 100644 --- a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/UseWindowsServiceTests.cs +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/UseWindowsServiceTests.cs @@ -2,7 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; -using System.IO; using System.Reflection; using System.ServiceProcess; using Microsoft.Extensions.DependencyInjection; @@ -30,6 +29,26 @@ public void DefaultsToOffOutsideOfService() Assert.IsType(lifetime); } + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void CanCreateService() + { + using var serviceTester = WindowsServiceTester.Create(() => + { + using IHost host = new HostBuilder() + .UseWindowsService() + .Build(); + host.Run(); + }); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + } + [Fact] public void ServiceCollectionExtensionMethodDefaultsToOffOutsideOfService() { @@ -66,7 +85,7 @@ public void ServiceCollectionExtensionMethodSetsEventLogSourceNameToApplicationN var builder = new HostApplicationBuilder(new HostApplicationBuilderSettings { ApplicationName = appName, - }); + }); // Emulate calling builder.Services.AddWindowsService() from inside a Windows service. AddWindowsServiceLifetime(builder.Services); @@ -82,7 +101,7 @@ public void ServiceCollectionExtensionMethodSetsEventLogSourceNameToApplicationN [Fact] public void ServiceCollectionExtensionMethodCanBeCalledOnDefaultConfiguration() { - var builder = new HostApplicationBuilder(); + var builder = new HostApplicationBuilder(); // Emulate calling builder.Services.AddWindowsService() from inside a Windows service. AddWindowsServiceLifetime(builder.Services); diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs new file mode 100644 index 0000000000000..06679b3c48459 --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceLifetimeTests.cs @@ -0,0 +1,338 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics; +using System.IO; +using System.ServiceProcess; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting.Internal; +using Microsoft.Extensions.Hosting.WindowsServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.Extensions.Hosting +{ + public class WindowsServiceLifetimeTests + { + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceStops() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + var applicationLifetime = new ApplicationLifetime(NullLogger.Instance); + using var lifetime = new WindowsServiceLifetime( + new HostingEnvironment(), + applicationLifetime, + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())); + + await lifetime.WaitForStartAsync(CancellationToken.None); + + // would normally occur here, but WindowsServiceLifetime does not depend on it. + // applicationLifetime.NotifyStarted(); + + // will be signaled by WindowsServiceLifetime when SCM stops the service. + applicationLifetime.ApplicationStopping.WaitHandle.WaitOne(); + + // required by WindowsServiceLifetime to identify that app has stopped. + applicationLifetime.NotifyStopped(); + + await lifetime.StopAsync(CancellationToken.None); + }); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + + var statusEx = serviceTester.QueryServiceStatusEx(); + var serviceProcess = Process.GetProcessById(statusEx.dwProcessId); + + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + serviceProcess.WaitForExit(); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + [SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework, ".NET Framework is missing the fix from https://github.com/dotnet/corefx/commit/3e68d791066ad0fdc6e0b81828afbd9df00dd7f8")] + public void ExceptionOnStartIsPropagated() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStart: new Exception("Should be thrown"))) + { + Assert.Equal(lifetime.ThrowOnStart, + await Assert.ThrowsAsync(async () => + await lifetime.WaitForStartAsync(CancellationToken.None))); + } + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_EXCEPTION_IN_SERVICE, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ExceptionOnStopIsPropagated() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + using (var lifetime = ThrowingWindowsServiceLifetime.Create(throwOnStop: new Exception("Should be thrown"))) + { + await lifetime.WaitForStartAsync(CancellationToken.None); + lifetime.ApplicationLifetime.NotifyStopped(); + Assert.Equal(lifetime.ThrowOnStop, + await Assert.ThrowsAsync( async () => + await lifetime.StopAsync(CancellationToken.None))); + } + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void CancelStopAsync() + { + using var serviceTester = WindowsServiceTester.Create(async () => + { + var applicationLifetime = new ApplicationLifetime(NullLogger.Instance); + using var lifetime = new WindowsServiceLifetime( + new HostingEnvironment(), + applicationLifetime, + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())); + await lifetime.WaitForStartAsync(CancellationToken.None); + + await Assert.ThrowsAsync(async () => await lifetime.StopAsync(new CancellationToken(true))); + }); + + serviceTester.Start(); + + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(Interop.Errors.ERROR_PROCESS_ABORTED, status.win32ExitCode); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceCanStopItself() + { + using (var serviceTester = WindowsServiceTester.Create(async () => + { + FileLogger.InitializeForTestCase(nameof(ServiceCanStopItself)); + using IHost host = new HostBuilder() + .ConfigureServices(services => + { + services.AddHostedService(); + services.AddSingleton(); + }) + .Build(); + + var applicationLifetime = host.Services.GetRequiredService(); + applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started")); + applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping")); + applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped")); + + FileLogger.Log("host.Start()"); + host.Start(); + + FileLogger.Log("host.Stop()"); + await host.StopAsync(); + FileLogger.Log("host.Stop() complete"); + })) + { + FileLogger.DeleteLog(nameof(ServiceCanStopItself)); + + // service should start cleanly + serviceTester.Start(); + + // service will proceed to stopped without any error + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + + } + + var logText = FileLogger.ReadLog(nameof(ServiceCanStopItself)); + Assert.Equal(""" + host.Start() + WindowsServiceLifetime.OnStart + BackgroundService.StartAsync + lifetime started + host.Stop() + lifetime stopping + BackgroundService.StopAsync + lifetime stopped + WindowsServiceLifetime.OnStop + host.Stop() complete + + """, logText); + } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsPrivilegedProcess))] + public void ServiceSequenceIsCorrect() + { + using (var serviceTester = WindowsServiceTester.Create(() => + { + FileLogger.InitializeForTestCase(nameof(ServiceSequenceIsCorrect)); + using IHost host = new HostBuilder() + .ConfigureServices(services => + { + services.AddHostedService(); + services.AddSingleton(); + }) + .Build(); + + var applicationLifetime = host.Services.GetRequiredService(); + applicationLifetime.ApplicationStarted.Register(() => FileLogger.Log($"lifetime started")); + applicationLifetime.ApplicationStopping.Register(() => FileLogger.Log($"lifetime stopping")); + applicationLifetime.ApplicationStopped.Register(() => FileLogger.Log($"lifetime stopped")); + + FileLogger.Log("host.Run()"); + host.Run(); + FileLogger.Log("host.Run() complete"); + })) + { + + FileLogger.DeleteLog(nameof(ServiceSequenceIsCorrect)); + + serviceTester.Start(); + serviceTester.WaitForStatus(ServiceControllerStatus.Running); + + var statusEx = serviceTester.QueryServiceStatusEx(); + var serviceProcess = Process.GetProcessById(statusEx.dwProcessId); + + // Give a chance for all asynchronous "started" events to be raised, these happen after the service status changes to started + Thread.Sleep(1000); + + serviceTester.Stop(); + serviceTester.WaitForStatus(ServiceControllerStatus.Stopped); + + var status = serviceTester.QueryServiceStatus(); + Assert.Equal(0, status.win32ExitCode); + + } + + var logText = FileLogger.ReadLog(nameof(ServiceSequenceIsCorrect)); + Assert.Equal(""" + host.Run() + WindowsServiceLifetime.OnStart + BackgroundService.StartAsync + lifetime started + WindowsServiceLifetime.OnStop + lifetime stopping + BackgroundService.StopAsync + lifetime stopped + host.Run() complete + + """, logText); + + } + + public class LoggingWindowsServiceLifetime : WindowsServiceLifetime + { + public LoggingWindowsServiceLifetime(IHostEnvironment environment, IHostApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : + base(environment, applicationLifetime, loggerFactory, optionsAccessor) + { } + + protected override void OnStart(string[] args) + { + FileLogger.Log("WindowsServiceLifetime.OnStart"); + base.OnStart(args); + } + + protected override void OnStop() + { + FileLogger.Log("WindowsServiceLifetime.OnStop"); + base.OnStop(); + } + } + + public class ThrowingWindowsServiceLifetime : WindowsServiceLifetime + { + public static ThrowingWindowsServiceLifetime Create(Exception throwOnStart = null, Exception throwOnStop = null) => + new ThrowingWindowsServiceLifetime( + new HostingEnvironment(), + new ApplicationLifetime(NullLogger.Instance), + NullLoggerFactory.Instance, + new OptionsWrapper(new HostOptions())) + { + ThrowOnStart = throwOnStart, + ThrowOnStop = throwOnStop + }; + + public ThrowingWindowsServiceLifetime(IHostEnvironment environment, ApplicationLifetime applicationLifetime, ILoggerFactory loggerFactory, IOptions optionsAccessor) : + base(environment, applicationLifetime, loggerFactory, optionsAccessor) + { + ApplicationLifetime = applicationLifetime; + } + + public ApplicationLifetime ApplicationLifetime { get; } + + public Exception ThrowOnStart { get; set; } + protected override void OnStart(string[] args) + { + if (ThrowOnStart != null) + { + throw ThrowOnStart; + } + base.OnStart(args); + } + + public Exception ThrowOnStop { get; set; } + protected override void OnStop() + { + if (ThrowOnStop != null) + { + throw ThrowOnStop; + } + base.OnStop(); + } + } + + public class LoggingBackgroundService : BackgroundService + { +#pragma warning disable CS1998 // Async method lacks 'await' operators and will run synchronously + protected override async Task ExecuteAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.ExecuteAsync"); + public override async Task StartAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StartAsync"); + public override async Task StopAsync(CancellationToken stoppingToken) => FileLogger.Log("BackgroundService.StopAsync"); +#pragma warning restore CS1998 // Async method lacks 'await' operators and will run synchronously + } + + static class FileLogger + { + static string _fileName; + + public static void InitializeForTestCase(string testCaseName) + { + Assert.Null(_fileName); + _fileName = GetLogForTestCase(testCaseName); + } + + private static string GetLogForTestCase(string testCaseName) => Path.Combine(AppContext.BaseDirectory, $"{testCaseName}.log"); + public static void DeleteLog(string testCaseName) => File.Delete(GetLogForTestCase(testCaseName)); + public static string ReadLog(string testCaseName) => File.ReadAllText(GetLogForTestCase(testCaseName)); + public static void Log(string message) + { + Assert.NotNull(_fileName); + lock (_fileName) + { + File.AppendAllText(_fileName, message + Environment.NewLine); + } + } + } + } +} diff --git a/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs new file mode 100644 index 0000000000000..895b4a87108eb --- /dev/null +++ b/src/libraries/Microsoft.Extensions.Hosting.WindowsServices/tests/WindowsServiceTester.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.ServiceProcess; +using System.Threading.Tasks; +using Microsoft.DotNet.RemoteExecutor; +using Microsoft.Win32.SafeHandles; +using Xunit; + +namespace Microsoft.Extensions.Hosting +{ + public class WindowsServiceTester : ServiceController + { + private WindowsServiceTester(SafeServiceHandle serviceHandle, RemoteInvokeHandle remoteInvokeHandle, string serviceName) : base(serviceName) + { + _serviceHandle = serviceHandle; + _remoteInvokeHandle = remoteInvokeHandle; + } + + private SafeServiceHandle _serviceHandle; + private RemoteInvokeHandle _remoteInvokeHandle; + + public new void Start() + { + Start(Array.Empty()); + } + + public new void Start(string[] args) + { + base.Start(args); + + // get the process + _remoteInvokeHandle.Process.Dispose(); + _remoteInvokeHandle.Process = null; + + var statusEx = QueryServiceStatusEx(); + try + { + _remoteInvokeHandle.Process = Process.GetProcessById(statusEx.dwProcessId); + // fetch the process handle so that we can get the exit code later. + var _ = _remoteInvokeHandle.Process.SafeHandle; + } + catch (ArgumentException) + { } + } + + public TimeSpan WaitForStatusTimeout { get; set; } = TimeSpan.FromSeconds(30); + + public new void WaitForStatus(ServiceControllerStatus desiredStatus) => + WaitForStatus(desiredStatus, WaitForStatusTimeout); + + public new void WaitForStatus(ServiceControllerStatus desiredStatus, TimeSpan timeout) + { + base.WaitForStatus(desiredStatus, timeout); + + Assert.Equal(Status, desiredStatus); + } + + // the following overloads are necessary to ensure the compiler will produce the correct signature from a lambda. + public static WindowsServiceTester Create(Func serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Func> serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Func serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + public static WindowsServiceTester Create(Action serviceMain, [CallerMemberName] string serviceName = null) => Create(RemoteExecutor.Invoke(serviceMain, remoteInvokeOptions), serviceName); + + private static RemoteInvokeOptions remoteInvokeOptions = new RemoteInvokeOptions() { Start = false }; + + private static WindowsServiceTester Create(RemoteInvokeHandle remoteInvokeHandle, string serviceName) + { + // create remote executor commandline arguments + var startInfo = remoteInvokeHandle.Process.StartInfo; + string commandLine = startInfo.FileName + " " + startInfo.Arguments; + + // install the service + using (var serviceManagerHandle = new SafeServiceHandle(Interop.Advapi32.OpenSCManager(null, null, Interop.Advapi32.ServiceControllerOptions.SC_MANAGER_ALL))) + { + if (serviceManagerHandle.IsInvalid) + { + throw new InvalidOperationException(); + } + + // delete existing service if it exists + using (var existingServiceHandle = new SafeServiceHandle(Interop.Advapi32.OpenService(serviceManagerHandle, serviceName, Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL))) + { + if (!existingServiceHandle.IsInvalid) + { + Interop.Advapi32.DeleteService(existingServiceHandle); + } + } + + var serviceHandle = new SafeServiceHandle( + Interop.Advapi32.CreateService(serviceManagerHandle, + serviceName, + $"{nameof(WindowsServiceTester)} {serviceName} test service", + Interop.Advapi32.ServiceAccessOptions.ACCESS_TYPE_ALL, + Interop.Advapi32.ServiceTypeOptions.SERVICE_WIN32_OWN_PROCESS, + (int)ServiceStartMode.Manual, + Interop.Advapi32.ServiceStartErrorModes.ERROR_CONTROL_NORMAL, + commandLine, + loadOrderGroup: null, + pTagId: IntPtr.Zero, + dependencies: null, + servicesStartName: null, + password: null)); + + if (serviceHandle.IsInvalid) + { + throw new Win32Exception(); + } + + return new WindowsServiceTester(serviceHandle, remoteInvokeHandle, serviceName); + } + } + + internal unsafe Interop.Advapi32.SERVICE_STATUS QueryServiceStatus() + { + Interop.Advapi32.SERVICE_STATUS status = default; + bool success = Interop.Advapi32.QueryServiceStatus(_serviceHandle, &status); + if (!success) + { + throw new Win32Exception(); + } + return status; + } + + internal unsafe Interop.Advapi32.SERVICE_STATUS_PROCESS QueryServiceStatusEx() + { + Interop.Advapi32.SERVICE_STATUS_PROCESS status = default; + bool success = Interop.Advapi32.QueryServiceStatusEx(_serviceHandle, &status); + if (!success) + { + throw new Win32Exception(); + } + return status; + } + + protected override void Dispose(bool disposing) + { + if (_remoteInvokeHandle != null) + { + _remoteInvokeHandle.Dispose(); + } + + if (!_serviceHandle.IsInvalid) + { + // delete the temporary test service + Interop.Advapi32.DeleteService(_serviceHandle); + _serviceHandle.Close(); + } + } + } +} diff --git a/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs b/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs index ddc3e2ea601c5..59123a336ec3a 100644 --- a/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs +++ b/src/libraries/System.ServiceProcess.ServiceController/src/System/ServiceProcess/ServiceBase.cs @@ -31,6 +31,7 @@ public class ServiceBase : Component private bool _commandPropsFrozen; // set to true once we've use the Can... properties. private bool _disposed; private bool _initialized; + private object _stopLock = new object(); private EventLog? _eventLog; /// @@ -501,27 +502,34 @@ private void DeferredSessionChange(int eventType, int sessionId) // This is a problem when multiple services are hosted in a single process. private unsafe void DeferredStop() { - fixed (SERVICE_STATUS* pStatus = &_status) + lock (_stopLock) { - int previousState = _status.currentState; - - _status.checkPoint = 0; - _status.waitHint = 0; - _status.currentState = ServiceControlStatus.STATE_STOP_PENDING; - SetServiceStatus(_statusHandle, pStatus); - try + // never call SetServiceStatus again after STATE_STOPPED is set. + if (_status.currentState != ServiceControlStatus.STATE_STOPPED) { - OnStop(); - WriteLogEntry(SR.StopSuccessful); - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); - } - catch (Exception e) - { - _status.currentState = previousState; - SetServiceStatus(_statusHandle, pStatus); - WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error); - throw; + fixed (SERVICE_STATUS* pStatus = &_status) + { + int previousState = _status.currentState; + + _status.checkPoint = 0; + _status.waitHint = 0; + _status.currentState = ServiceControlStatus.STATE_STOP_PENDING; + SetServiceStatus(_statusHandle, pStatus); + try + { + OnStop(); + WriteLogEntry(SR.StopSuccessful); + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } + catch (Exception e) + { + _status.currentState = previousState; + SetServiceStatus(_statusHandle, pStatus); + WriteLogEntry(SR.Format(SR.StopFailed, e), EventLogEntryType.Error); + throw; + } + } } } } @@ -533,14 +541,17 @@ private unsafe void DeferredShutdown() OnShutdown(); WriteLogEntry(SR.ShutdownOK); - if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING) + lock (_stopLock) { - fixed (SERVICE_STATUS* pStatus = &_status) + if (_status.currentState == ServiceControlStatus.STATE_PAUSED || _status.currentState == ServiceControlStatus.STATE_RUNNING) { - _status.checkPoint = 0; - _status.waitHint = 0; - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); + fixed (SERVICE_STATUS* pStatus = &_status) + { + _status.checkPoint = 0; + _status.waitHint = 0; + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } } } } @@ -654,7 +665,7 @@ private void Initialize(bool multipleServices) { if (!_initialized) { - //Cannot register the service with NT service manatger if the object has been disposed, since finalization has been suppressed. + //Cannot register the service with NT service manager if the object has been disposed, since finalization has been suppressed. if (_disposed) throw new ObjectDisposedException(GetType().Name); @@ -923,8 +934,14 @@ public unsafe void ServiceMainCallback(int argCount, IntPtr argPointer) { string errorMessage = new Win32Exception().Message; WriteLogEntry(SR.Format(SR.StartFailed, errorMessage), EventLogEntryType.Error); - _status.currentState = ServiceControlStatus.STATE_STOPPED; - SetServiceStatus(_statusHandle, pStatus); + lock (_stopLock) + { + if (_status.currentState != ServiceControlStatus.STATE_STOPPED) + { + _status.currentState = ServiceControlStatus.STATE_STOPPED; + SetServiceStatus(_statusHandle, pStatus); + } + } } } }