diff --git a/src/TaskExtras.cs b/src/TaskExtras.cs index 5fbe4fb..d406080 100644 --- a/src/TaskExtras.cs +++ b/src/TaskExtras.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; namespace RLC.TaskChaining; @@ -187,6 +189,37 @@ private static Task DoRetry( )); } + /// + /// Partitions a collection of into faulted and fulfilled lists. + /// + /// The underlying type of the tasks in . + /// The collection of tasks to partition. + /// A tuple of the partitioned tasks. + public static Task<(IEnumerable> Faulted, IEnumerable> Fulfilled)> Partition( + IEnumerable> tasks + ) + { + return tasks.Aggregate, Task<(IEnumerable> Faulted, IEnumerable> Fulfilled)>>( + Task.FromResult(((IEnumerable>)new List>(), (IEnumerable>)new List>())), + ZipTasksWith> Faulted, IEnumerable> Fulfilled), (IEnumerable> Faulted, IEnumerable> Fulfilled), Exception> + ( + (values, v) => (values.Faulted, values.Fulfilled.Append(Task.FromResult(v)).ToList()), + (values, v) => (values.Faulted.Append>(Task.FromException(v)).ToList(), values.Fulfilled) + ) + ); + } + + private static Func, Task, Task> ZipTasksWith( + Func f, + Func g + ) where TException : Exception + { + return (b, a) => a.Then( + valueA => b.Then(valueB => f(valueB, valueA)), + error => b.Then(valueB => g(valueB, (TException)error)) + ); + } + /// /// A function that performs retries of the if it fails. /// diff --git a/tests/unit/TaskExtrasPartitionTests.cs b/tests/unit/TaskExtrasPartitionTests.cs new file mode 100644 index 0000000..5a60c30 --- /dev/null +++ b/tests/unit/TaskExtrasPartitionTests.cs @@ -0,0 +1,32 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; + +using RLC.TaskChaining; +using Xunit; + +namespace RLC.TaskChainingTests; + +public class PartitionTests +{ + [Fact] + public async Task ItShouldPartition() + { + int expectedFulfills = 4; + int expectedFaults = 1; + List> tasks = new() + { + Task.FromResult("abc"), + Task.FromResult("def"), + Task.FromException(new InvalidOperationException()), + Task.FromResult("ghi"), + TaskExtras.Defer(() => Task.FromResult("jkl"), TimeSpan.FromSeconds(1)) + }; + + (IEnumerable> Faulted, IEnumerable> Fulfilled) partition = await TaskExtras.Partition(tasks); + + Assert.Equal(expectedFaults, partition.Faulted.Count()); + Assert.Equal(expectedFulfills, partition.Fulfilled.Count()); + } +}