diff --git a/tests/test_server/test_device_pool.py b/tests/test_server/test_device_pool.py new file mode 100644 index 00000000..e44c1b64 --- /dev/null +++ b/tests/test_server/test_device_pool.py @@ -0,0 +1,17 @@ +import pytest +import torch.cuda +import torch.version + +from tiktorch.server.device_pool import TorchDevicePool + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available") +def test_device_pool_with_cuda(): + device_pool = TorchDevicePool() + assert device_pool.cuda_version == torch.version.cuda + + +@pytest.mark.skipif(torch.cuda.is_available(), reason="cuda is avaible") +def test_device_pool_without_cuda(): + device_pool = TorchDevicePool() + assert device_pool.cuda_version is None diff --git a/tiktorch/server/device_pool.py b/tiktorch/server/device_pool.py index 0a3dad21..6487ce01 100644 --- a/tiktorch/server/device_pool.py +++ b/tiktorch/server/device_pool.py @@ -5,7 +5,7 @@ import threading import uuid from collections import defaultdict -from typing import List +from typing import List, Optional import torch @@ -60,6 +60,14 @@ def devices(self) -> List[IDevice]: class IDevicePool(abc.ABC): + @property + @abc.abstractmethod + def cuda_version(self) -> Optional[str]: + """ + Returns CUDA version if available + """ + ... + @abc.abstractmethod def list_devices(self) -> List[IDevice]: """ @@ -112,6 +120,13 @@ def __init__(self): self.__device_ids_by_lease_id = defaultdict(list) self.__lock = threading.Lock() + @property + def cuda_version(self) -> Optional[str]: + if torch.cuda.is_available(): + return torch.version.cuda # type: ignore + else: + return None + def list_devices(self) -> List[IDevice]: with self.__lock: ids = ["cpu"] diff --git a/tiktorch/server/grpc/__init__.py b/tiktorch/server/grpc/__init__.py index f40190d5..cd1cad2c 100644 --- a/tiktorch/server/grpc/__init__.py +++ b/tiktorch/server/grpc/__init__.py @@ -1,4 +1,5 @@ import json +import os import threading from concurrent import futures from typing import Optional @@ -16,6 +17,13 @@ def _print_available_devices(device_pool: IDevicePool) -> None: + cuda = device_pool.cuda_version + print() + print("CUDA version:", cuda or "not available") + for env_var in ["CUDA_PATH", "CUDA_HOME", "CUDA_VISIBLE_DEVICES"]: + print(env_var, os.getenv(env_var, None)) + + print() print("Available devices:") for device in device_pool.list_devices(): print(f" * {device.id}") @@ -47,19 +55,20 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou inference_svc = InferenceServicer(device_pool, SessionManager(), data_store) fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout) data_svc = DataStoreServicer(data_store) + _print_available_devices(device_pool) inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server) inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server) data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server) acquired_port = server.add_insecure_port(f"{host}:{port}") + print() print(f"Starting server on {host}:{acquired_port}") if connection_file_path: print(f"Writing connection data to {connection_file_path}") with open(connection_file_path, "w") as conn_file: json.dump({"addr": host, "port": acquired_port}, conn_file) - _print_available_devices(device_pool) server.start() done_evt.wait()