diff --git a/.github/workflows/ci-e2e.yml b/.github/workflows/ci-e2e.yml index ace74939..472e33d0 100644 --- a/.github/workflows/ci-e2e.yml +++ b/.github/workflows/ci-e2e.yml @@ -30,8 +30,15 @@ jobs: - name: Install regctl uses: iarekylew00t/regctl-installer@v1 - - name: Getr Version of Machine.py - run: echo "MACHINE_PY_IMAGE=ghcr.io/sillsdev/machine.py:$(regctl image config ghcr.io/sillsdev/machine.py | jq -r ".config.Labels[\"org.opencontainers.image.version\"]")" >> $GITHUB_ENV + - name: Set proper version of Machine.py + run: | + export MACHINE_PY_IMAGE=ghcr.io/sillsdev/machine.py:$(regctl image config ghcr.io/sillsdev/machine.py | jq -r ".config.Labels[\"org.opencontainers.image.version\"]") && \ + echo "MACHINE_PY_IMAGE=$MACHINE_PY_IMAGE" >> $GITHUB_ENV && \ + echo "MACHINE_PY_CPU_IMAGE=$MACHINE_PY_IMAGE.cpu_only" >> $GITHUB_ENV + + - name: Confirm proper version of Machine.py + run: | + echo $MACHINE_PY_IMAGE $MACHINE_PY_CPU_IMAGE - name: Setup .NET uses: actions/setup-dotnet@v3 @@ -50,6 +57,9 @@ jobs: - name: Test run: dotnet test --no-build --verbosity normal --filter "TestCategory!=slow&TestCategory=E2E" --collect:"Xplat Code Coverage" + - name: Debug network again + run: docker ps -a && docker logs --since 10m serval_cntr && docker logs --since 10m echo_cntr && docker logs --since 10m machine-engine-cntr && docker logs --since 10m serval-mongo-1 && docker logs --since 10m machine-job-cntr + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 env: diff --git a/src/Machine/src/Serval.Machine.JobServer/appsettings.json b/src/Machine/src/Serval.Machine.JobServer/appsettings.json index d1934bc0..d5aada0d 100644 --- a/src/Machine/src/Serval.Machine.JobServer/appsettings.json +++ b/src/Machine/src/Serval.Machine.JobServer/appsettings.json @@ -21,7 +21,7 @@ { "TranslationEngineType": "SmtTransfer", "ModelType": "thot", - "Queue": "jobs_backlog", + "Queue": "jobs_backlog.cpu_only", "DockerImage": "ghcr.io/sillsdev/machine.py:latest" } ], diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/DistributedReaderWriterLockOptions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/DistributedReaderWriterLockOptions.cs new file mode 100644 index 00000000..62817dbc --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/DistributedReaderWriterLockOptions.cs @@ -0,0 +1,8 @@ +namespace Serval.Machine.Shared.Configuration; + +public class DistributedReaderWriterLockOptions +{ + public const string Key = "DistributedReaderWriterLock"; + + public TimeSpan DefaultLifetime { get; set; } = TimeSpan.FromSeconds(56); // must be less than DefaultHttpRequestTimeout +} diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs index d67afb90..5a577cb5 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IMachineBuilderExtensions.cs @@ -49,6 +49,24 @@ public static IMachineBuilder AddClearMLOptions(this IMachineBuilder builder, IC return builder; } + public static IMachineBuilder AddDistributedReaderWriterLockOptions( + this IMachineBuilder build, + Action configureOptions + ) + { + build.Services.Configure(configureOptions); + return build; + } + + public static IMachineBuilder AddDistributedReaderWriterLockOptions( + this IMachineBuilder build, + IConfiguration config + ) + { + build.Services.Configure(config); + return build; + } + public static IMachineBuilder AddMessageOutboxOptions( this IMachineBuilder builder, Action configureOptions @@ -360,6 +378,7 @@ public static IMachineBuilder AddServalTranslationEngineService( { options.Interceptors.Add(); options.Interceptors.Add(); + options.Interceptors.Add(); }); builder.AddServalPlatformService(connectionString); diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/IServiceCollectionExtensions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/IServiceCollectionExtensions.cs index 7463e6ac..9ae176d8 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/IServiceCollectionExtensions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/IServiceCollectionExtensions.cs @@ -28,6 +28,7 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf builder.AddSharedFileOptions(o => { }); builder.AddSmtTransferEngineOptions(o => { }); builder.AddClearMLOptions(o => { }); + builder.AddDistributedReaderWriterLockOptions(o => { }); builder.AddBuildJobOptions(o => { }); builder.AddMessageOutboxOptions(o => { }); } @@ -37,6 +38,9 @@ public static IMachineBuilder AddMachine(this IServiceCollection services, IConf builder.AddSharedFileOptions(configuration.GetSection(SharedFileOptions.Key)); builder.AddSmtTransferEngineOptions(configuration.GetSection(SmtTransferEngineOptions.Key)); builder.AddClearMLOptions(configuration.GetSection(ClearMLOptions.Key)); + builder.AddDistributedReaderWriterLockOptions( + configuration.GetSection(DistributedReaderWriterLockOptions.Key) + ); builder.AddBuildJobOptions(configuration.GetSection(BuildJobOptions.Key)); builder.AddMessageOutboxOptions(configuration.GetSection(MessageOutboxOptions.Key)); } diff --git a/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs b/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs index 15002604..58a425ff 100644 --- a/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs +++ b/src/Machine/src/Serval.Machine.Shared/Configuration/SmtTransferEngineOptions.cs @@ -7,4 +7,6 @@ public class SmtTransferEngineOptions public string EnginesDir { get; set; } = "translation_engines"; public TimeSpan EngineCommitFrequency { get; set; } = TimeSpan.FromMinutes(5); public TimeSpan InactiveEngineTimeout { get; set; } = TimeSpan.FromMinutes(10); + public TimeSpan SaveModelTimeout { get; set; } = TimeSpan.FromMinutes(5); + public TimeSpan EngineCommitTimeout { get; set; } = TimeSpan.FromMinutes(2); } diff --git a/src/Machine/src/Serval.Machine.Shared/Models/Lock.cs b/src/Machine/src/Serval.Machine.Shared/Models/Lock.cs index 39ceae87..505126d2 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/Lock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/Lock.cs @@ -3,6 +3,7 @@ public record Lock { public required string Id { get; init; } - public DateTime? ExpiresAt { get; init; } + + public DateTime ExpiresAt { get; init; } public required string HostId { get; init; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Models/RWLock.cs b/src/Machine/src/Serval.Machine.Shared/Models/RWLock.cs index 2271aa9b..1a462602 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/RWLock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/RWLock.cs @@ -11,15 +11,14 @@ public record RWLock : IEntity public bool IsAvailableForReading() { var now = DateTime.UtcNow; - return (WriterLock is null || WriterLock.ExpiresAt is not null && WriterLock.ExpiresAt <= now) - && WriterQueue.Count == 0; + return (WriterLock is null || WriterLock.ExpiresAt <= now) && WriterQueue.Count == 0; } public bool IsAvailableForWriting(string? lockId = null) { var now = DateTime.UtcNow; - return (WriterLock is null || WriterLock.ExpiresAt is not null && WriterLock.ExpiresAt <= now) - && !ReaderLocks.Any(l => l.ExpiresAt is null || l.ExpiresAt > now) + return (WriterLock is null || WriterLock.ExpiresAt <= now) + && !ReaderLocks.Any(l => l.ExpiresAt > now) && (lockId is null || WriterQueue.Count > 0 && WriterQueue[0].Id == lockId); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs index 80b1f648..e3143a3c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs +++ b/src/Machine/src/Serval.Machine.Shared/Models/TranslationEngine.cs @@ -11,4 +11,5 @@ public record TranslationEngine : IEntity public required bool IsModelPersisted { get; init; } public int BuildRevision { get; init; } public Build? CurrentBuild { get; init; } + public bool? CollectTrainSegmentPairs { get; init; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj b/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj index 4ea74e68..fc1cd005 100644 --- a/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj +++ b/src/Machine/src/Serval.Machine.Shared/Serval.Machine.Shared.csproj @@ -36,9 +36,9 @@ - - - + + + diff --git a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs index 244aa04a..da670439 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/BuildJobService.cs @@ -59,6 +59,7 @@ public async Task DeleteEngineAsync(string engineId, CancellationToken cancellat public async Task StartBuildJobAsync( BuildJobRunnerType runnerType, + TranslationEngineType engineType, string engineId, string buildId, BuildStage stage, @@ -67,18 +68,9 @@ public async Task StartBuildJobAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.GetAsync( - e => - e.EngineId == engineId - && (e.CurrentBuild == null || e.CurrentBuild.JobState != BuildJobState.Canceling), - cancellationToken - ); - if (engine is null) - return false; - IBuildJobRunner runner = _runners[runnerType]; string jobId = await runner.CreateJobAsync( - engine.Type, + engineType, engineId, buildId, stage, @@ -88,8 +80,17 @@ public async Task StartBuildJobAsync( ); try { - await _engines.UpdateAsync( - e => e.EngineId == engineId, + TranslationEngine? engine = await _engines.UpdateAsync( + e => + e.EngineId == engineId + && ( + (stage == BuildStage.Preprocess && e.CurrentBuild == null) + || ( + stage != BuildStage.Preprocess + && e.CurrentBuild != null + && e.CurrentBuild.JobState != BuildJobState.Canceling + ) + ), u => u.Set( e => e.CurrentBuild, @@ -105,6 +106,11 @@ await _engines.UpdateAsync( ), cancellationToken: cancellationToken ); + if (engine is null) + { + await runner.DeleteJobAsync(jobId, CancellationToken.None); + return false; + } await runner.EnqueueJobAsync(jobId, engine.Type, cancellationToken); return true; } @@ -120,44 +126,36 @@ await _engines.UpdateAsync( CancellationToken cancellationToken = default ) { - TranslationEngine? engine = await _engines.GetAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - cancellationToken + // cancel a job that hasn't started yet + TranslationEngine? engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Pending, + u => + { + u.Unset(b => b.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); + }, + returnOriginal: true, + cancellationToken: cancellationToken ); - if (engine is null || engine.CurrentBuild is null) - return (null, BuildJobState.None); - - IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; - - if (engine.CurrentBuild.JobState is BuildJobState.Pending) + if (engine is not null && engine.CurrentBuild is not null) { - // cancel a job that hasn't started yet - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - u => u.Unset(b => b.CurrentBuild), - returnOriginal: true, - cancellationToken: cancellationToken - ); - if (engine is not null && engine.CurrentBuild is not null) - { - // job will be deleted from the queue - await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); - return (engine.CurrentBuild.BuildId, BuildJobState.None); - } + // job will be deleted from the queue + IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.None); } - else if (engine.CurrentBuild.JobState is BuildJobState.Active) + + // cancel a job that is already running + engine = await _engines.UpdateAsync( + e => e.EngineId == engineId && e.CurrentBuild != null && e.CurrentBuild.JobState == BuildJobState.Active, + u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), + cancellationToken: cancellationToken + ); + if (engine is not null && engine.CurrentBuild is not null) { - // cancel a job that is already running - engine = await _engines.UpdateAsync( - e => e.EngineId == engineId && e.CurrentBuild != null, - u => u.Set(e => e.CurrentBuild!.JobState, BuildJobState.Canceling), - cancellationToken: cancellationToken - ); - if (engine is not null && engine.CurrentBuild is not null) - { - await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); - return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); - } + IBuildJobRunner runner = _runners[engine.CurrentBuild.BuildJobRunner]; + await runner.StopJobAsync(engine.CurrentBuild.JobId, CancellationToken.None); + return (engine.CurrentBuild.BuildId, BuildJobState.Canceling); } return (null, BuildJobState.None); @@ -193,6 +191,7 @@ public Task BuildJobFinishedAsync( u => { u.Unset(e => e.CurrentBuild); + u.Set(e => e.CollectTrainSegmentPairs, false); if (buildComplete) u.Inc(e => e.BuildRevision); }, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs index f577fdce..c527a03f 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ClearMLMonitorService.cs @@ -85,7 +85,6 @@ await _clearMLService.GetTasksForQueueAsync(_queuePerEngineType[engineType], can var dataAccessContext = scope.ServiceProvider.GetRequiredService(); var platformService = scope.ServiceProvider.GetRequiredService(); - var lockFactory = scope.ServiceProvider.GetRequiredService(); foreach (TranslationEngine engine in trainingEngines) { if (engine.CurrentBuild is null || !tasks.TryGetValue(engine.CurrentBuild.JobId, out ClearMLTask? task)) @@ -119,7 +118,6 @@ or ClearMLTaskStatus.Completed { bool canceled = !await TrainJobStartedAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -159,8 +157,8 @@ await UpdateTrainJobStatus( cancellationToken ); bool canceling = !await TrainJobCompletedAsync( - lockFactory, buildJobService, + engine.Type, engine.EngineId, engine.CurrentBuild.BuildId, (int)GetMetric(task, SummaryMetric, TrainCorpusSizeVariant), @@ -172,7 +170,6 @@ await UpdateTrainJobStatus( { await TrainJobCanceledAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -187,7 +184,6 @@ await TrainJobCanceledAsync( { await TrainJobCanceledAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -201,7 +197,6 @@ await TrainJobCanceledAsync( { await TrainJobFaultedAsync( dataAccessContext, - lockFactory, buildJobService, platformService, engine.EngineId, @@ -223,7 +218,6 @@ await TrainJobFaultedAsync( private async Task TrainJobStartedAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -231,29 +225,24 @@ private async Task TrainJobStartedAsync( CancellationToken cancellationToken = default ) { - bool success; - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - success = await dataAccessContext.WithTransactionAsync( - async (ct) => - { - if (!await buildJobService.BuildJobStartedAsync(engineId, buildId, ct)) - return false; - await platformService.BuildStartedAsync(buildId, CancellationToken.None); - return true; - }, - cancellationToken: cancellationToken - ); - } + bool success = await dataAccessContext.WithTransactionAsync( + async (ct) => + { + if (!await buildJobService.BuildJobStartedAsync(engineId, buildId, ct)) + return false; + await platformService.BuildStartedAsync(buildId, CancellationToken.None); + return true; + }, + cancellationToken: cancellationToken + ); await UpdateTrainJobStatus(platformService, buildId, new ProgressStatus(0), 0, cancellationToken); _logger.LogInformation("Build started ({BuildId})", buildId); return success; } private async Task TrainJobCompletedAsync( - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, + TranslationEngineType engineType, string engineId, string buildId, int corpusSize, @@ -264,19 +253,16 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - return await buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Postprocess, - (corpusSize, confidence), - buildOptions, - cancellationToken - ); - } + return await buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + engineType, + engineId, + buildId, + BuildStage.Postprocess, + (corpusSize, confidence), + buildOptions, + cancellationToken + ); } finally { @@ -286,7 +272,6 @@ CancellationToken cancellationToken private async Task TrainJobFaultedAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -297,23 +282,19 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await dataAccessContext.WithTransactionAsync( - async (ct) => - { - await platformService.BuildFaultedAsync(buildId, message, ct); - await buildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: cancellationToken - ); - } + await dataAccessContext.WithTransactionAsync( + async (ct) => + { + await platformService.BuildFaultedAsync(buildId, message, ct); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: cancellationToken + ); _logger.LogError("Build faulted ({BuildId}). Error: {ErrorMessage}", buildId, message); } finally @@ -324,7 +305,6 @@ await buildJobService.BuildJobFinishedAsync( private async Task TrainJobCanceledAsync( IDataAccessContext dataAccessContext, - IDistributedReaderWriterLockFactory lockFactory, IBuildJobService buildJobService, IPlatformService platformService, string engineId, @@ -334,23 +314,19 @@ CancellationToken cancellationToken { try { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await dataAccessContext.WithTransactionAsync( - async (ct) => - { - await platformService.BuildCanceledAsync(buildId, ct); - await buildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: cancellationToken - ); - } + await dataAccessContext.WithTransactionAsync( + async (ct) => + { + await platformService.BuildCanceledAsync(buildId, ct); + await buildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: cancellationToken + ); _logger.LogInformation("Build canceled ({BuildId})", buildId); } finally diff --git a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs index 7ea8679f..d5a64e28 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLock.cs @@ -1,20 +1,71 @@ namespace Serval.Machine.Shared.Services; -public class DistributedReaderWriterLock(string hostId, IRepository locks, IIdGenerator idGenerator, string id) - : IDistributedReaderWriterLock +public class DistributedReaderWriterLock( + string hostId, + IRepository locks, + IIdGenerator idGenerator, + string id, + DistributedReaderWriterLockOptions lockOptions +) : IDistributedReaderWriterLock { private readonly string _hostId = hostId; private readonly IRepository _locks = locks; private readonly IIdGenerator _idGenerator = idGenerator; private readonly string _id = id; + private readonly DistributedReaderWriterLockOptions _lockOptions = lockOptions; - public async Task ReaderLockAsync( - TimeSpan? lifetime = default, + public Task ReaderLockAsync( + Func action, + TimeSpan? lifetime = null, CancellationToken cancellationToken = default ) { + return ReaderLockAsync( + async ct => + { + await action(ct); + return null; + }, + lifetime, + cancellationToken + ); + } + + public Task WriterLockAsync( + Func action, + TimeSpan? lifetime = null, + CancellationToken cancellationToken = default + ) + { + return WriterLockAsync( + async ct => + { + await action(ct); + return null; + }, + lifetime, + cancellationToken + ); + } + + public async Task ReaderLockAsync( + Func> action, + TimeSpan? lifetime = null, + CancellationToken cancellationToken = default + ) + { + if (lifetime < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException( + nameof(lifetime), + "The lifetime must be greater than or equal to zero." + ); + } + + TimeSpan resolvedLifetime = lifetime ?? _lockOptions.DefaultLifetime; string lockId = _idGenerator.GenerateId(); - if (!await TryAcquireReaderLock(lockId, lifetime, cancellationToken)) + (bool acquired, DateTime expiresAt) = await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken); + if (!acquired) { using ISubscription sub = await _locks.SubscribeAsync(rwl => rwl.Id == _id, cancellationToken); do @@ -32,18 +83,47 @@ public async Task ReaderLockAsync( if (timeout != TimeSpan.Zero) await sub.WaitForChangeAsync(timeout, cancellationToken); } - } while (!await TryAcquireReaderLock(lockId, lifetime, cancellationToken)); + (acquired, expiresAt) = await TryAcquireReaderLock(lockId, resolvedLifetime, cancellationToken); + } while (!acquired); + } + + try + { + (bool completed, T? result) = await TaskEx.Timeout(action, expiresAt - DateTime.UtcNow, cancellationToken); + if (!completed) + throw new TimeoutException($"A reader lock for the distributed lock '{_id}' expired."); + // if the task sucssfully completed, then the result will be populated + return result!; + } + finally + { + Expression> filter = rwl => rwl.Id == _id && rwl.ReaderLocks.Any(l => l.Id == lockId); + await _locks.UpdateAsync( + filter, + u => u.RemoveAll(rwl => rwl.ReaderLocks, l => l.Id == lockId), + cancellationToken: CancellationToken.None + ); } - return new ReaderLockReleaser(this, lockId); } - public async Task WriterLockAsync( - TimeSpan? lifetime = default, + public async Task WriterLockAsync( + Func> action, + TimeSpan? lifetime = null, CancellationToken cancellationToken = default ) { + if (lifetime < TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException( + nameof(lifetime), + "The lifetime must be greater than or equal to zero." + ); + } + + TimeSpan resolvedLifetime = lifetime ?? _lockOptions.DefaultLifetime; string lockId = _idGenerator.GenerateId(); - if (!await TryAcquireWriterLock(lockId, lifetime, cancellationToken)) + (bool acquired, DateTime expiresAt) = await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken); + if (!acquired) { await _locks.UpdateAsync( _id, @@ -58,12 +138,9 @@ await _locks.UpdateAsync( RWLock? rwLock = sub.Change.Entity; if (rwLock is not null && !rwLock.IsAvailableForWriting(lockId)) { - var dateTimes = rwLock - .ReaderLocks.Where(l => l.ExpiresAt.HasValue) - .Select(l => l.ExpiresAt.GetValueOrDefault()) - .ToList(); + var dateTimes = rwLock.ReaderLocks.Select(l => l.ExpiresAt).ToList(); if (rwLock.WriterLock?.ExpiresAt is not null) - dateTimes.Add(rwLock.WriterLock.ExpiresAt.Value); + dateTimes.Add(rwLock.WriterLock.ExpiresAt); TimeSpan? timeout = default; if (dateTimes.Count > 0) { @@ -74,7 +151,8 @@ await _locks.UpdateAsync( if (timeout != TimeSpan.Zero) await sub.WaitForChangeAsync(timeout, cancellationToken); } - } while (!await TryAcquireWriterLock(lockId, lifetime, cancellationToken)); + (acquired, expiresAt) = await TryAcquireWriterLock(lockId, resolvedLifetime, cancellationToken); + } while (!acquired); } catch { @@ -86,20 +164,39 @@ await _locks.UpdateAsync( throw; } } - return new WriterLockReleaser(this, lockId); + + try + { + (bool completed, T? result) = await TaskEx.Timeout(action, expiresAt - DateTime.UtcNow, cancellationToken); + if (!completed) + throw new TimeoutException($"A writer lock for the distributed lock '{_id}' expired."); + // if the task sucssfully completed, then the result will be populated + return result!; + } + finally + { + Expression> filter = rwl => + rwl.Id == _id && rwl.WriterLock != null && rwl.WriterLock.Id == lockId; + await _locks.UpdateAsync( + filter, + u => u.Unset(rwl => rwl.WriterLock), + cancellationToken: CancellationToken.None + ); + } } - private async Task TryAcquireWriterLock( + private async Task<(bool, DateTime)> TryAcquireWriterLock( string lockId, - TimeSpan? lifetime, + TimeSpan lifetime, CancellationToken cancellationToken ) { - var now = DateTime.UtcNow; + DateTime now = DateTime.UtcNow; + DateTime expiresAt = now + lifetime; Expression> filter = rwl => rwl.Id == _id - && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt != null && rwl.WriterLock.ExpiresAt <= now) - && !rwl.ReaderLocks.Any(l => l.ExpiresAt == null || l.ExpiresAt > now) + && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt <= now) + && !rwl.ReaderLocks.Any(l => l.ExpiresAt > now) && (!rwl.WriterQueue.Any() || rwl.WriterQueue[0].Id == lockId); void Update(IUpdateBuilder u) { @@ -108,27 +205,26 @@ void Update(IUpdateBuilder u) new Lock { Id = lockId, - ExpiresAt = lifetime is null ? null : now + lifetime, + ExpiresAt = expiresAt, HostId = _hostId } ); u.RemoveAll(rwl => rwl.WriterQueue, l => l.Id == lockId); } RWLock? rwLock = await _locks.UpdateAsync(filter, Update, cancellationToken: cancellationToken); - return rwLock is not null; + return (rwLock is not null, expiresAt); } - private async Task TryAcquireReaderLock( + private async Task<(bool, DateTime)> TryAcquireReaderLock( string lockId, - TimeSpan? lifetime, + TimeSpan lifetime, CancellationToken cancellationToken ) { - var now = DateTime.UtcNow; + DateTime now = DateTime.UtcNow; + DateTime expiresAt = now + lifetime; Expression> filter = rwl => - rwl.Id == _id - && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt != null && rwl.WriterLock.ExpiresAt <= now) - && !rwl.WriterQueue.Any(); + rwl.Id == _id && (rwl.WriterLock == null || rwl.WriterLock.ExpiresAt <= now) && !rwl.WriterQueue.Any(); void Update(IUpdateBuilder u) { u.Add( @@ -136,42 +232,13 @@ void Update(IUpdateBuilder u) new Lock { Id = lockId, - ExpiresAt = lifetime is null ? null : now + lifetime, + ExpiresAt = expiresAt, HostId = _hostId } ); } RWLock? rwLock = await _locks.UpdateAsync(filter, Update, cancellationToken: cancellationToken); - return rwLock is not null; - } - - private class WriterLockReleaser(DistributedReaderWriterLock distributedLock, string lockId) : AsyncDisposableBase - { - private readonly DistributedReaderWriterLock _distributedLock = distributedLock; - private readonly string _lockId = lockId; - - protected override async ValueTask DisposeAsyncCore() - { - Expression> filter = rwl => - rwl.Id == _distributedLock._id && rwl.WriterLock != null && rwl.WriterLock.Id == _lockId; - await _distributedLock._locks.UpdateAsync(filter, u => u.Unset(rwl => rwl.WriterLock)); - } - } - - private class ReaderLockReleaser(DistributedReaderWriterLock distributedLock, string lockId) : AsyncDisposableBase - { - private readonly DistributedReaderWriterLock _distributedLock = distributedLock; - private readonly string _lockId = lockId; - - protected override async ValueTask DisposeAsyncCore() - { - Expression> filter = rwl => - rwl.Id == _distributedLock._id && rwl.ReaderLocks.Any(l => l.Id == _lockId); - await _distributedLock._locks.UpdateAsync( - filter, - u => u.RemoveAll(rwl => rwl.ReaderLocks, l => l.Id == _lockId) - ); - } + return (rwLock is not null, expiresAt); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLockFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLockFactory.cs index 81810fb1..e0d44795 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLockFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/DistributedReaderWriterLockFactory.cs @@ -2,11 +2,13 @@ public class DistributedReaderWriterLockFactory( IOptions serviceOptions, + IOptions lockOptions, IRepository locks, IIdGenerator idGenerator ) : IDistributedReaderWriterLockFactory { private readonly ServiceOptions _serviceOptions = serviceOptions.Value; + private readonly DistributedReaderWriterLockOptions _lockOptions = lockOptions.Value; private readonly IIdGenerator _idGenerator = idGenerator; private readonly IRepository _locks = locks; @@ -39,7 +41,7 @@ await _locks.InsertAsync( // the lock is already made - no new one needs to be made // This is done instead of checking if it exists first to prevent race conditions. } - return new DistributedReaderWriterLock(_serviceOptions.ServiceId, _locks, _idGenerator, id); + return new DistributedReaderWriterLock(_serviceOptions.ServiceId, _locks, _idGenerator, id, _lockOptions); } public async Task DeleteAsync(string id, CancellationToken cancellationToken = default) diff --git a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs index 26fe58ed..13fc9add 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/HangfireBuildJob.cs @@ -3,11 +3,10 @@ public abstract class HangfireBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger logger -) : HangfireBuildJob(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) +) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger) { public virtual Task RunAsync( string engineId, @@ -23,7 +22,6 @@ CancellationToken cancellationToken public abstract class HangfireBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger> logger @@ -31,7 +29,6 @@ ILogger> logger { protected IPlatformService PlatformService { get; } = platformService; protected IRepository Engines { get; } = engines; - protected IDistributedReaderWriterLockFactory LockFactory { get; } = lockFactory; protected IDataAccessContext DataAccessContext { get; } = dataAccessContext; protected IBuildJobService BuildJobService { get; } = buildJobService; protected ILogger> Logger { get; } = logger; @@ -44,21 +41,17 @@ public virtual async Task RunAsync( CancellationToken cancellationToken ) { - IDistributedReaderWriterLock @lock = await LockFactory.CreateAsync(engineId, cancellationToken); JobCompletionStatus completionStatus = JobCompletionStatus.Completed; try { - await InitializeAsync(engineId, buildId, data, @lock, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + await InitializeAsync(engineId, buildId, data, cancellationToken); + if (!await BuildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) { - if (!await BuildJobService.BuildJobStartedAsync(engineId, buildId, cancellationToken)) - { - completionStatus = JobCompletionStatus.Canceled; - return; - } + completionStatus = JobCompletionStatus.Canceled; + return; } - await DoWorkAsync(engineId, buildId, data, buildOptions, @lock, cancellationToken); + await DoWorkAsync(engineId, buildId, data, buildOptions, cancellationToken); } catch (OperationCanceledException) { @@ -70,22 +63,19 @@ CancellationToken cancellationToken if (engine?.CurrentBuild?.JobState is BuildJobState.Canceling) { completionStatus = JobCompletionStatus.Canceled; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildCanceledAsync(buildId, CancellationToken.None); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildCanceledAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: CancellationToken.None + ); Logger.LogInformation("Build canceled ({0})", buildId); } else if (engine is not null) @@ -93,17 +83,14 @@ await BuildJobService.BuildJobFinishedAsync( // the build was canceled, because of a server shutdown // switch state back to pending completionStatus = JobCompletionStatus.Restarting; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildRestartingAsync(buildId, CancellationToken.None); - await BuildJobService.BuildJobRestartingAsync(engineId, buildId, CancellationToken.None); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildRestartingAsync(buildId, CancellationToken.None); + await BuildJobService.BuildJobRestartingAsync(engineId, buildId, CancellationToken.None); + }, + cancellationToken: CancellationToken.None + ); throw; } else @@ -114,38 +101,29 @@ await DataAccessContext.WithTransactionAsync( catch (Exception e) { completionStatus = JobCompletionStatus.Faulted; - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: false, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildFaultedAsync(buildId, e.Message, CancellationToken.None); + await BuildJobService.BuildJobFinishedAsync( + engineId, + buildId, + buildComplete: false, + CancellationToken.None + ); + }, + cancellationToken: CancellationToken.None + ); Logger.LogError(0, e, "Build faulted ({0})", buildId); throw; } finally { - await CleanupAsync(engineId, buildId, data, @lock, completionStatus); + await CleanupAsync(engineId, buildId, data, completionStatus); } } - protected virtual Task InitializeAsync( - string engineId, - string buildId, - T data, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) + protected virtual Task InitializeAsync(string engineId, string buildId, T data, CancellationToken cancellationToken) { return Task.CompletedTask; } @@ -155,17 +133,10 @@ protected abstract Task DoWorkAsync( string buildId, T data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ); - protected virtual Task CleanupAsync( - string engineId, - string buildId, - T data, - IDistributedReaderWriterLock @lock, - JobCompletionStatus completionStatus - ) + protected virtual Task CleanupAsync(string engineId, string buildId, T data, JobCompletionStatus completionStatus) { return Task.CompletedTask; } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs index c9ddf983..61c6122e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IBuildJobService.cs @@ -14,7 +14,8 @@ Task> GetBuildingEnginesAsync( Task DeleteEngineAsync(string engineId, CancellationToken cancellationToken = default); Task StartBuildJobAsync( - BuildJobRunnerType jobType, + BuildJobRunnerType runnerType, + TranslationEngineType engineType, string engineId, string buildId, BuildStage stage, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs b/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs index 026aff28..7edf79f7 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/IDistributedReaderWriterLock.cs @@ -2,6 +2,25 @@ public interface IDistributedReaderWriterLock { - Task ReaderLockAsync(TimeSpan? lifetime = default, CancellationToken cancellationToken = default); - Task WriterLockAsync(TimeSpan? lifetime = default, CancellationToken cancellationToken = default); + Task ReaderLockAsync( + Func action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + Task WriterLockAsync( + Func action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + + Task ReaderLockAsync( + Func> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); + Task WriterLockAsync( + Func> action, + TimeSpan? lifetime = default, + CancellationToken cancellationToken = default + ); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs index 6612e11e..01776084 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ISmtModelFactory.cs @@ -2,21 +2,19 @@ public interface ISmtModelFactory { - Task CreateAsync( + IInteractiveTranslationModel Create( string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ); - Task CreateTrainerAsync( + ITrainer CreateTrainer( string engineDir, IRangeTokenizer tokenizer, - IParallelTextCorpus corpus, - CancellationToken cancellationToken = default + IParallelTextCorpus corpus ); - Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + void InitNew(string engineDir); + void Cleanup(string engineDir); Task UpdateEngineFromAsync(string engineDir, Stream source, CancellationToken cancellationToken = default); Task SaveEngineToAsync(string engineDir, Stream destination, CancellationToken cancellationToken = default); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs index c76b8e91..7ac3eb5b 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ITransferEngineFactory.cs @@ -2,13 +2,12 @@ public interface ITransferEngineFactory { - Task CreateAsync( + ITranslationEngine? Create( string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ); - Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + void InitNew(string engineDir); + void Cleanup(string engineDir); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs index e83337d3..c4470925 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ITruecaserFactory.cs @@ -2,12 +2,7 @@ public interface ITruecaserFactory { - Task CreateAsync(string engineDir, CancellationToken cancellationToken = default); - Task CreateTrainerAsync( - string engineDir, - ITokenizer tokenizer, - ITextCorpus corpus, - CancellationToken cancellationToken = default - ); - Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default); + ITruecaser Create(string engineDir); + ITrainer CreateTrainer(string engineDir, ITokenizer tokenizer, ITextCorpus corpus); + void Cleanup(string engineDir); } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs index 5a2fb912..fc1c2c95 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtEngineService.cs @@ -2,7 +2,6 @@ public class NmtEngineService( IPlatformService platformService, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IRepository engines, IBuildJobService buildJobService, @@ -11,7 +10,6 @@ public class NmtEngineService( ISharedFileService sharedFileService ) : ITranslationEngineService { - private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; private readonly IPlatformService _platformService = platformService; private readonly IDataAccessContext _dataAccessContext = dataAccessContext; private readonly IRepository _engines = engines; @@ -61,15 +59,10 @@ public async Task CreateAsync( public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await CancelBuildJobAsync(engineId, cancellationToken); + await CancelBuildJobAsync(engineId, cancellationToken); - await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); - await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); - } - await _lockFactory.DeleteAsync(engineId, CancellationToken.None); + await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); } public async Task StartBuildAsync( @@ -80,33 +73,26 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - // If there is a pending/running build, then no need to start a new one. - if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is already building or in the process of canceling."); - - await _buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Preprocess, - corpora, - buildOptions, - cancellationToken - ); - } + bool building = !await _buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.Nmt, + engineId, + buildId, + BuildStage.Preprocess, + corpora, + buildOptions, + cancellationToken + ); + // If there is a pending/running build, then no need to start a new one. + if (building) + throw new InvalidOperationException("The engine is already building or in the process of canceling."); } public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - if (!await CancelBuildJobAsync(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is not currently building."); - } + bool building = await CancelBuildJobAsync(engineId, cancellationToken); + if (!building) + throw new InvalidOperationException("The engine is not currently building."); } public async Task GetModelDownloadUrlAsync( diff --git a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs index b4c61648..3c46a34e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/NmtPreprocessBuildJob.cs @@ -3,7 +3,6 @@ public class NmtPreprocessBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, ILogger logger, IBuildJobService buildJobService, @@ -14,7 +13,6 @@ ILanguageTagService languageTagService : PreprocessBuildJob( platformService, engines, - lockFactory, dataAccessContext, logger, buildJobService, diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs index ff96570d..8237295a 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PostprocessBuildJob.cs @@ -3,23 +3,19 @@ public class PostprocessBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger logger, - ISharedFileService sharedFileService, - IOptionsMonitor options -) : HangfireBuildJob<(int, double)>(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) + ISharedFileService sharedFileService +) : HangfireBuildJob<(int, double)>(platformService, engines, dataAccessContext, buildJobService, logger) { protected ISharedFileService SharedFileService { get; } = sharedFileService; - private readonly BuildJobOptions _options = options.CurrentValue; protected override async Task DoWorkAsync( string engineId, string buildId, (int, double) data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -35,28 +31,20 @@ CancellationToken cancellationToken await PlatformService.InsertPretranslationsAsync(engineId, pretranslationsStream, cancellationToken); } - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - int additionalCorpusSize = await SaveModelAsync(engineId, buildId); - await PlatformService.BuildCompletedAsync( - buildId, - corpusSize + additionalCorpusSize, - Math.Round(confidence, 2, MidpointRounding.AwayFromZero), - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: true, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } + int additionalCorpusSize = await SaveModelAsync(engineId, buildId); + await DataAccessContext.WithTransactionAsync( + async (ct) => + { + await PlatformService.BuildCompletedAsync( + buildId, + corpusSize + additionalCorpusSize, + Math.Round(confidence, 2, MidpointRounding.AwayFromZero), + ct + ); + await BuildJobService.BuildJobFinishedAsync(engineId, buildId, buildComplete: true, ct); + }, + cancellationToken: CancellationToken.None + ); Logger.LogInformation("Build completed ({0}).", buildId); } @@ -70,14 +58,13 @@ protected override async Task CleanupAsync( string engineId, string buildId, (int, double) data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { if (completionStatus is JobCompletionStatus.Restarting) return; - if (_options.PreserveBuildFiles) + if (_buildJobOptions.PreserveBuildFiles) return; try diff --git a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs index d15e5a69..97e5fc77 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs @@ -14,14 +14,13 @@ public class PreprocessBuildJob : HangfireBuildJob> public PreprocessBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, ILogger logger, IBuildJobService buildJobService, ISharedFileService sharedFileService, ICorpusService corpusService ) - : base(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) + : base(platformService, engines, dataAccessContext, buildJobService, logger) { _sharedFileService = sharedFileService; _corpusService = corpusService; @@ -46,7 +45,6 @@ protected override async Task DoWorkAsync( string buildId, IReadOnlyList data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -86,19 +84,17 @@ CancellationToken cancellationToken cancellationToken.ThrowIfCancellationRequested(); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - bool canceling = !await BuildJobService.StartBuildJobAsync( - TrainJobRunnerType, - engineId, - buildId, - BuildStage.Train, - buildOptions: buildOptions, - cancellationToken: cancellationToken - ); - if (canceling) - throw new OperationCanceledException(); - } + bool canceling = !await BuildJobService.StartBuildJobAsync( + TrainJobRunnerType, + engine.Type, + engineId, + buildId, + BuildStage.Train, + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); } private async Task<(int TrainCount, int PretranslateCount)> WriteDataFilesAsync( @@ -209,7 +205,6 @@ protected override async Task CleanupAsync( string engineId, string buildId, IReadOnlyList data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs deleted file mode 100644 index c83f0703..00000000 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferBuildJob.cs +++ /dev/null @@ -1,166 +0,0 @@ -namespace Serval.Machine.Shared.Services; - -public class SmtTransferBuildJob( - IPlatformService platformService, - IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, - IDataAccessContext dataAccessContext, - IBuildJobService buildJobService, - ILogger logger, - IRepository trainSegmentPairs, - ITruecaserFactory truecaserFactory, - ISmtModelFactory smtModelFactory, - ICorpusService corpusService -) - : HangfireBuildJob>( - platformService, - engines, - lockFactory, - dataAccessContext, - buildJobService, - logger - ) -{ - private readonly IRepository _trainSegmentPairs = trainSegmentPairs; - private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; - private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; - private readonly ICorpusService _corpusService = corpusService; - - protected override Task InitializeAsync( - string engineId, - string buildId, - IReadOnlyList data, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - return _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, cancellationToken); - } - - protected override async Task DoWorkAsync( - string engineId, - string buildId, - IReadOnlyList data, - string? buildOptions, - IDistributedReaderWriterLock @lock, - CancellationToken cancellationToken - ) - { - await PlatformService.BuildStartedAsync(buildId, cancellationToken); - Logger.LogInformation("Build started ({0})", buildId); - var stopwatch = new Stopwatch(); - stopwatch.Start(); - - cancellationToken.ThrowIfCancellationRequested(); - - JsonObject? buildOptionsObject = null; - if (buildOptions is not null) - buildOptionsObject = JsonSerializer.Deserialize(buildOptions); - - var targetCorpora = new List(); - var parallelCorpora = new List(); - foreach (Corpus corpus in data) - { - ITextCorpus? sourceTextCorpus = _corpusService.CreateTextCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTextCorpus = _corpusService.CreateTextCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTextCorpus is null || targetTextCorpus is null) - continue; - - targetCorpora.Add(targetTextCorpus); - parallelCorpora.Add(sourceTextCorpus.AlignRows(targetTextCorpus)); - - if ((bool?)buildOptionsObject?["use_key_terms"] ?? true) - { - ITextCorpus? sourceTermCorpus = _corpusService.CreateTermCorpora(corpus.SourceFiles).FirstOrDefault(); - ITextCorpus? targetTermCorpus = _corpusService.CreateTermCorpora(corpus.TargetFiles).FirstOrDefault(); - if (sourceTermCorpus is not null && targetTermCorpus is not null) - { - IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus); - parallelCorpora.Add(parallelKeyTermsCorpus); - } - } - } - - IParallelTextCorpus parallelCorpus = parallelCorpora.Flatten(); - ITextCorpus targetCorpus = targetCorpora.Flatten(); - - var tokenizer = new LatinWordTokenizer(); - var detokenizer = new LatinWordDetokenizer(); - - using ITrainer smtModelTrainer = await _smtModelFactory.CreateTrainerAsync( - engineId, - tokenizer, - parallelCorpus, - cancellationToken - ); - using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( - engineId, - tokenizer, - targetCorpus, - cancellationToken - ); - - cancellationToken.ThrowIfCancellationRequested(); - - var progress = new BuildProgress(PlatformService, buildId); - await smtModelTrainer.TrainAsync(progress, cancellationToken); - await truecaseTrainer.TrainAsync(cancellationToken: cancellationToken); - - TranslationEngine? engine = await Engines.GetAsync(e => e.EngineId == engineId, cancellationToken); - if (engine is null) - throw new OperationCanceledException(); - - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - cancellationToken.ThrowIfCancellationRequested(); - await smtModelTrainer.SaveAsync(CancellationToken.None); - await truecaseTrainer.SaveAsync(CancellationToken.None); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineId, CancellationToken.None); - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( - p => p.TranslationEngineRef == engine.Id, - CancellationToken.None - ); - using ( - IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( - engineId, - tokenizer, - detokenizer, - truecaser, - CancellationToken.None - ) - ) - { - foreach (TrainSegmentPair segmentPair in segmentPairs) - { - await smtModel.TrainSegmentAsync( - segmentPair.Source, - segmentPair.Target, - cancellationToken: CancellationToken.None - ); - } - } - - await DataAccessContext.WithTransactionAsync( - async (ct) => - { - await PlatformService.BuildCompletedAsync( - buildId, - smtModelTrainer.Stats.TrainCorpusSize + segmentPairs.Count, - smtModelTrainer.Stats.Metrics["bleu"] * 100.0, - CancellationToken.None - ); - await BuildJobService.BuildJobFinishedAsync( - engineId, - buildId, - buildComplete: true, - CancellationToken.None - ); - }, - cancellationToken: CancellationToken.None - ); - } - - stopwatch.Stop(); - Logger.LogInformation("Build completed in {0}s ({1})", stopwatch.Elapsed.TotalSeconds, buildId); - } -} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs index bdda5353..5789d67d 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineService.cs @@ -40,7 +40,7 @@ public async Task CreateAsync( } TranslationEngine translationEngine = await _dataAccessContext.WithTransactionAsync( - async (ct) => + async ct => { var translationEngine = new TranslationEngine { @@ -57,38 +57,30 @@ public async Task CreateAsync( cancellationToken: cancellationToken ); - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, CancellationToken.None); - await using (await @lock.WriterLockAsync(cancellationToken: CancellationToken.None)) - { - SmtTransferEngineState state = _stateService.Get(engineId); - await state.InitNewAsync(CancellationToken.None); - } + SmtTransferEngineState state = _stateService.Get(engineId); + state.InitNew(); return translationEngine; } public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - await CancelBuildJobAsync(engineId, cancellationToken); - - await _dataAccessContext.WithTransactionAsync( - async (ct) => - { - await _engines.DeleteAsync(e => e.EngineId == engineId, ct); - await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); - }, - cancellationToken: cancellationToken - ); - await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); + await CancelBuildJobAsync(engineId, cancellationToken); - if (_stateService.TryRemove(engineId, out SmtTransferEngineState? state)) + await _dataAccessContext.WithTransactionAsync( + async ct => { - await state.DeleteDataAsync(); - await state.DisposeAsync(); - } - } + await _engines.DeleteAsync(e => e.EngineId == engineId, ct); + await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); + }, + cancellationToken: cancellationToken + ); + await _buildJobService.DeleteEngineAsync(engineId, CancellationToken.None); + + SmtTransferEngineState state = _stateService.Get(engineId); + _stateService.Remove(engineId); + // there is no way to cancel this call + state.DeleteData(); + state.Dispose(); await _lockFactory.DeleteAsync(engineId, CancellationToken.None); } @@ -99,16 +91,22 @@ public async Task> TranslateAsync( CancellationToken cancellationToken = default ) { + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); - SmtTransferEngineState state = _stateService.Get(engineId); - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - IReadOnlyList results = await hybridEngine.TranslateAsync(n, segment, cancellationToken); - state.LastUsedTime = DateTime.Now; - return results; - } + IReadOnlyList results = await @lock.ReaderLockAsync( + async ct => + { + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + return hybridEngine.Translate(n, segment); + }, + cancellationToken: cancellationToken + ); + + state.Touch(); + return results; } public async Task GetWordGraphAsync( @@ -117,16 +115,22 @@ public async Task GetWordGraphAsync( CancellationToken cancellationToken = default ) { + TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.ReaderLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetBuiltEngineAsync(engineId, cancellationToken); - SmtTransferEngineState state = _stateService.Get(engineId); - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - WordGraph result = await hybridEngine.GetWordGraphAsync(segment, cancellationToken); - state.LastUsedTime = DateTime.Now; - return result; - } + WordGraph result = await @lock.ReaderLockAsync( + async ct => + { + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + return hybridEngine.GetWordGraph(segment); + }, + cancellationToken: cancellationToken + ); + + state.Touch(); + return result; } public async Task TrainSegmentPairAsync( @@ -137,47 +141,39 @@ public async Task TrainSegmentPairAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + SmtTransferEngineState state = _stateService.Get(engineId); - async Task TrainSubroutineAsync(SmtTransferEngineState state, CancellationToken ct) + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => { - HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision); - await hybridEngine.TrainSegmentAsync(sourceSegment, targetSegment, sentenceStart, ct); - await _platformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); - } + TranslationEngine engine = await GetEngineAsync(engineId, ct); + + HybridTranslationEngine hybridEngine = await state.GetHybridEngineAsync(engine.BuildRevision, ct); + // there is no way to cancel this call + hybridEngine.TrainSegment(sourceSegment, targetSegment, sentenceStart); - SmtTransferEngineState state = _stateService.Get(engineId); - await _dataAccessContext.WithTransactionAsync( - async (ct) => + if (engine.CollectTrainSegmentPairs ?? false) { - if (engine.CurrentBuild?.JobState is BuildJobState.Active) - { - await _trainSegmentPairs.InsertAsync( - new TrainSegmentPair - { - TranslationEngineRef = engineId, - Source = sourceSegment, - Target = targetSegment, - SentenceStart = sentenceStart - }, - CancellationToken.None - ); - await TrainSubroutineAsync(state, CancellationToken.None); - } - else - { - await TrainSubroutineAsync(state, ct); - } - }, - cancellationToken: cancellationToken - ); + await _trainSegmentPairs.InsertAsync( + new TrainSegmentPair + { + TranslationEngineRef = engineId, + Source = sourceSegment, + Target = targetSegment, + SentenceStart = sentenceStart + }, + CancellationToken.None + ); + } - state.IsUpdated = true; - state.LastUsedTime = DateTime.Now; - } + state.IsUpdated = true; + }, + cancellationToken: cancellationToken + ); + + await _platformService.IncrementTrainSizeAsync(engineId, cancellationToken: CancellationToken.None); + state.Touch(); } public async Task StartBuildAsync( @@ -188,37 +184,32 @@ public async Task StartBuildAsync( CancellationToken cancellationToken = default ) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - // If there is a pending/running build, then no need to start a new one. - if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is already building or in the process of canceling."); + bool building = !await _buildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.SmtTransfer, + engineId, + buildId, + BuildStage.Preprocess, + corpora, + buildOptions, + cancellationToken + ); + // If there is a pending/running build, then no need to start a new one. + if (building) + throw new InvalidOperationException("The engine is already building or in the process of canceling."); - await _buildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Preprocess, - corpora, - buildOptions, - cancellationToken - ); - SmtTransferEngineState state = _stateService.Get(engineId); - state.LastUsedTime = DateTime.UtcNow; - } + SmtTransferEngineState state = _stateService.Get(engineId); + state.Touch(); } public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default) { - IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - if (!await CancelBuildJobAsync(engineId, cancellationToken)) - throw new InvalidOperationException("The engine is not currently building."); - SmtTransferEngineState state = _stateService.Get(engineId); - state.LastUsedTime = DateTime.UtcNow; - } + bool building = await CancelBuildJobAsync(engineId, cancellationToken); + if (!building) + throw new InvalidOperationException("The engine is not currently building."); + + SmtTransferEngineState state = _stateService.Get(engineId); + state.Touch(); } public int GetQueueSize() @@ -235,7 +226,7 @@ private async Task CancelBuildJobAsync(string engineId, CancellationToken { string? buildId = null; await _dataAccessContext.WithTransactionAsync( - async (ct) => + async ct => { (buildId, BuildJobState jobState) = await _buildJobService.CancelBuildJobAsync(engineId, ct); if (buildId is not null && jobState is BuildJobState.None) diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs index a5f4300a..a0072b3d 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineState.cs @@ -1,4 +1,6 @@ -namespace Serval.Machine.Shared.Services; +using SIL.ObjectModel; + +namespace Serval.Machine.Shared.Services; public class SmtTransferEngineState( ISmtModelFactory smtModelFactory, @@ -6,7 +8,7 @@ public class SmtTransferEngineState( ITruecaserFactory truecaserFactory, IOptionsMonitor options, string engineId -) : AsyncDisposableBase +) : DisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; @@ -21,34 +23,37 @@ string engineId public bool IsUpdated { get; set; } public int CurrentBuildRevision { get; set; } = -1; - public DateTime LastUsedTime { get; set; } = DateTime.UtcNow; + public DateTime LastUsedTime { get; private set; } = DateTime.UtcNow; public bool IsLoaded => _hybridEngine != null; private string EngineDir => Path.Combine(_options.CurrentValue.EnginesDir, EngineId); - public async Task InitNewAsync(CancellationToken cancellationToken = default) + public void InitNew() { - await _smtModelFactory.InitNewAsync(EngineDir, cancellationToken); - await _transferEngineFactory.InitNewAsync(EngineDir, cancellationToken); + _smtModelFactory.InitNew(EngineDir); + _transferEngineFactory.InitNew(EngineDir); } - public async Task GetHybridEngineAsync(int buildRevision) + public async Task GetHybridEngineAsync( + int buildRevision, + CancellationToken cancellationToken = default + ) { - using (await _lock.LockAsync()) + using (await _lock.LockAsync(cancellationToken)) { if (_hybridEngine is not null && CurrentBuildRevision != -1 && buildRevision != CurrentBuildRevision) { IsUpdated = false; - await UnloadAsync(); + Unload(); } if (_hybridEngine is null) { LatinWordTokenizer tokenizer = new(); LatinWordDetokenizer detokenizer = new(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(EngineDir); - _smtModel = await _smtModelFactory.CreateAsync(EngineDir, tokenizer, detokenizer, truecaser); - ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(EngineDir); + _smtModel = _smtModelFactory.Create(EngineDir, tokenizer, detokenizer, truecaser); + ITranslationEngine? transferEngine = _transferEngineFactory.Create( EngineDir, tokenizer, detokenizer, @@ -64,19 +69,15 @@ public async Task GetHybridEngineAsync(int buildRevisio } } - public async Task DeleteDataAsync() + public void DeleteData() { - await UnloadAsync(); - await _smtModelFactory.CleanupAsync(EngineDir); - await _transferEngineFactory.CleanupAsync(EngineDir); - await _truecaserFactory.CleanupAsync(EngineDir); + Unload(); + _smtModelFactory.Cleanup(EngineDir); + _transferEngineFactory.Cleanup(EngineDir); + _truecaserFactory.Cleanup(EngineDir); } - public async Task CommitAsync( - int buildRevision, - TimeSpan inactiveTimeout, - CancellationToken cancellationToken = default - ) + public void Commit(int buildRevision, TimeSpan inactiveTimeout) { if (_hybridEngine is null) return; @@ -85,34 +86,39 @@ public async Task CommitAsync( CurrentBuildRevision = buildRevision; if (buildRevision != CurrentBuildRevision) { - await UnloadAsync(cancellationToken); + Unload(); CurrentBuildRevision = buildRevision; } - else if (DateTime.Now - LastUsedTime > inactiveTimeout) + else if (DateTime.UtcNow - LastUsedTime > inactiveTimeout) { - await UnloadAsync(cancellationToken); + Unload(); } else { - await SaveModelAsync(cancellationToken); + SaveModel(); } } - private async Task SaveModelAsync(CancellationToken cancellationToken = default) + public void Touch() + { + LastUsedTime = DateTime.UtcNow; + } + + private void SaveModel() { if (_smtModel is not null && IsUpdated) { - await _smtModel.SaveAsync(cancellationToken); + _smtModel.Save(); IsUpdated = false; } } - private async Task UnloadAsync(CancellationToken cancellationToken = default) + private void Unload() { if (_hybridEngine is null) return; - await SaveModelAsync(cancellationToken); + SaveModel(); _hybridEngine.Dispose(); @@ -121,8 +127,8 @@ private async Task UnloadAsync(CancellationToken cancellationToken = default) CurrentBuildRevision = -1; } - protected override async ValueTask DisposeAsyncCore() + protected override void DisposeManagedResources() { - await UnloadAsync(); + Unload(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs index 03ef2ad8..9b97e004 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferEngineStateService.cs @@ -1,16 +1,20 @@ -namespace Serval.Machine.Shared.Services; +using SIL.ObjectModel; + +namespace Serval.Machine.Shared.Services; public class SmtTransferEngineStateService( ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory, ITruecaserFactory truecaserFactory, - IOptionsMonitor options -) : AsyncDisposableBase + IOptionsMonitor options, + ILogger logger +) : DisposableBase { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITransferEngineFactory _transferEngineFactory = transferEngineFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly IOptionsMonitor _options = options; + private readonly ILogger _logger = logger; private readonly ConcurrentDictionary _engineStates = new ConcurrentDictionary(); @@ -20,9 +24,9 @@ public SmtTransferEngineState Get(string engineId) return _engineStates.GetOrAdd(engineId, CreateState); } - public bool TryRemove(string engineId, [MaybeNullWhen(false)] out SmtTransferEngineState state) + public void Remove(string engineId) { - return _engineStates.TryRemove(engineId, out state); + _engineStates.TryRemove(engineId, out _); } public async Task CommitAsync( @@ -34,20 +38,24 @@ public async Task CommitAsync( { foreach (SmtTransferEngineState state in _engineStates.Values) { - IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) + try { - TranslationEngine? engine = await engines.GetAsync( - e => e.EngineId == state.EngineId, - cancellationToken + IDistributedReaderWriterLock @lock = await lockFactory.CreateAsync(state.EngineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => + { + TranslationEngine? engine = await engines.GetAsync(state.EngineId, ct); + if (engine is not null && !(engine.CollectTrainSegmentPairs ?? false)) + // there is no way to cancel this call + state.Commit(engine.BuildRevision, inactiveTimeout); + }, + _options.CurrentValue.EngineCommitTimeout, + cancellationToken: cancellationToken ); - if ( - engine is not null - && (engine.CurrentBuild is null || engine.CurrentBuild.JobState is BuildJobState.Pending) - ) - { - await state.CommitAsync(engine.BuildRevision, inactiveTimeout, cancellationToken); - } + } + catch (Exception e) + { + _logger.LogError(e, "Error occurred while committing SMT transfer engine {EngineId}.", state.EngineId); } } } @@ -63,10 +71,10 @@ private SmtTransferEngineState CreateState(string engineId) ); } - protected override async ValueTask DisposeAsyncCore() + protected override void DisposeManagedResources() { foreach (SmtTransferEngineState state in _engineStates.Values) - await state.DisposeAsync(); + state.Dispose(); _engineStates.Clear(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs index 9f532b2b..2d9bf00c 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferHangfireBuildJobFactory.cs @@ -11,7 +11,7 @@ public Job CreateJob(string engineId, string buildId, BuildStage stage, object? return stage switch { BuildStage.Preprocess - => CreateJob>( + => CreateJob>( engineId, buildId, "smt_transfer", diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs index cfa7cc32..8d2c12ca 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPostprocessBuildJob.cs @@ -3,64 +3,66 @@ public class SmtTransferPostprocessBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger logger, ISharedFileService sharedFileService, + IDistributedReaderWriterLockFactory lockFactory, IRepository trainSegmentPairs, ISmtModelFactory smtModelFactory, ITruecaserFactory truecaserFactory, - IOptionsMonitor buildOptions, IOptionsMonitor engineOptions -) - : PostprocessBuildJob( - platformService, - engines, - lockFactory, - dataAccessContext, - buildJobService, - logger, - sharedFileService, - buildOptions - ) +) : PostprocessBuildJob(platformService, engines, dataAccessContext, buildJobService, logger, sharedFileService) { private readonly ISmtModelFactory _smtModelFactory = smtModelFactory; private readonly ITruecaserFactory _truecaserFactory = truecaserFactory; private readonly IRepository _trainSegmentPairs = trainSegmentPairs; private readonly IOptionsMonitor _engineOptions = engineOptions; + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; protected override async Task SaveModelAsync(string engineId, string buildId) { - await using ( - Stream engineStream = await SharedFileService.OpenReadAsync( - $"builds/{buildId}/model.tar.gz", - CancellationToken.None - ) - ) - { - await _smtModelFactory.UpdateEngineFromAsync( - Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), - engineStream, - CancellationToken.None - ); - } - return await TrainOnNewSegmentPairsAsync(engineId); + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId); + return await @lock.WriterLockAsync( + async ct => + { + await using ( + Stream engineStream = await SharedFileService.OpenReadAsync($"builds/{buildId}/model.tar.gz", ct) + ) + { + await _smtModelFactory.UpdateEngineFromAsync( + Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId), + engineStream, + ct + ); + } + IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync( + p => p.TranslationEngineRef == engineId, + ct + ); + TrainOnNewSegmentPairs(engineId, segmentPairs, ct); + await Engines.UpdateAsync( + engineId, + u => u.Set(e => e.CollectTrainSegmentPairs, false), + cancellationToken: ct + ); + return segmentPairs.Count; + }, + _engineOptions.CurrentValue.SaveModelTimeout + ); } - private async Task TrainOnNewSegmentPairsAsync(string engineId) + private void TrainOnNewSegmentPairs( + string engineId, + IReadOnlyList segmentPairs, + CancellationToken cancellationToken + ) { - IReadOnlyList segmentPairs = await _trainSegmentPairs.GetAllAsync(p => - p.TranslationEngineRef == engineId - ); - if (segmentPairs.Count == 0) - return segmentPairs.Count; - string engineDir = Path.Combine(_engineOptions.CurrentValue.EnginesDir, engineId); var tokenizer = new LatinWordTokenizer(); var detokenizer = new LatinWordDetokenizer(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir); - using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(engineDir); + using IInteractiveTranslationModel smtModel = _smtModelFactory.Create( engineDir, tokenizer, detokenizer, @@ -68,9 +70,10 @@ private async Task TrainOnNewSegmentPairsAsync(string engineId) ); foreach (TrainSegmentPair segmentPair in segmentPairs) { - await smtModel.TrainSegmentAsync(segmentPair.Source, segmentPair.Target); + cancellationToken.ThrowIfCancellationRequested(); + smtModel.TrainSegment(segmentPair.Source, segmentPair.Target); } - await smtModel.SaveAsync(); - return segmentPairs.Count; + cancellationToken.ThrowIfCancellationRequested(); + smtModel.Save(); } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs new file mode 100644 index 00000000..9e14037a --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferPreprocessBuildJob.cs @@ -0,0 +1,48 @@ +namespace Serval.Machine.Shared.Services; + +public class SmtTransferPreprocessBuildJob( + IPlatformService platformService, + IRepository engines, + IDataAccessContext dataAccessContext, + ILogger logger, + IBuildJobService buildJobService, + ISharedFileService sharedFileService, + ICorpusService corpusService, + IDistributedReaderWriterLockFactory lockFactory, + IRepository trainSegmentPairs +) + : PreprocessBuildJob( + platformService, + engines, + dataAccessContext, + logger, + buildJobService, + sharedFileService, + corpusService + ) +{ + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; + private readonly IRepository _trainSegmentPairs = trainSegmentPairs; + + protected override async Task InitializeAsync( + string engineId, + string buildId, + IReadOnlyList data, + CancellationToken cancellationToken + ) + { + IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken); + await @lock.WriterLockAsync( + async ct => + { + await _trainSegmentPairs.DeleteAllAsync(p => p.TranslationEngineRef == engineId, ct); + await Engines.UpdateAsync( + engineId, + u => u.Set(e => e.CollectTrainSegmentPairs, true), + cancellationToken: ct + ); + }, + cancellationToken: cancellationToken + ); + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs index bb4870c1..e81fc354 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/SmtTransferTrainBuildJob.cs @@ -3,7 +3,6 @@ public class SmtTransferTrainBuildJob( IPlatformService platformService, IRepository engines, - IDistributedReaderWriterLockFactory lockFactory, IDataAccessContext dataAccessContext, IBuildJobService buildJobService, ILogger logger, @@ -11,7 +10,7 @@ public class SmtTransferTrainBuildJob( ITruecaserFactory truecaserFactory, ISmtModelFactory smtModelFactory, ITransferEngineFactory transferEngineFactory -) : HangfireBuildJob(platformService, engines, lockFactory, dataAccessContext, buildJobService, logger) +) : HangfireBuildJob(platformService, engines, dataAccessContext, buildJobService, logger) { private static readonly JsonWriterOptions PretranslateWriterOptions = new() { Indented = true }; private static readonly JsonSerializerOptions JsonSerializerOptions = @@ -28,7 +27,6 @@ protected override async Task DoWorkAsync( string buildId, object? data, string? buildOptions, - IDistributedReaderWriterLock @lock, CancellationToken cancellationToken ) { @@ -55,27 +53,24 @@ CancellationToken cancellationToken await GeneratePretranslationsAsync(buildId, engineDir, cancellationToken); - await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken)) - { - bool canceling = !await BuildJobService.StartBuildJobAsync( - BuildJobRunnerType.Hangfire, - engineId, - buildId, - BuildStage.Postprocess, - data: (trainCorpusSize, confidence), - buildOptions: buildOptions, - cancellationToken: cancellationToken - ); - if (canceling) - throw new OperationCanceledException(); - } + bool canceling = !await BuildJobService.StartBuildJobAsync( + BuildJobRunnerType.Hangfire, + TranslationEngineType.SmtTransfer, + engineId, + buildId, + BuildStage.Postprocess, + data: (trainCorpusSize, confidence), + buildOptions: buildOptions, + cancellationToken: cancellationToken + ); + if (canceling) + throw new OperationCanceledException(); } protected override async Task CleanupAsync( string engineId, string buildId, object? data, - IDistributedReaderWriterLock @lock, JobCompletionStatus completionStatus ) { @@ -118,22 +113,12 @@ private async Task DownloadDataAsync(string buildId, string corpusDir, Cancellat CancellationToken cancellationToken ) { - await _smtModelFactory.InitNewAsync(engineDir, cancellationToken); + _smtModelFactory.InitNew(engineDir); LatinWordTokenizer tokenizer = new(); int trainCorpusSize; double confidence; - using ITrainer smtModelTrainer = await _smtModelFactory.CreateTrainerAsync( - engineDir, - tokenizer, - parallelCorpus, - cancellationToken - ); - using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( - engineDir, - tokenizer, - targetCorpus, - cancellationToken - ); + using ITrainer smtModelTrainer = _smtModelFactory.CreateTrainer(engineDir, tokenizer, parallelCorpus); + using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineDir, tokenizer, targetCorpus); cancellationToken.ThrowIfCancellationRequested(); var progress = new BuildProgress(PlatformService, buildId); @@ -179,20 +164,18 @@ CancellationToken cancellationToken LatinWordTokenizer tokenizer = new(); LatinWordDetokenizer detokenizer = new(); - ITruecaser truecaser = await _truecaserFactory.CreateAsync(engineDir, CancellationToken.None); - using IInteractiveTranslationModel smtModel = await _smtModelFactory.CreateAsync( + ITruecaser truecaser = _truecaserFactory.Create(engineDir); + using IInteractiveTranslationModel smtModel = _smtModelFactory.Create( engineDir, tokenizer, detokenizer, - truecaser, - cancellationToken + truecaser ); - using ITranslationEngine? transferEngine = await _transferEngineFactory.CreateAsync( + using ITranslationEngine? transferEngine = _transferEngineFactory.Create( engineDir, tokenizer, detokenizer, - truecaser, - cancellationToken + truecaser ); HybridTranslationEngine hybridEngine = new(smtModel, transferEngine) { TargetDetokenizer = detokenizer }; diff --git a/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs index 031891c4..03f4ab5d 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/ThotSmtModelFactory.cs @@ -4,12 +4,11 @@ public class ThotSmtModelFactory(IOptionsMonitor options) : { private readonly IOptionsMonitor _options = options; - public Task CreateAsync( + public IInteractiveTranslationModel Create( string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ) { string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -22,14 +21,13 @@ public Task CreateAsync( LowercaseTarget = true, Truecaser = truecaser }; - return Task.FromResult(model); + return model; } - public Task CreateTrainerAsync( + public ITrainer CreateTrainer( string engineDir, IRangeTokenizer tokenizer, - IParallelTextCorpus corpus, - CancellationToken cancellationToken = default + IParallelTextCorpus corpus ) { string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -40,21 +38,20 @@ public Task CreateTrainerAsync( LowercaseSource = true, LowercaseTarget = true }; - return Task.FromResult(trainer); + return trainer; } - public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) + public void InitNew(string engineDir) { if (!Directory.Exists(engineDir)) Directory.CreateDirectory(engineDir); ZipFile.ExtractToDirectory(_options.CurrentValue.NewModelFile, engineDir); - return Task.CompletedTask; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { if (!Directory.Exists(engineDir)) - return Task.CompletedTask; + return; DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "lm")); DirectoryHelper.DeleteDirectoryRobust(Path.Combine(engineDir, "tm")); string smtConfigFileName = Path.Combine(engineDir, "smt.cfg"); @@ -62,7 +59,6 @@ public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = File.Delete(smtConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); - return Task.CompletedTask; } public async Task UpdateEngineFromAsync( diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs b/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs new file mode 100644 index 00000000..8f33674d --- /dev/null +++ b/src/Machine/src/Serval.Machine.Shared/Services/TimeoutInterceptor.cs @@ -0,0 +1,23 @@ +namespace Serval.Machine.Shared.Services; + +public class TimeoutInterceptor(ILogger logger) : Interceptor +{ + private readonly ILogger _logger = logger; + + public override async Task UnaryServerHandler( + TRequest request, + ServerCallContext context, + UnaryServerMethod continuation + ) + { + try + { + return await continuation(request, context); + } + catch (TimeoutException te) + { + _logger.LogError(te, "The method {Method} took too long to complete.", context.Method); + throw new RpcException(new Status(StatusCode.Unavailable, "The method took too long to complete.")); + } + } +} diff --git a/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs index a140792b..7834bd73 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/TransferEngineFactory.cs @@ -2,12 +2,11 @@ public class TransferEngineFactory : ITransferEngineFactory { - public Task CreateAsync( + public ITranslationEngine? Create( string engineDir, IRangeTokenizer tokenizer, IDetokenizer detokenizer, - ITruecaser truecaser, - CancellationToken cancellationToken = default + ITruecaser truecaser ) { string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); @@ -35,19 +34,18 @@ public class TransferEngineFactory : ITransferEngineFactory Truecaser = truecaser }; } - return Task.FromResult(transferEngine); + return transferEngine; } - public Task InitNewAsync(string engineDir, CancellationToken cancellationToken = default) + public void InitNew(string engineDir) { // TODO: generate source and target config files - return Task.CompletedTask; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { if (!Directory.Exists(engineDir)) - return Task.CompletedTask; + return; string hcSrcConfigFileName = Path.Combine(engineDir, "src-hc.xml"); if (File.Exists(hcSrcConfigFileName)) File.Delete(hcSrcConfigFileName); @@ -56,6 +54,5 @@ public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = File.Delete(hcTrgConfigFileName); if (!Directory.EnumerateFileSystemEntries(engineDir).Any()) Directory.Delete(engineDir); - return Task.CompletedTask; } } diff --git a/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs b/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs index cbf9c8b5..0821c10e 100644 --- a/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs +++ b/src/Machine/src/Serval.Machine.Shared/Services/UnigramTruecaserFactory.cs @@ -2,32 +2,26 @@ public class UnigramTruecaserFactory : ITruecaserFactory { - public async Task CreateAsync(string engineDir, CancellationToken cancellationToken = default) + public ITruecaser Create(string engineDir) { var truecaser = new UnigramTruecaser(); string path = GetModelPath(engineDir); - await truecaser.LoadAsync(path); + truecaser.Load(path); return truecaser; } - public Task CreateTrainerAsync( - string engineDir, - ITokenizer tokenizer, - ITextCorpus corpus, - CancellationToken cancellationToken = default - ) + public ITrainer CreateTrainer(string engineDir, ITokenizer tokenizer, ITextCorpus corpus) { string path = GetModelPath(engineDir); ITrainer trainer = new UnigramTruecaserTrainer(path, corpus) { Tokenizer = tokenizer }; - return Task.FromResult(trainer); + return trainer; } - public Task CleanupAsync(string engineDir, CancellationToken cancellationToken = default) + public void Cleanup(string engineDir) { string path = GetModelPath(engineDir); if (File.Exists(path)) File.Delete(path); - return Task.CompletedTask; } private static string GetModelPath(string engineDir) diff --git a/src/Machine/src/Serval.Machine.Shared/Usings.cs b/src/Machine/src/Serval.Machine.Shared/Usings.cs index 159f4f01..8d75abec 100644 --- a/src/Machine/src/Serval.Machine.Shared/Usings.cs +++ b/src/Machine/src/Serval.Machine.Shared/Usings.cs @@ -1,7 +1,6 @@ global using System.Collections.Concurrent; global using System.Data; global using System.Diagnostics; -global using System.Diagnostics.CodeAnalysis; global using System.Formats.Tar; global using System.Globalization; global using System.IO.Compression; diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Serval.Machine.Shared.Tests.csproj b/src/Machine/test/Serval.Machine.Shared.Tests/Serval.Machine.Shared.Tests.csproj index b8c398f9..ab0b38dd 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Serval.Machine.Shared.Tests.csproj +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Serval.Machine.Shared.Tests.csproj @@ -24,6 +24,7 @@ all + all diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockFactoryTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockFactoryTests.cs index d9389a69..84f61fae 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockFactoryTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockFactoryTests.cs @@ -70,6 +70,7 @@ public TestEnvironment() ServiceOptions serviceOptions = new() { ServiceId = "this_service" }; Factory = new DistributedReaderWriterLockFactory( new OptionsWrapper(serviceOptions), + new OptionsWrapper(new DistributedReaderWriterLockOptions()), Locks, new ObjectIdGenerator() ); diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockTests.cs index dae41b35..28894901 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/DistributedReaderWriterLockTests.cs @@ -6,342 +6,374 @@ public class DistributedReaderWriterLockTests [Test] public async Task ReaderLockAsync_NoLockAcquired() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - RWLock entity; - await using (await rwLock.ReaderLockAsync()) + await rwLock.ReaderLockAsync(ct => { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; }); + + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task ReaderLockAsync_ReaderLockAcquired() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - RWLock entity; - await using (await rwLock.ReaderLockAsync()) + await rwLock.ReaderLockAsync(async ct => { - await using (await rwLock.ReaderLockAsync()) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => + await rwLock.ReaderLockAsync( + ct => { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: ct + ); }); + + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task ReaderLockAsync_WriterLockAcquiredAndNotReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - await rwLock.WriterLockAsync(); - var task = rwLock.ReaderLockAsync(); - await AssertNeverCompletesAsync(task); + using CancellationTokenSource cts = new(); + Task task1 = rwLock.WriterLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + Task task2 = rwLock.ReaderLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + + await AssertNeverCompletesAsync(task2); + + cts.Cancel(); + Assert.ThrowsAsync(async () => await task1); + Assert.ThrowsAsync(async () => await task2); } [Test] public async Task ReaderLockAsync_WriterLockAcquiredAndReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task; - await using (await rwLock.WriterLockAsync()) + AsyncManualResetEvent @event = new(false); + Task outerTask = rwLock.WriterLockAsync(async ct => { - task = rwLock.ReaderLockAsync(); + Task task = rwLock.ReaderLockAsync( + ct => + { + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: ct + ); Assert.That(task.IsCompleted, Is.False); - } - - RWLock entity; - await using (await task) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + await @event.WaitAsync(ct); + return task; }); + + @event.Set(); + Task innerTask = await outerTask; + await innerTask; + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task ReaderLockAsync_WriterLockAcquiredAndExpired() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - RWLock entity; - await using (await rwLock.WriterLockAsync(TimeSpan.FromMilliseconds(400))) - { - var task = rwLock.ReaderLockAsync(); - await Task.Delay(500); - await using (await task) + Task? innerTask = null; + Task outerTask = rwLock.WriterLockAsync( + async ct => { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); - }); + innerTask = rwLock.ReaderLockAsync( + ct => + { + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: CancellationToken.None + ); + await Task.Delay(500, ct); + }, + lifetime: TimeSpan.FromMilliseconds(400) + ); + + Assert.ThrowsAsync(async () => await outerTask); + Assert.That(innerTask, Is.Not.Null); + await innerTask; + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task ReaderLockAsync_Cancelled() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task; - await using (await rwLock.WriterLockAsync()) + await rwLock.WriterLockAsync(ct => { - var cts = new CancellationTokenSource(); - task = rwLock.ReaderLockAsync(cancellationToken: cts.Token); + using CancellationTokenSource cts = new(); + Task task = rwLock.ReaderLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); cts.Cancel(); Assert.CatchAsync(async () => await task); - } - - RWLock entity; - await using (await rwLock.ReaderLockAsync()) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } + return Task.CompletedTask; + }); - entity = env.Locks.Get("test"); - Assert.Multiple(() => + await rwLock.ReaderLockAsync(ct => { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; }); + + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task WriterLockAsync_NoLockAcquired() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - RWLock entity; - await using (await rwLock.WriterLockAsync()) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.False); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => + await rwLock.WriterLockAsync(ct => { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.False); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; }); + + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task WriterLockAsync_ReaderLockAcquiredAndNotReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - await rwLock.ReaderLockAsync(); - var task = rwLock.WriterLockAsync(); - await AssertNeverCompletesAsync(task); + using CancellationTokenSource cts = new(); + Task task1 = rwLock.ReaderLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + Task task2 = rwLock.WriterLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + + await AssertNeverCompletesAsync(task2); + + cts.Cancel(); + Assert.ThrowsAsync(async () => await task1); + Assert.ThrowsAsync(async () => await task2); } [Test] public async Task WriterLockAsync_ReaderLockAcquiredAndReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task; - await using (await rwLock.ReaderLockAsync()) + AsyncManualResetEvent @event = new(false); + Task outerTask = rwLock.ReaderLockAsync(async ct => { - task = rwLock.WriterLockAsync(); + Task task = rwLock.WriterLockAsync( + ct => + { + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.False); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: ct + ); Assert.That(task.IsCompleted, Is.False); - } - - RWLock entity; - await using (await task) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.False); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + await @event.WaitAsync(ct); + return task; }); + + @event.Set(); + Task innerTask = await outerTask; + await innerTask; + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task WriterLockAsync_WriterLockAcquiredAndNeverReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - await rwLock.WriterLockAsync(); - var task = rwLock.WriterLockAsync(); - await AssertNeverCompletesAsync(task); + using CancellationTokenSource cts = new(); + Task task1 = rwLock.WriterLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + Task task2 = rwLock.WriterLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); + + await AssertNeverCompletesAsync(task2); + + cts.Cancel(); + Assert.ThrowsAsync(async () => await task1); + Assert.ThrowsAsync(async () => await task2); } [Test] public async Task WriterLockAsync_WriterLockAcquiredAndReleased() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task; - await using (await rwLock.WriterLockAsync()) + AsyncManualResetEvent @event = new(false); + Task outerTask = rwLock.WriterLockAsync(async ct => { - task = rwLock.WriterLockAsync(); + Task task = rwLock.WriterLockAsync( + ct => + { + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.False); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: ct + ); Assert.That(task.IsCompleted, Is.False); - } - - RWLock entity; - await using (await task) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.False); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + await @event.WaitAsync(ct); + return task; }); + + @event.Set(); + Task innerTask = await outerTask; + await innerTask; + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] public async Task WriterLockAsync_WriterLockTakesPriorityOverReaderLock() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task writeTask, - readTask; - await using (await rwLock.WriterLockAsync()) + int value = 1; + AsyncManualResetEvent @event = new(false); + Task<(Task, Task)> outerTask = rwLock.WriterLockAsync(async ct => { - readTask = rwLock.ReaderLockAsync(); + Task readTask = rwLock.ReaderLockAsync(ct => Task.FromResult(value++), cancellationToken: ct); Assert.That(readTask.IsCompleted, Is.False); - writeTask = rwLock.WriterLockAsync(); + Task writeTask = rwLock.WriterLockAsync(ct => Task.FromResult(value++), cancellationToken: ct); Assert.That(writeTask.IsCompleted, Is.False); - } + await @event.WaitAsync(ct); + return (writeTask, readTask); + }); - await writeTask; - await AssertNeverCompletesAsync(readTask); + @event.Set(); + (Task writeTask, Task readTask) = await outerTask; + Assert.That(await writeTask, Is.EqualTo(1)); + Assert.That(await readTask, Is.EqualTo(2)); } [Test] public async Task WriterLockAsync_FirstWriterLockHasPriority() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task1, - task2; - await using (await rwLock.WriterLockAsync()) + int value = 1; + AsyncManualResetEvent @event = new(false); + Task<(Task, Task)> outerTask = rwLock.WriterLockAsync(async ct => { - task1 = rwLock.WriterLockAsync(); + Task task1 = rwLock.WriterLockAsync(ct => Task.FromResult(value++), cancellationToken: ct); Assert.That(task1.IsCompleted, Is.False); - task2 = rwLock.WriterLockAsync(); + Task task2 = rwLock.WriterLockAsync(ct => Task.FromResult(value++), cancellationToken: ct); Assert.That(task2.IsCompleted, Is.False); - } + await @event.WaitAsync(ct); + return (task1, task2); + }); - await task1; - await AssertNeverCompletesAsync(task2); + @event.Set(); + (Task task1, Task task2) = await outerTask; + Assert.That(await task1, Is.EqualTo(1)); + Assert.That(await task2, Is.EqualTo(2)); } [Test] public async Task WriterLockAsync_WriterLockAcquiredAndExpired() { - var env = new TestEnvironment(); + TestEnvironment env = new(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - RWLock entity; - await using (await rwLock.WriterLockAsync(TimeSpan.FromMilliseconds(400))) - { - var task = rwLock.WriterLockAsync(); - await Task.Delay(500); - await using (await task) + Task? innerTask = null; + Task outerTask = rwLock.WriterLockAsync( + async ct => { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.False); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } - } - - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); - }); + innerTask = rwLock.WriterLockAsync( + ct => + { + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.False); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; + }, + cancellationToken: CancellationToken.None + ); + await Task.Delay(500, ct); + }, + lifetime: TimeSpan.FromMilliseconds(400) + ); + + Assert.ThrowsAsync(async () => await outerTask); + Assert.That(innerTask, Is.Not.Null); + await innerTask; + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } [Test] @@ -350,32 +382,29 @@ public async Task WriterLockAsync_Cancelled() var env = new TestEnvironment(); IDistributedReaderWriterLock rwLock = await env.Factory.CreateAsync("test"); - Task task; - await using (await rwLock.WriterLockAsync()) + await rwLock.WriterLockAsync(ct => { - var cts = new CancellationTokenSource(); - task = rwLock.WriterLockAsync(cancellationToken: cts.Token); + using CancellationTokenSource cts = new(); + Task task = rwLock.WriterLockAsync( + ct => Task.Delay(Timeout.InfiniteTimeSpan, ct), + cancellationToken: cts.Token + ); cts.Cancel(); Assert.CatchAsync(async () => await task); - } - - RWLock entity; - await using (await rwLock.WriterLockAsync()) - { - entity = env.Locks.Get("test"); - Assert.Multiple(() => - { - Assert.That(entity.IsAvailableForReading(), Is.False); - Assert.That(entity.IsAvailableForWriting(), Is.False); - }); - } + return Task.CompletedTask; + }); - entity = env.Locks.Get("test"); - Assert.Multiple(() => + await rwLock.WriterLockAsync(ct => { - Assert.That(entity.IsAvailableForReading(), Is.True); - Assert.That(entity.IsAvailableForWriting(), Is.True); + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.False); + Assert.That(lockEntity.IsAvailableForWriting(), Is.False); + return Task.CompletedTask; }); + + RWLock lockEntity = env.Locks.Get("test"); + Assert.That(lockEntity.IsAvailableForReading(), Is.True); + Assert.That(lockEntity.IsAvailableForWriting(), Is.True); } private static async Task AssertNeverCompletesAsync(Task task, int timeout = 100) @@ -385,7 +414,14 @@ private static async Task AssertNeverCompletesAsync(Task task, int timeout = 100 Task completedTask = await Task.WhenAny(task, Task.Delay(timeout)).ConfigureAwait(false); if (completedTask == task) Assert.Fail("Task completed unexpectedly."); - var _ = task.ContinueWith(_ => Assert.Fail("Task completed unexpectedly."), TaskScheduler.Default); + var _ = task.ContinueWith( + t => + { + if (!t.IsCanceled) + Assert.Fail("Task completed unexpectedly."); + }, + TaskScheduler.Default + ); } private class TestEnvironment @@ -394,9 +430,11 @@ public TestEnvironment() { Locks = new MemoryRepository(); var idGenerator = new ObjectIdGenerator(); - var options = Substitute.For>(); - options.Value.Returns(new ServiceOptions { ServiceId = "host" }); - Factory = new DistributedReaderWriterLockFactory(options, Locks, idGenerator); + var serviceOptions = Substitute.For>(); + serviceOptions.Value.Returns(new ServiceOptions { ServiceId = "host" }); + var lockOptions = Substitute.For>(); + lockOptions.Value.Returns(new DistributedReaderWriterLockOptions()); + Factory = new DistributedReaderWriterLockFactory(serviceOptions, lockOptions, Locks, idGenerator); } public DistributedReaderWriterLockFactory Factory { get; } diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs index fd05ada8..f601e9f5 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/NmtEngineServiceTests.cs @@ -100,6 +100,7 @@ public TestEnvironment() PlatformService = Substitute.For(); _lockFactory = new DistributedReaderWriterLockFactory( new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), + new OptionsWrapper(new DistributedReaderWriterLockOptions()), new MemoryRepository(), new ObjectIdGenerator() ); @@ -211,7 +212,6 @@ private NmtEngineService CreateService() { return new NmtEngineService( PlatformService, - _lockFactory, new MemoryDataAccessContext(), Engines, BuildJobService, @@ -262,6 +262,7 @@ private async Task RunNormalTrainJob() await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, + TranslationEngineType.Nmt, "engine1", "build1", BuildStage.Postprocess, @@ -296,7 +297,6 @@ public override object ActivateJob(Type jobType) return new NmtPreprocessBuildJob( _env.PlatformService, _env.Engines, - _env._lockFactory, new MemoryDataAccessContext(), Substitute.For>(), _env.BuildJobService, @@ -307,17 +307,15 @@ public override object ActivateJob(Type jobType) } if (jobType == typeof(PostprocessBuildJob)) { - var options = Substitute.For>(); - options.CurrentValue.Returns(new BuildJobOptions()); + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns(new BuildJobOptions()); return new PostprocessBuildJob( _env.PlatformService, _env.Engines, - _env._lockFactory, new MemoryDataAccessContext(), _env.BuildJobService, Substitute.For>(), - _env.SharedFileService, - options + _env.SharedFileService ); } return base.ActivateJob(jobType); diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs index 7c347603..df7498ee 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs @@ -369,6 +369,7 @@ private class TestEnvironment : DisposableBase public ICorpusService CorpusService { get; set; } public IPlatformService PlatformService { get; } public MemoryRepository Engines { get; } + public MemoryRepository TrainSegmentPairs { get; } public IDistributedReaderWriterLockFactory LockFactory { get; } public IBuildJobService BuildJobService { get; } public IClearMLService ClearMLService { get; } @@ -495,10 +496,12 @@ public TestEnvironment() } } ); + TrainSegmentPairs = new MemoryRepository(); CorpusService = new CorpusService(); PlatformService = Substitute.For(); LockFactory = new DistributedReaderWriterLockFactory( new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), + new OptionsWrapper(new DistributedReaderWriterLockOptions()), new MemoryRepository(), new ObjectIdGenerator() ); @@ -576,7 +579,6 @@ public PreprocessBuildJob GetBuildJob(TranslationEngineType engineType) return new NmtPreprocessBuildJob( PlatformService, Engines, - LockFactory, new MemoryDataAccessContext(), Substitute.For>(), BuildJobService, @@ -590,15 +592,16 @@ public PreprocessBuildJob GetBuildJob(TranslationEngineType engineType) } case TranslationEngineType.SmtTransfer: { - return new PreprocessBuildJob( + return new SmtTransferPreprocessBuildJob( PlatformService, Engines, - LockFactory, new MemoryDataAccessContext(), Substitute.For>(), BuildJobService, SharedFileService, - CorpusService + CorpusService, + LockFactory, + TrainSegmentPairs ) { Seed = 1234 diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs index 49bf65b3..095f9448 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/SmtTransferEngineServiceTests.cs @@ -22,8 +22,8 @@ public async Task CreateAsync() Assert.That(engine?.IsModelPersisted, Is.True); }); string engineDir = Path.Combine("translation_engines", EngineId2); - _ = env.SmtModelFactory.Received().InitNewAsync(engineDir); - _ = env.TransferEngineFactory.Received().InitNewAsync(engineDir); + env.SmtModelFactory.Received().InitNew(engineDir); + env.TransferEngineFactory.Received().InitNew(engineDir); } [TestCase(BuildJobRunnerType.Hangfire)] @@ -77,7 +77,7 @@ public async Task CancelBuildAsync_Building(BuildJobRunnerType trainJobRunnerTyp using var env = new TestEnvironment(trainJobRunnerType); env.UseInfiniteTrainJob(); - await env.Service.StartBuildAsync(EngineId1, BuildId1, "{}", Array.Empty()); + await env.Service.StartBuildAsync(EngineId1, BuildId1, buildOptions: "{}", corpora: []); await env.WaitForTrainingToStartAsync(); TranslationEngine engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Not.Null); @@ -159,7 +159,7 @@ public async Task TrainSegmentPairAsync(BuildJobRunnerType trainJobRunnerType) engine = env.Engines.Get(EngineId1); Assert.That(engine.CurrentBuild, Is.Null); Assert.That(engine.BuildRevision, Is.EqualTo(2)); - _ = env.SmtModel.Received(2).TrainSegmentAsync("esto es una prueba.", "this is a test.", true); + env.SmtModel.Received(2).TrainSegment("esto es una prueba.", "this is a test.", true); } [Test] @@ -169,7 +169,7 @@ public async Task CommitAsync_LoadedInactive() await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await Task.Delay(10); await env.CommitAsync(TimeSpan.Zero); - _ = env.SmtModel.Received().SaveAsync(); + env.SmtModel.Received().Save(); Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.False); } @@ -179,7 +179,7 @@ public async Task CommitAsync_LoadedActive() using var env = new TestEnvironment(); await env.Service.TrainSegmentPairAsync(EngineId1, "esto es una prueba.", "this is a test.", true); await env.CommitAsync(TimeSpan.FromHours(1)); - _ = env.SmtModel.Received().SaveAsync(); + env.SmtModel.Received().Save(); Assert.That(env.StateService.Get(EngineId1).IsLoaded, Is.True); } @@ -247,6 +247,7 @@ public TestEnvironment(BuildJobRunnerType trainJobRunnerType = BuildJobRunnerTyp _truecaserFactory = CreateTruecaserFactory(); _lockFactory = new DistributedReaderWriterLockFactory( new OptionsWrapper(new ServiceOptions { ServiceId = "host" }), + new OptionsWrapper(new DistributedReaderWriterLockOptions()), new MemoryRepository(), new ObjectIdGenerator() ); @@ -395,7 +396,8 @@ private SmtTransferEngineStateService CreateStateService() SmtModelFactory, TransferEngineFactory, _truecaserFactory, - options + options, + Substitute.For>() ); } @@ -439,70 +441,64 @@ private ISmtModelFactory CreateSmtModelFactory() }, [new Phrase(Range.Create(0, 5), 5)] ); + SmtModel.Translate(1, Arg.Any()).Returns([translationResult]); SmtModel - .TranslateAsync(1, Arg.Any()) - .Returns(Task.FromResult>([translationResult])); - SmtModel - .GetWordGraphAsync(Arg.Any()) + .GetWordGraph(Arg.Any()) .Returns( - Task.FromResult( - new WordGraph( - "esto es una prueba .".Split(), - new[] - { - new WordGraphArc( - 0, - 1, - 1.0, - "this is".Split(), - new WordAlignmentMatrix(2, 2) { [0, 0] = true, [1, 1] = true }, - Range.Create(0, 2), - GetSources(2, false), - [1.0, 1.0] - ), - new WordGraphArc( - 1, - 2, - 1.0, - "a test".Split(), - new WordAlignmentMatrix(2, 2) { [0, 0] = true, [1, 1] = true }, - Range.Create(2, 4), - GetSources(2, false), - [1.0, 1.0] - ), - new WordGraphArc( - 2, - 3, - 1.0, - ".".Split(), - new WordAlignmentMatrix(1, 1) { [0, 0] = true }, - Range.Create(4, 5), - GetSources(1, false), - [1.0] - ) - }, - [3] - ) + new WordGraph( + "esto es una prueba .".Split(), + new[] + { + new WordGraphArc( + 0, + 1, + 1.0, + "this is".Split(), + new WordAlignmentMatrix(2, 2) { [0, 0] = true, [1, 1] = true }, + Range.Create(0, 2), + GetSources(2, false), + [1.0, 1.0] + ), + new WordGraphArc( + 1, + 2, + 1.0, + "a test".Split(), + new WordAlignmentMatrix(2, 2) { [0, 0] = true, [1, 1] = true }, + Range.Create(2, 4), + GetSources(2, false), + [1.0, 1.0] + ), + new WordGraphArc( + 2, + 3, + 1.0, + ".".Split(), + new WordAlignmentMatrix(1, 1) { [0, 0] = true }, + Range.Create(4, 5), + GetSources(1, false), + [1.0] + ) + }, + [3] ) ); factory - .CreateAsync( + .Create( Arg.Any(), Arg.Any>(), Arg.Any>(), - Arg.Any(), - Arg.Any() + Arg.Any() ) - .Returns(Task.FromResult(SmtModel)); + .Returns(SmtModel); factory - .CreateTrainerAsync( + .CreateTrainer( Arg.Any(), Arg.Any>(), - Arg.Any(), - Arg.Any() + Arg.Any() ) - .Returns(Task.FromResult(SmtBatchTrainer)); + .Returns(SmtBatchTrainer); return factory; } @@ -511,57 +507,49 @@ private static ITransferEngineFactory CreateTransferEngineFactory() ITransferEngineFactory factory = Substitute.For(); ITranslationEngine engine = Substitute.For(); engine - .TranslateAsync(Arg.Any()) + .Translate(Arg.Any()) .Returns( - Task.FromResult( - new TranslationResult( - "this is a TEST.", - "esto es una prueba .".Split(), - "this is a TEST .".Split(), - [1.0, 1.0, 1.0, 1.0, 1.0], - [ - TranslationSources.Transfer, - TranslationSources.Transfer, - TranslationSources.Transfer, - TranslationSources.Transfer, - TranslationSources.Transfer - ], - new WordAlignmentMatrix(5, 5) - { - [0, 0] = true, - [1, 1] = true, - [2, 2] = true, - [3, 3] = true, - [4, 4] = true - }, - [new Phrase(Range.Create(0, 5), 5)] - ) + new TranslationResult( + "this is a TEST.", + "esto es una prueba .".Split(), + "this is a TEST .".Split(), + [1.0, 1.0, 1.0, 1.0, 1.0], + [ + TranslationSources.Transfer, + TranslationSources.Transfer, + TranslationSources.Transfer, + TranslationSources.Transfer, + TranslationSources.Transfer + ], + new WordAlignmentMatrix(5, 5) + { + [0, 0] = true, + [1, 1] = true, + [2, 2] = true, + [3, 3] = true, + [4, 4] = true + }, + [new Phrase(Range.Create(0, 5), 5)] ) ); factory - .CreateAsync( + .Create( Arg.Any(), Arg.Any>(), Arg.Any>(), - Arg.Any(), - Arg.Any() + Arg.Any() ) - .Returns(Task.FromResult(engine)); + .Returns(engine); return factory; } private ITruecaserFactory CreateTruecaserFactory() { ITruecaserFactory factory = Substitute.For(); - factory.CreateAsync(Arg.Any()).Returns(Task.FromResult(Truecaser)); + factory.Create(Arg.Any()).Returns(Truecaser); factory - .CreateTrainerAsync( - Arg.Any(), - Arg.Any>(), - Arg.Any(), - Arg.Any() - ) - .Returns(Task.FromResult(TruecaserTrainer)); + .CreateTrainer(Arg.Any(), Arg.Any>(), Arg.Any()) + .Returns(TruecaserTrainer); return factory; } @@ -631,23 +619,13 @@ private async Task RunTrainJob() await BuildJobService.BuildJobStartedAsync("engine1", "build1", _cancellationTokenSource.Token); string engineDir = Path.Combine("translation_engines", EngineId1); - await SmtModelFactory.InitNewAsync(engineDir, _cancellationTokenSource.Token); + SmtModelFactory.InitNew(engineDir); ITextCorpus sourceCorpus = new DictionaryTextCorpus(); ITextCorpus targetCorpus = new DictionaryTextCorpus(); IParallelTextCorpus parallelCorpus = sourceCorpus.AlignRows(targetCorpus); LatinWordTokenizer tokenizer = new(); - using ITrainer smtModelTrainer = await SmtModelFactory.CreateTrainerAsync( - engineDir, - tokenizer, - parallelCorpus, - _cancellationTokenSource.Token - ); - using ITrainer truecaseTrainer = await _truecaserFactory.CreateTrainerAsync( - engineDir, - tokenizer, - targetCorpus, - _cancellationTokenSource.Token - ); + using ITrainer smtModelTrainer = SmtModelFactory.CreateTrainer(engineDir, tokenizer, parallelCorpus); + using ITrainer truecaseTrainer = _truecaserFactory.CreateTrainer(engineDir, tokenizer, targetCorpus); await smtModelTrainer.TrainAsync(null, _cancellationTokenSource.Token); await truecaseTrainer.TrainAsync(cancellationToken: _cancellationTokenSource.Token); @@ -666,6 +644,7 @@ private async Task RunTrainJob() await BuildJobService.StartBuildJobAsync( BuildJobRunnerType.Hangfire, + TranslationEngineType.SmtTransfer, EngineId1, BuildId1, BuildStage.Postprocess, @@ -684,17 +663,18 @@ private class EnvActivator(TestEnvironment env) : JobActivator public override object ActivateJob(Type jobType) { - if (jobType == typeof(PreprocessBuildJob)) + if (jobType == typeof(SmtTransferPreprocessBuildJob)) { - return new PreprocessBuildJob( + return new SmtTransferPreprocessBuildJob( _env.PlatformService, _env.Engines, - _env._lockFactory, new MemoryDataAccessContext(), Substitute.For>(), _env.BuildJobService, _env.SharedFileService, - Substitute.For() + Substitute.For(), + _env._lockFactory, + _env.TrainSegmentPairs ) { TrainJobRunnerType = _env._trainJobRunnerType @@ -704,20 +684,19 @@ public override object ActivateJob(Type jobType) { var engineOptions = Substitute.For>(); engineOptions.CurrentValue.Returns(new SmtTransferEngineOptions()); - var buildOptions = Substitute.For>(); - buildOptions.CurrentValue.Returns(new BuildJobOptions()); + var buildJobOptions = Substitute.For>(); + buildJobOptions.CurrentValue.Returns(new BuildJobOptions()); return new SmtTransferPostprocessBuildJob( _env.PlatformService, _env.Engines, - _env._lockFactory, new MemoryDataAccessContext(), _env.BuildJobService, Substitute.For>(), _env.SharedFileService, + _env._lockFactory, _env.TrainSegmentPairs, _env.SmtModelFactory, _env._truecaserFactory, - buildOptions, engineOptions ); } @@ -726,7 +705,6 @@ public override object ActivateJob(Type jobType) return new SmtTransferTrainBuildJob( _env.PlatformService, _env.Engines, - _env._lockFactory, new MemoryDataAccessContext(), _env.BuildJobService, Substitute.For>(), diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Usings.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Usings.cs index 26908576..f58cb973 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Usings.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Usings.cs @@ -11,6 +11,7 @@ global using Microsoft.Extensions.Hosting.Internal; global using Microsoft.Extensions.Logging; global using Microsoft.Extensions.Options; +global using Nito.AsyncEx; global using NSubstitute; global using NSubstitute.ClearExtensions; global using NSubstitute.ExceptionExtensions; diff --git a/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj b/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj index d40fb933..2d8c5622 100644 --- a/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj +++ b/src/Serval/src/Serval.ApiServer/Serval.ApiServer.csproj @@ -21,9 +21,9 @@ - - - + + + all diff --git a/src/Serval/src/Serval.ApiServer/Startup.cs b/src/Serval/src/Serval.ApiServer/Startup.cs index 27e142e2..0831e75c 100644 --- a/src/Serval/src/Serval.ApiServer/Startup.cs +++ b/src/Serval/src/Serval.ApiServer/Startup.cs @@ -10,6 +10,13 @@ public void ConfigureServices(IServiceCollection services) { services.AddFeatureManagement(); services.AddRouting(o => o.LowercaseUrls = true); + + var apiOptions = new ApiOptions(); + Configuration.GetSection(ApiOptions.Key).Bind(apiOptions); + services.AddRequestTimeouts(o => + { + o.DefaultPolicy = new RequestTimeoutPolicy { Timeout = apiOptions.DefaultHttpRequestTimeout }; + }); services.AddOutputCache(options => { options.DefaultExpirationTimeSpan = TimeSpan.FromSeconds(10); @@ -215,6 +222,7 @@ public void Configure(IApplicationBuilder app, IWebHostEnvironment env) app.UseAuthentication(); app.UseRouting(); + app.UseRequestTimeouts(); app.UseOutputCache(); app.UseAuthorization(); app.UseEndpoints(x => diff --git a/src/Serval/src/Serval.ApiServer/Usings.cs b/src/Serval/src/Serval.ApiServer/Usings.cs index 377cd261..8f4b6446 100644 --- a/src/Serval/src/Serval.ApiServer/Usings.cs +++ b/src/Serval/src/Serval.ApiServer/Usings.cs @@ -10,6 +10,7 @@ global using MassTransit.Mediator; global using Microsoft.AspNetCore.Authentication.JwtBearer; global using Microsoft.AspNetCore.Authorization; +global using Microsoft.AspNetCore.Http.Timeouts; global using Microsoft.AspNetCore.Mvc; global using Microsoft.AspNetCore.OutputCaching; global using Microsoft.Extensions.Diagnostics.HealthChecks; diff --git a/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs b/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs index 3a139a36..459d3b34 100644 --- a/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs +++ b/src/Serval/src/Serval.Assessment/Controllers/AssessmentEnginesController.cs @@ -311,10 +311,10 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange change = await TaskEx.Timeout( + (_, EntityChange change) = await TaskEx.Timeout( ct => _jobService.GetNewerRevisionAsync(jobId, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, - cancellationToken + cancellationToken: cancellationToken ); return change.Type switch { diff --git a/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj b/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj index 96391e79..81838382 100644 --- a/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj +++ b/src/Serval/src/Serval.Assessment/Serval.Assessment.csproj @@ -19,6 +19,7 @@ + diff --git a/src/Serval/src/Serval.Assessment/Usings.cs b/src/Serval/src/Serval.Assessment/Usings.cs index 17020327..29d2b2f7 100644 --- a/src/Serval/src/Serval.Assessment/Usings.cs +++ b/src/Serval/src/Serval.Assessment/Usings.cs @@ -29,3 +29,4 @@ global using Serval.Shared.Utils; global using SIL.DataAccess; global using SIL.Scripture; +global using SIL.ServiceToolkit.Utils; diff --git a/src/Serval/src/Serval.Shared/Configuration/ApiOptions.cs b/src/Serval/src/Serval.Shared/Configuration/ApiOptions.cs index b238316c..9a998b72 100644 --- a/src/Serval/src/Serval.Shared/Configuration/ApiOptions.cs +++ b/src/Serval/src/Serval.Shared/Configuration/ApiOptions.cs @@ -4,5 +4,6 @@ public class ApiOptions { public const string Key = "Api"; - public TimeSpan LongPollTimeout { get; set; } = TimeSpan.FromSeconds(40); + public TimeSpan DefaultHttpRequestTimeout { get; set; } = TimeSpan.FromSeconds(58); // must be less than 60 seconds Cloudflare timeout + public TimeSpan LongPollTimeout { get; set; } = TimeSpan.FromSeconds(40); // must be less than DefaultHttpRequestTimeout } diff --git a/src/Serval/src/Serval.Shared/Controllers/OperationCancelledExceptionFilter.cs b/src/Serval/src/Serval.Shared/Controllers/OperationCancelledExceptionFilter.cs index 7fc82635..40b494d1 100644 --- a/src/Serval/src/Serval.Shared/Controllers/OperationCancelledExceptionFilter.cs +++ b/src/Serval/src/Serval.Shared/Controllers/OperationCancelledExceptionFilter.cs @@ -11,7 +11,11 @@ context.Exception is OperationCanceledException || context.Exception is RpcException rpcEx && rpcEx.StatusCode == StatusCode.Cancelled ) { - _logger.LogInformation("Request was cancelled"); + _logger.LogInformation( + "Request {RequestMethod}:{RequestPath} was cancelled", + context.HttpContext.Request.Method, + context.HttpContext.Request.Path + ); context.ExceptionHandled = true; context.Result = new StatusCodeResult(499); } diff --git a/src/Serval/src/Serval.Shared/Utils/TaskEx.cs b/src/Serval/src/Serval.Shared/Utils/TaskEx.cs deleted file mode 100644 index a9dd7cba..00000000 --- a/src/Serval/src/Serval.Shared/Utils/TaskEx.cs +++ /dev/null @@ -1,48 +0,0 @@ -namespace Serval.Shared.Utils; - -public static class TaskEx -{ - public static async Task Timeout( - Func> action, - TimeSpan timeout, - CancellationToken cancellationToken = default - ) - { - if (timeout == System.Threading.Timeout.InfiniteTimeSpan) - return await action(cancellationToken); - - var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - Task task = action(cts.Token); - var completedTask = await Task.WhenAny(task, Delay(timeout, cancellationToken)); - if (task != completedTask) - cts.Cancel(); - return await completedTask; - } - - public static async Task Timeout( - Func action, - TimeSpan timeout, - CancellationToken cancellationToken = default - ) - { - if (timeout == System.Threading.Timeout.InfiniteTimeSpan) - { - await action(cancellationToken); - } - else - { - var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); - Task task = action(cts.Token); - Task completedTask = await Task.WhenAny(task, Task.Delay(timeout, cancellationToken)); - if (task != completedTask) - cts.Cancel(); - await completedTask; - } - } - - public static async Task Delay(TimeSpan timeout, CancellationToken cancellationToken = default) - { - await Task.Delay(timeout, cancellationToken); - return default; - } -} diff --git a/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs b/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs index 6bbcd5f3..052d7ede 100644 --- a/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs +++ b/src/Serval/src/Serval.Translation/Controllers/TranslationEnginesController.cs @@ -775,10 +775,10 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange change = await TaskEx.Timeout( + (_, EntityChange change) = await TaskEx.Timeout( ct => _buildService.GetNewerRevisionAsync(buildId, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, - cancellationToken + cancellationToken: cancellationToken ); return change.Type switch { @@ -885,10 +885,10 @@ CancellationToken cancellationToken await AuthorizeAsync(id, cancellationToken); if (minRevision != null) { - EntityChange change = await TaskEx.Timeout( + (_, EntityChange change) = await TaskEx.Timeout( ct => _buildService.GetActiveNewerRevisionAsync(id, minRevision.Value, ct), _apiOptions.CurrentValue.LongPollTimeout, - cancellationToken + cancellationToken: cancellationToken ); return change.Type switch { diff --git a/src/Serval/src/Serval.Translation/Serval.Translation.csproj b/src/Serval/src/Serval.Translation/Serval.Translation.csproj index 96391e79..81838382 100644 --- a/src/Serval/src/Serval.Translation/Serval.Translation.csproj +++ b/src/Serval/src/Serval.Translation/Serval.Translation.csproj @@ -19,6 +19,7 @@ + diff --git a/src/Serval/src/Serval.Translation/Usings.cs b/src/Serval/src/Serval.Translation/Usings.cs index 77bb4439..1d3800f3 100644 --- a/src/Serval/src/Serval.Translation/Usings.cs +++ b/src/Serval/src/Serval.Translation/Usings.cs @@ -27,3 +27,4 @@ global using Serval.Translation.Models; global using Serval.Translation.Services; global using SIL.DataAccess; +global using SIL.ServiceToolkit.Utils; diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs new file mode 100644 index 00000000..bfc73e21 --- /dev/null +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/TaskEx.cs @@ -0,0 +1,55 @@ +namespace SIL.ServiceToolkit.Utils; + +public static class TaskEx +{ + public static async Task<(bool, T?)> Timeout( + Func> action, + TimeSpan timeout, + CancellationToken cancellationToken = default + ) + { + if (timeout == System.Threading.Timeout.InfiniteTimeSpan) + return (true, await action(cancellationToken)); + + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task task = action(cts.Token); + Task completedTask = await Task.WhenAny(task as Task, Delay(timeout, cancellationToken)); + T? result = await completedTask; + if (completedTask == task) + return (true, result); + + cts.Cancel(); + return (false, result); + } + + public static async Task Timeout( + Func action, + TimeSpan timeout, + CancellationToken cancellationToken = default + ) + { + if (timeout == System.Threading.Timeout.InfiniteTimeSpan) + { + await action(cancellationToken); + return true; + } + else + { + using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + Task task = action(cts.Token); + Task completedTask = await Task.WhenAny(task, Task.Delay(timeout, cancellationToken)); + await completedTask; + if (completedTask == task) + return true; + + cts.Cancel(); + return false; + } + } + + private static async Task Delay(TimeSpan timeout, CancellationToken cancellationToken = default) + { + await Task.Delay(timeout, cancellationToken); + return default; + } +}