Skip to content

Commit

Permalink
Merge pull request #9 from rafa-be/add_pyproject_toml
Browse files Browse the repository at this point in the history
Add pyproject.toml and fixes Mypy errors.
  • Loading branch information
sharpener6 authored Aug 27, 2024
2 parents 28f880e + 05c9ae4 commit 2e9c4dd
Show file tree
Hide file tree
Showing 32 changed files with 459 additions and 453 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python

name: Python Linter And Unittest

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
build:

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.8
uses: actions/setup-python@v5
with:
python-version: "3.8"
- name: Install dependencies
# TODO: add Scaler as a test dependency once it's published to Pypi
run: |
python -m pip install --upgrade pip
pip install flake8 pyproject-flake8 mypy
pip install -r requirements.txt
pip install pandas dask[distributed]
- name: Lint with flake8
run: |
pflake8 .
- name: Lint with MyPy
run: |
mypy .
- name: Run python unittest
run: |
python -m unittest discover -v tests
2 changes: 1 addition & 1 deletion parfun/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "6.0.5"
__version__ = "6.0.6"
12 changes: 6 additions & 6 deletions parfun/backend/dask.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from contextlib import contextmanager
from threading import BoundedSemaphore
from typing import ContextManager, Optional
from typing import Generator, Optional

try:
from dask.distributed import Client, Future, LocalCluster, worker_client
Expand Down Expand Up @@ -30,15 +30,15 @@ def __enter__(self) -> "DaskSession":
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
return None

def submit(self, fn, *args, **kwargs) -> Optional[Future]:
def submit(self, fn, *args, **kwargs) -> Optional[ProfiledFuture]:
with profile() as submit_duration:
future = ProfiledFuture()

acquired = self._concurrent_task_guard.acquire()
if not acquired:
return None

with self._engine.executor() as executor:
with self._engine.executor() as executor: # type: ignore[var-annotated]
underlying_future = executor.submit(timed_function, fn, *args, **kwargs)

def on_done_callback(underlying_future: Future):
Expand Down Expand Up @@ -82,7 +82,7 @@ def session(self) -> DaskSession:

@abc.abstractmethod
@contextmanager
def executor(self) -> ContextManager[ClientExecutor]:
def executor(self) -> Generator[ClientExecutor, None, None]:
raise NotImplementedError

def allows_nested_tasks(self) -> bool:
Expand All @@ -101,7 +101,7 @@ def __init__(self, scheduler_address: str):
self._executor = self._client.get_executor()

@contextmanager
def executor(self) -> ContextManager[ClientExecutor]:
def executor(self) -> Generator[ClientExecutor, None, None]:
yield self._executor

def get_scheduler_address(self) -> str:
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self, n_workers: int) -> None:
super().__init__(n_workers)

@contextmanager
def executor(self) -> ContextManager[ClientExecutor]:
def executor(self) -> Generator[ClientExecutor, None, None]:
with worker_client() as client:
yield client.get_executor()

Expand Down
4 changes: 2 additions & 2 deletions parfun/backend/local_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LocalMultiprocessingSession(BackendSession):

def __init__(self, underlying_executor: Executor):
self._underlying_executor = underlying_executor
self._concurrent_task_guard = BoundedSemaphore(underlying_executor._max_workers)
self._concurrent_task_guard = BoundedSemaphore(underlying_executor._max_workers) # type: ignore[attr-defined]

def __enter__(self) -> "LocalMultiprocessingSession":
return self
Expand Down Expand Up @@ -73,7 +73,7 @@ def on_done_callback(underlying_future: Future):
return future


@attrs.define
@attrs.define(init=False)
class LocalMultiprocessingBackend(BackendEngine):
"""
A concurrent engine that shares a similar interface to :py:class:`concurrent.futures.Executor`, but that blocks when
Expand Down
57 changes: 27 additions & 30 deletions parfun/backend/scaler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import inspect
from concurrent.futures import Future
from threading import BoundedSemaphore
from typing import Any, Optional, Set

try:
from scaler import Client, SchedulerClusterCombo
from scaler.client.future import ScalerFuture
from scaler.client.object_reference import ObjectReference
except ImportError:
raise ImportError("Scaler dependency missing. Use `pip install 'parfun[scaler]'` to install Scaler.")
Expand Down Expand Up @@ -34,7 +34,7 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
def preload_value(self, value: Any) -> ObjectReference:
return self.client.send_object(value)

def submit(self, fn, *args, **kwargs) -> Optional[Future]:
def submit(self, fn, *args, **kwargs) -> Optional[ProfiledFuture]:
with profile() as submit_duration:
future = ProfiledFuture()

Expand All @@ -44,7 +44,7 @@ def submit(self, fn, *args, **kwargs) -> Optional[Future]:

underlying_future = self.client.submit(fn, *args, **kwargs)

def on_done_callback(underlying_future: Future):
def on_done_callback(underlying_future: ScalerFuture):
assert submit_duration.value is not None

if underlying_future.cancelled():
Expand All @@ -56,17 +56,7 @@ def on_done_callback(underlying_future: Future):

if exception is None:
result = underlying_future.result()

# New for scaler>=1.5.0: task_duration is removed and replaced with profiling_info()

function_duration = int(
(
underlying_future.task_duration
if hasattr(underlying_future, "task_duration")
else underlying_future.profiling_info().duration_s
)
* 1_000_000_000
)
function_duration = int(underlying_future.profiling_info().cpu_time_s * 1_000_000_000)
else:
function_duration = 0
result = None
Expand Down Expand Up @@ -97,21 +87,28 @@ def __init__(
allows_nested_tasks: bool = True,
**client_kwargs,
):
self._scheduler_address = scheduler_address
self._n_workers = n_workers
self._allows_nested_tasks = allows_nested_tasks
self._client_kwargs = client_kwargs
self.__setstate__(
{
"scheduler_address": scheduler_address,
"n_workers": n_workers,
"allows_nested_tasks": allows_nested_tasks,
"client_kwargs": client_kwargs,
}
)

def __getstate__(self) -> dict:
return {
"scheduler_address": self._scheduler_address,
"n_workers": self._n_workers,
"allows_nested_tasks": self._allows_nested_tasks,
**self._client_kwargs,
"client_kwargs": self._client_kwargs,
}

def __setstate__(self, state: dict) -> None:
self.__init__(**state)
self._scheduler_address = state["scheduler_address"]
self._n_workers = state["n_workers"]
self._allows_nested_tasks = state["allows_nested_tasks"]
self._client_kwargs = state["client_kwargs"]

def session(self) -> ScalerSession:
return ScalerSession(self._scheduler_address, self._n_workers, **self._client_kwargs)
Expand Down Expand Up @@ -149,15 +146,6 @@ def __init__(
scheduler_port = get_available_tcp_port()
scheduler_address = f"tcp://127.0.0.1:{scheduler_port}"

scheduler_cluster_combo_kwargs = self.__get_constructor_arg_names(SchedulerClusterCombo)

self._cluster = SchedulerClusterCombo(
address=scheduler_address,
n_workers=n_workers,
per_worker_queue_size=per_worker_queue_size,
**{kwarg: value for kwarg, value in kwargs.items() if kwarg in scheduler_cluster_combo_kwargs},
)

client_kwargs = self.__get_constructor_arg_names(Client)

super().__init__(
Expand All @@ -167,8 +155,17 @@ def __init__(
**{kwarg: value for kwarg, value in kwargs.items() if kwarg in client_kwargs},
)

scheduler_cluster_combo_kwargs = self.__get_constructor_arg_names(SchedulerClusterCombo)

self._cluster = SchedulerClusterCombo(
address=scheduler_address,
n_workers=n_workers,
per_worker_queue_size=per_worker_queue_size,
**{kwarg: value for kwarg, value in kwargs.items() if kwarg in scheduler_cluster_combo_kwargs},
)

def __setstate__(self, state: dict) -> None:
super().__init__(**state)
super().__setstate__(state)
self._cluster = None # Unserialized instances have no cluster reference.

@property
Expand Down
2 changes: 1 addition & 1 deletion parfun/combine/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def concat_lists(values: Iterable[List[ListValue]]) -> List[ListValue]:
return list_concat(values)


def unzip(iterable: Iterable[Tuple]) -> Tuple[Iterable]:
def unzip(iterable: Iterable[Tuple]) -> Tuple[Iterable, ...]:
"""
Opposite of zip().
Expand Down
10 changes: 6 additions & 4 deletions parfun/decorators.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
"""
A decorator that helps users run their functions in parallel.
"""

import importlib
from functools import wraps
from typing import Callable, Iterable, Optional, Tuple, Union

from parfun.kernel.function_signature import NamedArguments
from parfun.kernel.parallel_function import ParallelFunction
from parfun.object import FunctionInputType, FunctionOutputType, PartitionType
from parfun.partition.object import PartitionFunction
from parfun.partition.object import PartitionFunction, PartitionGenerator
from parfun.partition_size_estimator.linear_regression_estimator import LinearRegessionEstimator
from parfun.partition_size_estimator.mixins import PartitionSizeEstimator


def parfun(
combine_with: Callable[[Iterable[FunctionOutputType]], FunctionOutputType],
split: Optional[PartitionFunction[PartitionType]] = None,
split: Optional[Callable[[NamedArguments], Tuple[NamedArguments, PartitionGenerator[NamedArguments]]]] = None,
partition_on: Optional[Union[str, Tuple[str, ...]]] = None,
partition_with: Optional[PartitionFunction[PartitionType]] = None,
initial_partition_size: Optional[Union[int, Callable[[PartitionType], int]]] = None,
fixed_partition_size: Optional[Union[int, Callable[[PartitionType], int]]] = None,
initial_partition_size: Optional[Union[int, Callable[[FunctionInputType], int]]] = None,
fixed_partition_size: Optional[Union[int, Callable[[FunctionInputType], int]]] = None,
profile: bool = False,
trace_export: Optional[str] = None,
partition_size_estimator_factory: Callable[[], PartitionSizeEstimator] = LinearRegessionEstimator,
Expand Down
4 changes: 2 additions & 2 deletions parfun/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import logging
import os
from contextvars import ContextVar, Token
from typing import Optional, Union
from typing import Callable, Dict, Optional, Union

from parfun.backend.local_multiprocessing import LocalMultiprocessingBackend
from parfun.backend.local_single_process import LocalSingleProcessBackend
from parfun.backend.mixins import BackendEngine

_backend_engine: ContextVar[Optional[BackendEngine]] = ContextVar("_backend_engine", default=None)

BACKEND_REGISTRY = {
BACKEND_REGISTRY: Dict[str, Callable] = {
"none": lambda *_args, **_kwargs: None,
"local_single_process": LocalSingleProcessBackend,
"local_multiprocessing": LocalMultiprocessingBackend,
Expand Down
6 changes: 3 additions & 3 deletions parfun/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import collections
import logging
import time
from typing import Any, Callable, Iterable, Optional, Tuple
from typing import Any, Callable, Deque, Iterable, Optional, Tuple

from parfun.backend.mixins import BackendSession
from parfun.backend.mixins import BackendSession, ProfiledFuture
from parfun.entry_point import get_parallel_backend
from parfun.profiler.object import TraceTime

Expand All @@ -28,7 +28,7 @@ def parallel_timed_map(
# Uses a generator function, so that we can use deque.pop() and thus discard the no longer required futures'
# references as we yield them.
def result_generator(backend_session: BackendSession):
futures = collections.deque()
futures: Deque[ProfiledFuture] = collections.deque()

try:
for args in zip(*iterables):
Expand Down
Loading

0 comments on commit 2e9c4dd

Please sign in to comment.