Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes #210

Merged
merged 18 commits into from
Mar 15, 2024
12 changes: 8 additions & 4 deletions .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ name: Integration tests

on:
workflow_call:
schedule:
# run this workflow every day at 2:00 AM UTC
- cron: '0 2 * * *'
inputs:
BRANCH:
type: string
description: Git branch to be used in run
default: main

jobs:
examples:
Expand Down Expand Up @@ -38,6 +40,8 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ inputs.BRANCH }}

- name: Start Goth
env:
Expand Down Expand Up @@ -88,7 +92,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-example-${{ matrix.example_name }}
name: logs-example-${{ inputs.BRANCH }}-${{ matrix.example_name }}
path: |
/root/.local/share/ray_on_golem/webserver_debug.log
/root/.local/share/ray_on_golem/yagna.log
Expand Down
17 changes: 17 additions & 0 deletions .github/workflows/on_schedule.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: On schedule

on:
schedule:
# run this workflow every day at 2:00 AM UTC
- cron: '0 2 * * *'

jobs:
nightly_tests:
name: Nightly tests
strategy:
fail-fast: false
matrix:
branch: [main, develop]
uses: ./.github/workflows/integration_tests.yml
with:
BRANCH: ${{ matrix.branch }}
24 changes: 15 additions & 9 deletions ray_on_golem/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ray_on_golem.client.exceptions import RayOnGolemClientError, RayOnGolemClientValidationError
from ray_on_golem.server import models, settings
from ray_on_golem.server.models import CreateClusterResponseData
from ray_on_golem.server.models import BootstrapClusterResponseData

TResponseModel = TypeVar("TResponseModel")

Expand All @@ -22,15 +22,19 @@ def __init__(self, port: int) -> None:
self.base_url = URL("http://127.0.0.1").with_port(self.port)
self._session = requests.Session()

def create_cluster(
def bootstrap_cluster(
self,
cluster_config: Dict[str, Any],
) -> CreateClusterResponseData:
provider_config: Dict[str, Any],
cluster_name: str,
) -> BootstrapClusterResponseData:
return self._make_request(
url=settings.URL_CREATE_CLUSTER,
request_data=models.CreateClusterRequestData(**cluster_config),
response_model=models.CreateClusterResponseData,
error_message="Couldn't create cluster",
url=settings.URL_BOOTSTRAP_CLUSTER,
request_data=models.BootstrapClusterRequestData(
provider_config=provider_config,
cluster_name=cluster_name,
),
response_model=models.BootstrapClusterResponseData,
error_message="Couldn't bootstrap cluster",
)

def request_nodes(
Expand Down Expand Up @@ -222,7 +226,9 @@ def _make_request(
data=request_data.json() if request_data else None,
)
except requests.ConnectionError as e:
raise RayOnGolemClientError(f"{error_message or f'Connection failed: {url}'}: {e}")
raise RayOnGolemClientError(
"{}: {}".format(error_message or f"Connection failed: {url}", e)
)

if response.status_code != 200:
raise RayOnGolemClientError(
Expand Down
14 changes: 9 additions & 5 deletions ray_on_golem/provider/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,20 @@ def __init__(self, provider_config: Dict[str, Any], cluster_name: str):
provider_parameters = self._map_ssh_config(provider_parameters)
self._payment_network = provider_parameters["payment_network"].lower().strip()

cluster_creation_response = self._ray_on_golem_client.create_cluster(provider_parameters)
cluster_bootstrap_response = self._ray_on_golem_client.bootstrap_cluster(
provider_parameters, cluster_name
)

self._wallet_address = cluster_creation_response.wallet_address
self._is_cluster_just_created = cluster_creation_response.is_cluster_just_created
self._wallet_address = cluster_bootstrap_response.wallet_address
self._is_cluster_just_created = cluster_bootstrap_response.is_cluster_just_created

self._print_mainnet_onboarding_message(
cluster_creation_response.yagna_payment_status_output
cluster_bootstrap_response.yagna_payment_status_output
)

wallet_glm_amount = float(cluster_creation_response.yagna_payment_status.get("amount", "0"))
wallet_glm_amount = float(
cluster_bootstrap_response.yagna_payment_status.get("amount", "0")
)
if not wallet_glm_amount:
cli_logger.abort("You don't seem to have any GLM tokens on your Golem wallet.")

Expand Down
7 changes: 4 additions & 3 deletions ray_on_golem/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray_on_golem.server.services import GolemService, RayService, YagnaService
from ray_on_golem.server.settings import (
DEFAULT_DATADIR,
RAY_ON_GOLEM_SHUTDOWN_TIMEOUT,
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT,
WEBSOCAT_PATH,
YAGNA_PATH,
get_datadir,
Expand Down Expand Up @@ -65,7 +65,7 @@ def main(port: int, self_shutdown: bool, registry_stats: bool, datadir: Path):
app,
port=app["port"],
print=None,
shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_TIMEOUT.total_seconds(),
shutdown_timeout=RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT.total_seconds(),
)
except Exception:
logger.info("Server unexpectedly died, bye!")
Expand Down Expand Up @@ -124,7 +124,8 @@ async def startup_print(app: web.Application) -> None:
async def shutdown_print(app: web.Application) -> None:
print("") # explicit new line to console to visually better handle ^C
logger.info(
"Waiting up to `%s` for current connections to close...", RAY_ON_GOLEM_SHUTDOWN_TIMEOUT
"Waiting up to `%s` for current connections to close...",
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT,
)


Expand Down
7 changes: 4 additions & 3 deletions ray_on_golem/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ class ProviderConfigData(BaseModel):
ssh_user: str


class CreateClusterRequestData(ProviderConfigData):
pass
class BootstrapClusterRequestData(BaseModel):
provider_config: ProviderConfigData
cluster_name: str


class CreateClusterResponseData(BaseModel):
class BootstrapClusterResponseData(BaseModel):
is_cluster_just_created: bool
wallet_address: str
yagna_payment_status_output: str
Expand Down
9 changes: 8 additions & 1 deletion ray_on_golem/server/services/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
self._datadir = datadir

self._provider_config: Optional[ProviderConfigData] = None
self._cluster_name: Optional[str] = None
self._wallet_address: Optional[str] = None

self._nodes: Dict[NodeId, Node] = {}
Expand All @@ -78,11 +79,17 @@ async def shutdown(self) -> None:
logger.info("Stopping RayService done")

async def create_cluster(
self, provider_config: ProviderConfigData
self, provider_config: ProviderConfigData, cluster_name: str
) -> Tuple[bool, str, str, Dict]:
is_cluster_just_created = self._provider_config is None

if not is_cluster_just_created and self._cluster_name != cluster_name:
raise RayServiceError(
f"Webserver is running only for `{self._cluster_name}` cluster, not for `{cluster_name}`!"
)

self._provider_config = provider_config
self._cluster_name = cluster_name

self._ssh_private_key_path = Path(provider_config.ssh_private_key)
self._ssh_public_key_path = self._ssh_private_key_path.with_suffix(".pub")
Expand Down
12 changes: 10 additions & 2 deletions ray_on_golem/server/services/yagna.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from asyncio.subprocess import Process
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, Optional

Expand Down Expand Up @@ -197,6 +197,7 @@ async def run_payment_fund(self, network: str, driver: str) -> Dict:
"--driver",
driver,
"--json",
timeout=timedelta(seconds=30),
)
)

Expand Down Expand Up @@ -224,7 +225,14 @@ async def run_payment_fund(self, network: str, driver: str) -> Dict:

async def fetch_payment_status(self, network: str, driver: str) -> str:
output = await run_subprocess_output(
self._yagna_path, "payment", "status", "--network", network, "--driver", driver
self._yagna_path,
"payment",
"status",
"--network",
network,
"--driver",
driver,
timeout=timedelta(seconds=30),
)
return output.decode()

Expand Down
5 changes: 4 additions & 1 deletion ray_on_golem/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
# how long a shutdown request will wait until the webserver shutdown is initiated
RAY_ON_GOLEM_SHUTDOWN_DELAY = timedelta(seconds=60)

# how long we wait for the webserver shutdown pending connection to complete
RAY_ON_GOLEM_SHUTDOWN_CONNECTIONS_TIMEOUT = timedelta(seconds=5)

# how long we wait for the webserver shutdown to complete
RAY_ON_GOLEM_SHUTDOWN_TIMEOUT = timedelta(seconds=60)

Expand All @@ -31,7 +34,7 @@
RAY_ON_GOLEM_PID_FILENAME = "ray_on_golem.pid"

URL_STATUS = "/"
URL_CREATE_CLUSTER = "/create_cluster"
URL_BOOTSTRAP_CLUSTER = "/bootstrap_cluster"
URL_NON_TERMINATED_NODES = "/non_terminated_nodes"
URL_IS_RUNNING = "/is_running"
URL_IS_TERMINATED = "/is_terminated"
Expand Down
28 changes: 17 additions & 11 deletions ray_on_golem/server/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray_on_golem.server import models, settings
from ray_on_golem.server.models import ShutdownState
from ray_on_golem.server.services import RayService
from ray_on_golem.server.services.ray import RayServiceError
from ray_on_golem.utils import raise_graceful_exit
from ray_on_golem.version import get_version

Expand All @@ -18,7 +19,7 @@
def reject_if_shutting_down(func):
async def wrapper(request: web.Request) -> web.Response:
if request.app.get("shutting_down"):
return web.HTTPBadRequest(text="Action not allowed while server is shutting down!")
return web.HTTPBadRequest(reason="Action not allowed while server is shutting down!")

return await func(request)

Expand All @@ -43,21 +44,26 @@ async def status(request: web.Request) -> web.Response:
)


@routes.post(settings.URL_CREATE_CLUSTER)
async def create_cluster(request: web.Request) -> web.Response:
@routes.post(settings.URL_BOOTSTRAP_CLUSTER)
async def bootstrap_cluster(request: web.Request) -> web.Response:
ray_service: RayService = request.app["ray_service"]

request_data = models.CreateClusterRequestData.parse_raw(await request.text())
request_data = models.BootstrapClusterRequestData.parse_raw(await request.text())

(
is_cluster_just_created,
wallet_address,
yagna_payment_status_output,
yagna_payment_status,
) = await ray_service.create_cluster(provider_config=request_data)
try:
(
is_cluster_just_created,
wallet_address,
yagna_payment_status_output,
yagna_payment_status,
) = await ray_service.create_cluster(
request_data.provider_config, request_data.cluster_name
)
except RayServiceError as e:
raise web.HTTPBadRequest(reason=str(e))

return json_response(
models.CreateClusterResponseData(
models.BootstrapClusterResponseData(
is_cluster_just_created=is_cluster_just_created,
wallet_address=wallet_address,
yagna_payment_status_output=yagna_payment_status_output,
Expand Down
17 changes: 14 additions & 3 deletions ray_on_golem/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import os
from asyncio.subprocess import Process
from collections import deque
from datetime import timedelta
from pathlib import Path
from typing import Dict
from typing import Dict, Optional

from aiohttp.web_runner import GracefulExit

Expand All @@ -27,14 +28,24 @@ async def run_subprocess(
return process


async def run_subprocess_output(*args) -> bytes:
async def run_subprocess_output(*args, timeout: Optional[timedelta] = None) -> bytes:
process = await run_subprocess(
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

stdout, stderr = await process.communicate()
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout.total_seconds() if timeout else None,
)
except asyncio.TimeoutError as e:
if process.returncode is None:
process.kill()
await process.wait()

raise RayOnGolemError(f"Process could not finish in timeout of {timeout}!") from e

if process.returncode != 0:
raise RayOnGolemError(
Expand Down