Skip to content

Commit

Permalink
Merge pull request #183 from m-novikov/m-novikov-patch-0
Browse files Browse the repository at this point in the history
Improve startup log
  • Loading branch information
constantinpape committed Jul 19, 2021
2 parents 2a7a0cd + 2c05cdd commit 7d9e835
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ To install tiktorch and start server run:
conda create -n tiktorch-server-env -c ilastik-forge -c conda-forge -c pytorch tiktorch
conda activate tiktorch-server-env
```
To run server locally use
```
tiktorch-server
```
To be able to connect to remote machine use (this will bind to all available addresses)
```
tiktorch-server --addr 0.0.0.0
```

## Development environment

Expand Down
17 changes: 17 additions & 0 deletions tests/test_server/test_device_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch.cuda
import torch.version

from tiktorch.server.device_pool import TorchDevicePool


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
def test_device_pool_with_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version == torch.version.cuda


@pytest.mark.skipif(torch.cuda.is_available(), reason="cuda is available")
def test_device_pool_without_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version is None
17 changes: 16 additions & 1 deletion tiktorch/server/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
from collections import defaultdict
from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -60,6 +60,14 @@ def devices(self) -> List[IDevice]:


class IDevicePool(abc.ABC):
@property
@abc.abstractmethod
def cuda_version(self) -> Optional[str]:
"""
Returns CUDA version if available
"""
...

@abc.abstractmethod
def list_devices(self) -> List[IDevice]:
"""
Expand Down Expand Up @@ -112,6 +120,13 @@ def __init__(self):
self.__device_ids_by_lease_id = defaultdict(list)
self.__lock = threading.Lock()

@property
def cuda_version(self) -> Optional[str]:
if torch.cuda.is_available():
return torch.version.cuda # type: ignore
else:
return None

def list_devices(self) -> List[IDevice]:
with self.__lock:
ids = ["cpu"]
Expand Down
22 changes: 20 additions & 2 deletions tiktorch/server/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import threading
from concurrent import futures
from typing import Optional
Expand All @@ -7,14 +8,28 @@

from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc
from tiktorch.server.data_store import DataStore
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.device_pool import IDevicePool, TorchDevicePool
from tiktorch.server.session_manager import SessionManager

from .data_store_servicer import DataStoreServicer
from .flight_control_servicer import FlightControlServicer
from .inference_servicer import InferenceServicer


def _print_available_devices(device_pool: IDevicePool) -> None:
cuda = device_pool.cuda_version
print()
print("CUDA version:", cuda or "not available")
for env_var, value in os.environ.items():
if env_var.startswith("CUDA_"):
print(env_var, value.strip() or "<empty>")

print()
print("Available devices:")
for device in device_pool.list_devices():
print(f" * {device.id}")


def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeout: Optional[float] = None):
"""
Starts grpc server on given host and port and writes connection details to json file
Expand All @@ -37,15 +52,18 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou

data_store = DataStore()

inference_svc = InferenceServicer(TorchDevicePool(), SessionManager(), data_store)
device_pool = TorchDevicePool()
inference_svc = InferenceServicer(device_pool, SessionManager(), data_store)
fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout)
data_svc = DataStoreServicer(data_store)
_print_available_devices(device_pool)

inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server)
inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server)
data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server)

acquired_port = server.add_insecure_port(f"{host}:{port}")
print()
print(f"Starting server on {host}:{acquired_port}")
if connection_file_path:
print(f"Writing connection data to {connection_file_path}")
Expand Down

0 comments on commit 7d9e835

Please sign in to comment.