Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve startup log #183

Merged
merged 3 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@ To install tiktorch and start server run:
conda create -n tiktorch-server-env -c ilastik-forge -c conda-forge -c pytorch tiktorch

conda activate tiktorch-server-env

```
To run server locally use
```
tiktorch-server
```
To be able to connect to remote machine use (this will bind to all available addresses)
```
tiktorch-server --addr 0.0.0.0
```

## Development environment

Expand Down
17 changes: 17 additions & 0 deletions tests/test_server/test_device_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import torch.cuda
import torch.version

from tiktorch.server.device_pool import TorchDevicePool


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
def test_device_pool_with_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version == torch.version.cuda


@pytest.mark.skipif(torch.cuda.is_available(), reason="cuda is available")
def test_device_pool_without_cuda():
device_pool = TorchDevicePool()
assert device_pool.cuda_version is None
17 changes: 16 additions & 1 deletion tiktorch/server/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
from collections import defaultdict
from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -60,6 +60,14 @@ def devices(self) -> List[IDevice]:


class IDevicePool(abc.ABC):
@property
@abc.abstractmethod
def cuda_version(self) -> Optional[str]:
"""
Returns CUDA version if available
"""
...

@abc.abstractmethod
def list_devices(self) -> List[IDevice]:
"""
Expand Down Expand Up @@ -112,6 +120,13 @@ def __init__(self):
self.__device_ids_by_lease_id = defaultdict(list)
self.__lock = threading.Lock()

@property
def cuda_version(self) -> Optional[str]:
if torch.cuda.is_available():
return torch.version.cuda # type: ignore
else:
return None

def list_devices(self) -> List[IDevice]:
with self.__lock:
ids = ["cpu"]
Expand Down
22 changes: 20 additions & 2 deletions tiktorch/server/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import threading
from concurrent import futures
from typing import Optional
Expand All @@ -7,14 +8,28 @@

from tiktorch.proto import data_store_pb2_grpc, inference_pb2_grpc
from tiktorch.server.data_store import DataStore
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.device_pool import IDevicePool, TorchDevicePool
from tiktorch.server.session_manager import SessionManager

from .data_store_servicer import DataStoreServicer
from .flight_control_servicer import FlightControlServicer
from .inference_servicer import InferenceServicer


def _print_available_devices(device_pool: IDevicePool) -> None:
cuda = device_pool.cuda_version
print()
print("CUDA version:", cuda or "not available")
for env_var, value in os.environ.items():
if env_var.startswith("CUDA_"):
print(env_var, value.strip() or "<empty>")

print()
print("Available devices:")
for device in device_pool.list_devices():
print(f" * {device.id}")


def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeout: Optional[float] = None):
"""
Starts grpc server on given host and port and writes connection details to json file
Expand All @@ -37,15 +52,18 @@ def serve(host, port, *, connection_file_path: Optional[str] = None, kill_timeou

data_store = DataStore()

inference_svc = InferenceServicer(TorchDevicePool(), SessionManager(), data_store)
device_pool = TorchDevicePool()
inference_svc = InferenceServicer(device_pool, SessionManager(), data_store)
fligh_svc = FlightControlServicer(done_evt=done_evt, kill_timeout=kill_timeout)
data_svc = DataStoreServicer(data_store)
_print_available_devices(device_pool)

inference_pb2_grpc.add_InferenceServicer_to_server(inference_svc, server)
inference_pb2_grpc.add_FlightControlServicer_to_server(fligh_svc, server)
data_store_pb2_grpc.add_DataStoreServicer_to_server(data_svc, server)

acquired_port = server.add_insecure_port(f"{host}:{port}")
print()
print(f"Starting server on {host}:{acquired_port}")
if connection_file_path:
print(f"Writing connection data to {connection_file_path}")
Expand Down