Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: multiproc mapper max threads and default numprocess #112

Merged
merged 2 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ repos:
- id: check-ast
- id: check-case-conflict
- id: check-docstring-first
- repo: https://github.com/python-poetry/poetry
rev: "1.6"
hooks:
- id: poetry-check
903 changes: 432 additions & 471 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pynumaflow/mapper/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Iterator, Sequence
from collections.abc import Iterator, Sequence, Awaitable
from dataclasses import dataclass
from datetime import datetime
from typing import TypeVar, Callable
Expand Down Expand Up @@ -163,3 +163,4 @@ def watermark(self) -> datetime:


MapCallable = Callable[[list[str], Datum], Messages]
MapAsyncCallable = Callable[[list[str], Datum], Awaitable[Messages]]
6 changes: 3 additions & 3 deletions pynumaflow/mapper/async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
MAP_SOCK_PATH,
)
from pynumaflow.mapper import Datum
from pynumaflow.mapper._dtypes import MapCallable
from pynumaflow.mapper._dtypes import MapAsyncCallable
from pynumaflow.mapper.proto import map_pb2
from pynumaflow.mapper.proto import map_pb2_grpc
from pynumaflow.types import NumaflowServicerContext
Expand Down Expand Up @@ -58,12 +58,12 @@ class AsyncMapper(map_pb2_grpc.MapServicer):

def __init__(
self,
handler: MapCallable,
handler: MapAsyncCallable,
sock_path=MAP_SOCK_PATH,
max_message_size=MAX_MESSAGE_SIZE,
max_threads=MAX_THREADS,
):
self.__map_handler: MapCallable = handler
self.__map_handler: MapAsyncCallable = handler
self.sock_path = f"unix://{sock_path}"
self._max_message_size = max_message_size
self._max_threads = max_threads
Expand Down
29 changes: 19 additions & 10 deletions pynumaflow/mapper/multiproc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,13 @@ class MultiProcMapper(map_pb2_grpc.MapServicer):
handler: Function callable following the type signature of MapCallable
sock_path: Path to the TCP port to bind to
max_message_size: The max message size in bytes the server can receive and send
max_threads: The max number of threads to be spawned;
defaults to number of processors x4

Example invocation:
>>> from typing import Iterator
>>> from pynumaflow.mapper import Messages, Message \
... Datum, MultiProcMapper
...
>>> def map_handler(key: [str], datum: Datum) -> Messages:
>>> def map_handler(keys: list[str], datum: Datum) -> Messages:
... val = datum.value
... _ = datum.event_time
... _ = datum.watermark
Expand All @@ -65,6 +63,15 @@ class MultiProcMapper(map_pb2_grpc.MapServicer):
>>> grpc_server.start()
"""

__slots__ = (
"__map_handler",
"_max_message_size",
"_server_options",
"_sock_path",
"_process_count",
"_threads_per_proc",
)

def __init__(
self,
handler: MapCallable,
Expand All @@ -81,10 +88,8 @@ def __init__(
("grpc.so_reuseaddr", 1),
]
self._sock_path = sock_path
self._process_count = int(
os.getenv("NUM_CPU_MULTIPROC") or os.getenv("NUMAFLOW_CPU_LIMIT", 1)
)
self._thread_concurrency = int(os.getenv("MAX_THREADS", 0)) or (self._process_count * 4)
self._process_count = int(os.getenv("NUM_CPU_MULTIPROC") or os.cpu_count())
self._threads_per_proc = int(os.getenv("MAX_THREADS", "4"))

def MapFn(
self, request: map_pb2.MapRequest, context: NumaflowServicerContext
Expand Down Expand Up @@ -127,12 +132,16 @@ def IsReady(
"""
return map_pb2.ReadyResponse(ready=True)

def _run_server(self, bind_address):
def _run_server(self, bind_address: str) -> None:
"""Start a server in a subprocess."""
_LOGGER.info("Starting new server.")
_LOGGER.info(
"Starting new server with num_procs: %s, num_threads/proc: %s",
self._process_count,
self._threads_per_proc,
)
server = grpc.server(
futures.ThreadPoolExecutor(
max_workers=self._thread_concurrency,
max_workers=self._threads_per_proc,
),
options=self._server_options,
)
Expand Down
4 changes: 2 additions & 2 deletions pynumaflow/reducer/_dtypes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from asyncio import Task
from collections.abc import Iterator, Sequence
from collections.abc import Iterator, Sequence, Awaitable
from dataclasses import dataclass
from datetime import datetime
from typing import TypeVar, Callable
Expand Down Expand Up @@ -232,4 +232,4 @@ def keys(self) -> list[str]:
return self._key


ReduceCallable = Callable[[list[str], AsyncIterable[Datum], Metadata], Messages]
ReduceCallable = Callable[[list[str], AsyncIterable[Datum], Metadata], Awaitable[Messages]]
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pynumaflow"
version = "0.5.1"
version = "0.5.2"
description = "Provides the interfaces of writing Python User Defined Functions and Sinks for NumaFlow."
authors = ["NumaFlow Developers"]
readme = "README.md"
Expand All @@ -26,7 +26,7 @@ grpcio-tools = "^1.48.1"
google-cloud = "^0.34.0"
google-api-core = "^2.11.0"
protobuf = ">=3.20,<5.0"
aiorun = "^2022.11.1"
aiorun = "^2023.7"

[tool.poetry.group.dev]
optional = true
Expand Down
7 changes: 2 additions & 5 deletions tests/map/test_async_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,8 @@ def startup_callable(loop):
loop.run_forever()


def NewAsyncMapper(
map_handler=async_map_handler,
):
def new_async_mapper():
udfs = AsyncMapper(handler=async_map_handler)

return udfs


Expand All @@ -88,7 +85,7 @@ def setUpClass(cls) -> None:
_loop = loop
_thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True)
_thread.start()
udfs = NewAsyncMapper()
udfs = new_async_mapper()
asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop)
while True:
try:
Expand Down
3 changes: 2 additions & 1 deletion tests/map/test_multiproc_mapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest
from unittest import mock
from unittest.mock import patch, Mock

import grpc
from google.protobuf import empty_pb2 as _empty_pb2
Expand Down Expand Up @@ -34,7 +35,7 @@ def test_multiproc_init(self) -> None:
self.assertEqual(server._sock_path, 55551)
self.assertEqual(server._process_count, 3)

@mockenv(NUMAFLOW_CPU_LIMIT="4")
@patch("os.cpu_count", Mock(return_value=4))
def test_multiproc_process_count(self) -> None:
server = MultiProcMapper(handler=map_handler)
self.assertEqual(server._sock_path, 55551)
Expand Down