From 49ebb17997968e1a95ecbd758cbd4521d32e4bc8 Mon Sep 17 00:00:00 2001 From: Brendan Maguire <1093243+brendanmaguire@users.noreply.github.com> Date: Fri, 30 Aug 2024 16:39:22 +0100 Subject: [PATCH] Add par_map methods to collections * `par_map` facilitates running an async function across all items in the collection * Resolves #213 --- expression/collections/array.py | 30 +++++++++++++++++++++++++++++- expression/collections/block.py | 20 +++++++++++++++++++- expression/collections/map.py | 22 +++++++++++++++++++++- expression/collections/seq.py | 19 ++++++++++++++++++- tests/test_array.py | 18 ++++++++++++++++++ tests/test_block.py | 19 +++++++++++++++++++ tests/test_map.py | 20 ++++++++++++++++++++ tests/test_seq.py | 19 +++++++++++++++++++ 8 files changed, 163 insertions(+), 4 deletions(-) diff --git a/expression/collections/array.py b/expression/collections/array.py index 19b0a16..05edd5a 100644 --- a/expression/collections/array.py +++ b/expression/collections/array.py @@ -10,9 +10,10 @@ from __future__ import annotations import array +import asyncio import builtins import functools -from collections.abc import Callable, Iterable, Iterator, MutableSequence +from collections.abc import Awaitable, Callable, Iterable, Iterator, MutableSequence from enum import Enum from typing import Any, TypeVar, cast @@ -185,9 +186,36 @@ def __init__( self.typecode = typecode def map(self, mapping: Callable[[_TSource], _TResult]) -> TypedArray[_TResult]: + """Map array. + + Builds a new array whose elements are the results of applying + the given function to each of the elements of the array. + + Args: + mapping: A function to transform items from the input array. + + Returns: + The result sequence. + """ result = builtins.map(mapping, self.value) return TypedArray(result) + async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> TypedArray[_TResult]: + """Map array asynchronously. + + Builds a new array whose elements are the results of applying + the given asynchronous function to each of the elements of the + array. + + Args: + mapping: A function to transform items from the input array. + + Returns: + The result sequence. + """ + result = await asyncio.gather(*[mapping(item) for item in self]) + return TypedArray(result) + def choose(self, chooser: Callable[[_TSource], Option[_TResult]]) -> TypedArray[_TResult]: """Choose items from the list. diff --git a/expression/collections/block.py b/expression/collections/block.py index 715f385..557edf5 100644 --- a/expression/collections/block.py +++ b/expression/collections/block.py @@ -20,10 +20,11 @@ from __future__ import annotations +import asyncio import builtins import functools import itertools -from collections.abc import Callable, Collection, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Callable, Collection, Iterable, Iterator, Sequence from typing import TYPE_CHECKING, Any, Literal, TypeVar, get_args, overload from typing_extensions import TypeVarTuple, Unpack @@ -239,6 +240,23 @@ def map(self, mapping: Callable[[_TSource], _TResult]) -> Block[_TResult]: """ return Block((*builtins.map(mapping, self),)) + async def par_map(self, mapping: Callable[[_TSource], Awaitable[_TResult]]) -> Block[_TResult]: + """Map list asynchronously. + + Builds a new collection whose elements are the results of + applying the given asynchronous function to each of the + elements of the collection. + + Args: + mapping: The function to transform elements from the input + list. + + Returns: + The list of transformed elements. + """ + result = await asyncio.gather(*[mapping(item) for item in self]) + return Block(result) + def starmap(self: Block[tuple[Unpack[_P]]], mapping: Callable[[Unpack[_P]], _TResult]) -> Block[_TResult]: """Starmap source sequence. diff --git a/expression/collections/map.py b/expression/collections/map.py index 90aabb7..d4f5534 100644 --- a/expression/collections/map.py +++ b/expression/collections/map.py @@ -16,7 +16,8 @@ # - https://github.com/fsharp/fsharp/blob/master/src/fsharp/FSharp.Core/map.fs from __future__ import annotations -from collections.abc import Callable, ItemsView, Iterable, Iterator, Mapping +import asyncio +from collections.abc import Awaitable, Callable, ItemsView, Iterable, Iterator, Mapping from typing import Any, TypeVar, cast from expression.core import Option, PipeMixin, SupportsLessThan, curry_flip, pipe @@ -114,6 +115,25 @@ def map(self, mapping: Callable[[_Key, _Value], _Result]) -> Map[_Key, _Result]: """ return Map(maptree.map(mapping, self._tree)) + async def par_map(self, mapping: Callable[[_Key, _Value], Awaitable[_Result]]) -> Map[_Key, _Result]: + """Map the mapping asynchronously. + + Builds a new collection whose elements are the results of + applying the given asynchronous function to each of the elements + of the collection. The key passed to the function indicates the + key of element being transformed. + + Args: + mapping: The function to transform the key/value pairs + + Returns: + The resulting map of keys and transformed values. + """ + keys_and_values = self.to_seq() + result = await asyncio.gather(*(mapping(key, value) for key, value in keys_and_values)) + keys = [key for key, _ in keys_and_values] + return Map.of_seq(zip(keys, result)) + def partition(self, predicate: Callable[[_Key, _Value], bool]) -> tuple[Map[_Key, _Value], Map[_Key, _Value]]: r1, r2 = maptree.partition(predicate, self._tree) return Map(r1), Map(r2) diff --git a/expression/collections/seq.py b/expression/collections/seq.py index 415da1c..c939bbb 100644 --- a/expression/collections/seq.py +++ b/expression/collections/seq.py @@ -24,10 +24,11 @@ from __future__ import annotations +import asyncio import builtins import functools import itertools -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Awaitable, Callable, Iterable, Iterator from typing import TYPE_CHECKING, Any, TypeVar, cast, overload from expression.core import ( @@ -175,6 +176,22 @@ def map(self, mapper: Callable[[_TSource], _TResult]) -> Seq[_TResult]: """ return Seq(pipe(self, map(mapper))) + async def par_map(self, mapper: Callable[[_TSource], Awaitable[_TResult]]) -> Seq[_TResult]: + """Map sequence asynchronously. + + Builds a new collection whose elements are the results of + applying the given asynchronous function to each of the elements + of the collection. + + Args: + mapper: A function to transform items from the input sequence. + + Returns: + The result sequence. + """ + result = await asyncio.gather(*[mapper(item) for item in self]) + return Seq(result) + @overload def starmap(self: Seq[tuple[_T1, _T2]], mapping: Callable[[_T1, _T2], _TResult]) -> Seq[_TResult]: ... diff --git a/tests/test_array.py b/tests/test_array.py index 7b7ee33..e162494 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1,3 +1,4 @@ +import asyncio import functools from collections.abc import Callable from typing import Any @@ -428,3 +429,20 @@ def test_array_monad_law_associativity_iterable(xs: list[int]): m = array.of_seq(xs) assert m.collect(f).collect(g) == m.collect(lambda x: f(x).collect(g)) + +@pytest.mark.asyncio +async def test_par_map(): + async def async_fn(i: int): + await asyncio.sleep(0.1) + return i * 2 + + xs = TypedArray(range(1, 10)) + + start_time = asyncio.get_event_loop().time() + ys = await xs.par_map(async_fn) + end_time = asyncio.get_event_loop().time() + + assert ys == TypedArray(i * 2 for i in range(1, 10)) + + time_taken = end_time - start_time + assert time_taken < 0.2, "par_map took too long" diff --git a/tests/test_block.py b/tests/test_block.py index 222bc4c..40a50e3 100644 --- a/tests/test_block.py +++ b/tests/test_block.py @@ -1,3 +1,4 @@ +import asyncio import functools from builtins import list as list from collections.abc import Callable @@ -7,6 +8,7 @@ from hypothesis import strategies as st from pydantic import BaseModel, Field, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema +import pytest from expression import Nothing, Option, Some, pipe from expression.collections import Block, block @@ -458,3 +460,20 @@ def test_serialize_block_works(): assert model_.annotated_type_empty == block.empty assert model_.custom_type == Block(["a", "b", "c"]) assert model_.custom_type_empty == block.empty + +@pytest.mark.asyncio +async def test_par_map(): + async def async_fn(i: int): + await asyncio.sleep(0.1) + return i * 2 + + xs = Block(range(1, 10)) + + start_time = asyncio.get_event_loop().time() + ys = await xs.par_map(async_fn) + end_time = asyncio.get_event_loop().time() + + assert ys == Block(i * 2 for i in range(1, 10)) + + time_taken = end_time - start_time + assert time_taken < 0.2, "par_map took too long" diff --git a/tests/test_map.py b/tests/test_map.py index 553a36f..8db8a31 100644 --- a/tests/test_map.py +++ b/tests/test_map.py @@ -1,5 +1,7 @@ +import asyncio from collections.abc import Callable, ItemsView, Iterable +import pytest from hypothesis import given # type: ignore from hypothesis import strategies as st @@ -150,3 +152,21 @@ def test_expression_issue_105(): m = m.add("1", 1).add("2", 2).add("3", 3).add("4", 4) m = m.change("2", lambda x: x) m = m.change("3", lambda x: x) + + +@pytest.mark.asyncio +async def test_par_map(): + async def async_fn(key: str, value: int) -> int: + await asyncio.sleep(0.1) + return int(key) * value + + xs = Map.of_seq((str(i), i) for i in range(1, 10)) + + start_time = asyncio.get_event_loop().time() + ys = await xs.par_map(async_fn) + end_time = asyncio.get_event_loop().time() + + assert ys == Map.of_seq((str(i), i * i) for i in range(1, 10)) + + time_taken = end_time - start_time + assert time_taken < 0.2, "par_map took too long" diff --git a/tests/test_seq.py b/tests/test_seq.py index 7305020..318d55e 100644 --- a/tests/test_seq.py +++ b/tests/test_seq.py @@ -1,3 +1,4 @@ +import asyncio import functools from collections.abc import Callable, Iterable from itertools import accumulate @@ -382,3 +383,21 @@ def test_seq_monad_law_associativity_empty(value: int): # Empty list m = empty assert list(m.collect(f).collect(g)) == list(m.collect(lambda x: f(x).collect(g))) + + +@pytest.mark.asyncio +async def test_par_map(): + async def async_fn(i: int): + await asyncio.sleep(0.1) + return i * 2 + + xs = seq.of_iterable(range(1, 10)) + + start_time = asyncio.get_event_loop().time() + ys = await xs.par_map(async_fn) + end_time = asyncio.get_event_loop().time() + + assert list(ys) == [i * 2 for i in range(1, 10)] + + time_taken = end_time - start_time + assert time_taken < 0.2, "par_map took too long"