Skip to content

Commit

Permalink
Add cuda info to startup log
Browse files Browse the repository at this point in the history
  • Loading branch information
m-novikov committed Jul 17, 2021
1 parent 8d1a452 commit 6ea0ba5
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
17 changes: 17 additions & 0 deletions tests/test_server/test_device_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch.cuda
import torch.version

from tiktorch.server.device_pool import TorchDevicePool


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
def test_device_pool_with_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version == torch.version.cuda


@pytest.mark.skipif(torch.cuda.is_available(), reason="cuda is avaible")
def test_device_pool_without_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version is None
17 changes: 16 additions & 1 deletion tiktorch/server/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
from collections import defaultdict
from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -60,6 +60,14 @@ def devices(self) -> List[IDevice]:


class IDevicePool(abc.ABC):
@property
@abc.abstractmethod
def cuda_version(self) -> Optional[str]:
"""
Returns CUDA version if available
"""
...

@abc.abstractmethod
def list_devices(self) -> List[IDevice]:
"""
Expand Down Expand Up @@ -112,6 +120,13 @@ def __init__(self):
self.__device_ids_by_lease_id = defaultdict(list)
self.__lock = threading.Lock()

@property
def cuda_version(self) -> Optional[str]:
if torch.cuda.is_available():
return torch.version.cuda # type: ignore
else:
return None

def list_devices(self) -> List[IDevice]:
with self.__lock:
ids = ["cpu"]
Expand Down
11 changes: 10 additions & 1 deletion tiktorch/server/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import threading
from concurrent import futures
from typing import Optional
Expand All @@ -16,6 +17,13 @@


def _print_available_devices(device_pool: IDevicePool) -> None:
cuda = device_pool.cuda_version
print()
print("CUDA version:", cuda or "not available")
for env_var in ["CUDA_PATH", "CUDA_HOME", "CUDA_VISIBLE_DEVICES"]:
print(env_var, os.getenv(env_var, None))

print()
print("Available devices:")
for device in device_pool.list_devices():
print(f" * {device.id}")
Expand Down Expand Up @@ -47,19 +55,20 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou
inference_svc = InferenceServicer(device_pool, SessionManager(), data_store)
fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout)
data_svc = DataStoreServicer(data_store)
_print_available_devices(device_pool)

inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server)
inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server)
data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server)

acquired_port = server.add_insecure_port(f"{host}:{port}")
print()
print(f"Starting server on {host}:{acquired_port}")
if connection_file_path:
print(f"Writing connection data to {connection_file_path}")
with open(connection_file_path, "w") as conn_file:
json.dump({"addr": host, "port": acquired_port}, conn_file)

_print_available_devices(device_pool)
server.start()

done_evt.wait()
Expand Down

0 comments on commit 6ea0ba5

Please sign in to comment.