Skip to content

Commit

Permalink
Add new mila login command
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 11, 2024
1 parent 20e990f commit ddbc25e
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 5 deletions.
12 changes: 12 additions & 0 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,16 @@
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from collections.abc import Sequence
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any
from urllib.parse import urlencode

import questionary as qn
import rich.logging
from typing_extensions import TypedDict

from milatools.cli.login import login
from milatools.utils.remote_v2 import SSH_CONFIG_FILE
from milatools.utils.vscode_utils import (
sync_vscode_extensions_with_hostnames,
)
Expand Down Expand Up @@ -150,6 +153,15 @@ def mila():

init_parser.set_defaults(function=init)

# ----- mila login ------
login_parser = subparsers.add_parser(
"login",
help="Sets up reusable SSH connections to the entries of the SSH config.",
formatter_class=SortingHelpFormatter,
)
login_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE)
login_parser.set_defaults(function=login)

# ----- mila forward ------

forward_parser = subparsers.add_parser(
Expand Down
47 changes: 47 additions & 0 deletions milatools/cli/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

import asyncio
from pathlib import Path

from paramiko import SSHConfig

from milatools.cli import console
from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2


async def login(
ssh_config_path: Path = SSH_CONFIG_FILE,
) -> list[RemoteV2]:
"""Logs in and sets up reusable SSH connections to all the hosts in the SSH config.
Returns the list of remotes where the connection was successfully established.
"""
ssh_config = SSHConfig.from_path(str(ssh_config_path.expanduser()))
potential_clusters = [
host
for host in ssh_config.get_hostnames()
if not any(c in host for c in ["*", "?", "!"])
]
# take out entries like `mila-cpu` that have a proxy and remote command.
potential_clusters = [
hostname
for hostname in potential_clusters
if not (
(config := ssh_config.lookup(hostname)).get("proxycommand")
and config.get("remotecommand")
)
]
remotes = await asyncio.gather(
*(
RemoteV2.connect(hostname, ssh_config_path=ssh_config_path)
for hostname in potential_clusters
),
return_exceptions=True,
)
remotes = [remote for remote in remotes if isinstance(remote, RemoteV2)]
console.log(f"Successfully connected to {[remote.hostname for remote in remotes]}")
return remotes


if __name__ == "__main__":
asyncio.run(login())
6 changes: 4 additions & 2 deletions tests/cli/test_commands/test_help_mila_.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
{docs,intranet,init,login,forward,code,sync,serve} ...

Tools to connect to and interact with the Mila cluster. Cluster documentation:
https://docs.mila.quebec/

positional arguments:
{docs,intranet,init,forward,code,sync,serve}
{docs,intranet,init,login,forward,code,sync,serve}
docs Open the Mila cluster documentation.
intranet Open the Mila intranet in a browser.
init Set up your configuration and credentials.
login Sets up reusable SSH connections to the entries of the
SSH config.
forward Forward a port on a compute node to your local
machine.
code Open a remote VSCode session on a compute node.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
{docs,intranet,init,login,forward,code,sync,serve} ...
mila: error: the following arguments are required: <command>
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'forward', 'code', 'sync', 'serve')
{docs,intranet,init,login,forward,code,sync,serve} ...
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'login', 'forward', 'code', 'sync', 'serve')
38 changes: 38 additions & 0 deletions tests/cli/test_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import textwrap
from logging import getLogger as get_logger
from pathlib import Path

import pytest

from milatools.cli.login import login
from milatools.utils.remote_v2 import SSH_CACHE_DIR, RemoteV2

from .common import requires_ssh_to_localhost

logger = get_logger(__name__)


@requires_ssh_to_localhost
@pytest.mark.asyncio
async def test_login(tmp_path: Path): # ssh_config_file: Path):
assert SSH_CACHE_DIR.exists()
ssh_config_path = tmp_path / "ssh_config"
ssh_config_path.write_text(
textwrap.dedent(
"""\
Host foo
hostname localhost
Host bar
hostname localhost
"""
)
+ "\n"
)

# Should create a connection to every host in the ssh config file.
remotes = await login(ssh_config_path=ssh_config_path)
assert all(isinstance(remote, RemoteV2) for remote in remotes)
assert set(remote.hostname for remote in remotes) == {"foo", "bar"}
for remote in remotes:
logger.info(f"Removing control socket at {remote.control_path}")
remote.control_path.unlink()

0 comments on commit ddbc25e

Please sign in to comment.