diff --git a/README.md b/README.md index b3ab13a2..366bc4bd 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,15 @@ To install tiktorch and start server run: conda create -n tiktorch-server-env -c ilastik-forge -c conda-forge -c pytorch tiktorch conda activate tiktorch-server-env - +``` +To run server locally use +``` tiktorch-server ``` +To be able to connect to remote machine use (this will bind to all available addresses) +``` +tiktorch-server --addr 0.0.0.0 +``` ## Development environment diff --git a/tests/test_server/test_device_pool.py b/tests/test_server/test_device_pool.py new file mode 100644 index 00000000..9822b269 --- /dev/null +++ b/tests/test_server/test_device_pool.py @@ -0,0 +1,17 @@ +import pytest +import torch.cuda +import torch.version + +from tiktorch.server.device_pool import TorchDevicePool + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_device_pool_with_cuda(): + device_pool = TorchDevicePool() + assert device_pool.cuda_version == torch.version.cuda + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="cuda is available") +def test_device_pool_without_cuda(): + device_pool = TorchDevicePool() + assert device_pool.cuda_version is None diff --git a/tiktorch/server/device_pool.py b/tiktorch/server/device_pool.py index 0a3dad21..6487ce01 100644 --- a/tiktorch/server/device_pool.py +++ b/tiktorch/server/device_pool.py @@ -5,7 +5,7 @@ import threading import uuid from collections import defaultdict -from typing import List +from typing import List, Optional import torch @@ -60,6 +60,14 @@ def devices(self) -> List[IDevice]: class IDevicePool(abc.ABC): + @property + @abc.abstractmethod + def cuda_version(self) -> Optional[str]: + """ + Returns CUDA version if available + """ + ... + @abc.abstractmethod def list_devices(self) -> List[IDevice]: """ @@ -112,6 +120,13 @@ def __init__(self): self.__device_ids_by_lease_id = defaultdict(list) self.__lock = threading.Lock() + @property + def cuda_version(self) -> Optional[str]: + if torch.cuda.is_available(): + return torch.version.cuda # type: ignore + else: + return None + def list_devices(self) -> List[IDevice]: with self.__lock: ids = ["cpu"] diff --git a/tiktorch/server/grpc/__init__.py b/tiktorch/server/grpc/__init__.py index 2af48293..a2132a51 100644 --- a/tiktorch/server/grpc/__init__.py +++ b/tiktorch/server/grpc/__init__.py @@ -1,4 +1,5 @@ import json +import os import threading from concurrent import futures from typing import Optional @@ -7,7 +8,7 @@ from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc from tiktorch.server.data_store import DataStore -from tiktorch.server.device_pool import TorchDevicePool +from tiktorch.server.device_pool import IDevicePool, TorchDevicePool from tiktorch.server.session_manager import SessionManager from .data_store_servicer import DataStoreServicer @@ -15,6 +16,20 @@ from .inference_servicer import InferenceServicer +def _print_available_devices(device_pool: IDevicePool) -> None: + cuda = device_pool.cuda_version + print() + print("CUDA version:", cuda or "not available") + for env_var, value in os.environ.items(): + if env_var.startswith("CUDA_"): + print(env_var, value.strip() or "") + + print() + print("Available devices:") + for device in device_pool.list_devices(): + print(f" * {device.id}") + + def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeout: Optional[float] = None): """ Starts grpc server on given host and port and writes connection details to json file @@ -37,15 +52,18 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou data_store = DataStore() - inference_svc = InferenceServicer(TorchDevicePool(), SessionManager(), data_store) + device_pool = TorchDevicePool() + inference_svc = InferenceServicer(device_pool, SessionManager(), data_store) fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout) data_svc = DataStoreServicer(data_store) + _print_available_devices(device_pool) inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server) inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server) data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server) acquired_port = server.add_insecure_port(f"{host}:{port}") + print() print(f"Starting server on {host}:{acquired_port}") if connection_file_path: print(f"Writing connection data to {connection_file_path}")