Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

infer env params - rfc #52

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions launch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
SyncEndpoint,
)
from launch.request_validation import validate_task_request
from launch.utils import trim_kwargs
from launch.utils import infer_env_params, trim_kwargs

DEFAULT_NETWORK_TIMEOUT_SEC = 120

Expand Down Expand Up @@ -198,9 +198,10 @@ def create_model_bundle_from_dirs(
model_bundle_name: str,
base_paths: List[str],
requirements_path: str,
env_params: Dict[str, str],
load_predict_fn_module_path: str,
load_model_fn_module_path: str,
env_params: Optional[Dict[str, str]],
env_selector: Optional[str],
app_config: Optional[Union[Dict[str, Any], str]] = None,
) -> ModelBundle:
"""
Expand Down Expand Up @@ -275,6 +276,9 @@ def create_model_bundle_from_dirs(
with open(requirements_path, "r", encoding="utf-8") as req_f:
requirements = req_f.read().splitlines()

if env_params is None:
env_params = infer_env_params(env_selector)

tmpdir = tempfile.mkdtemp()
try:
zip_path = os.path.join(tmpdir, "bundle.zip")
Expand Down Expand Up @@ -331,7 +335,8 @@ def create_model_bundle_from_dirs(
def create_model_bundle( # pylint: disable=too-many-statements
self,
model_bundle_name: str,
env_params: Dict[str, str],
env_params: Optional[Dict[str, str]],
env_selector: Optional[str],
*,
load_predict_fn: Optional[
Callable[[LaunchModel_T], Callable[[Any], Any]]
Expand Down Expand Up @@ -435,6 +440,9 @@ def create_model_bundle( # pylint: disable=too-many-statements
)
# TODO should we try to catch when people intentionally pass both model and load_model_fn as None?

if env_params is None:
env_params = infer_env_params(env_selector)

if requirements is None:
# TODO explore: does globals() actually work as expected? Should we use globals_copy instead?
requirements_inferred = find_packages_from_imports(globals())
Expand Down
44 changes: 43 additions & 1 deletion launch/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict
from typing import Any, Dict, Optional


def trim_kwargs(kwargs_dict: Dict[Any, Any]):
Expand All @@ -7,3 +7,45 @@ def trim_kwargs(kwargs_dict: Dict[Any, Any]):
"""
dict_copy = {k: v for k, v in kwargs_dict.items() if v is not None}
return dict_copy


def infer_env_params(env_selector: Optional[str]):
"""
Returns an env_params dict from the env_selector.

env_selector: str - Either "pytorch" or "tensorflow"
"""
if env_selector == "pytorch":
import torch

try:
ver = torch.__version__.split("+")
torch_version = ver[0]
cuda_version = ver[1][2:] if len(ver) > 1 else "113"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can also default to "latest" if parsing goes awry but i'm hesitant cuz upstream build issues are worse than client side errors

if (
len(cuda_version) < 3
): # we can only parse cuda versions in the double digits
raise ValueError(
"PyTorch version parsing does not support CUDA versions below 10.0"
)
tag = f"{torch_version}-cuda{cuda_version[:2]}.{cuda_version[2:]}-cudnn8-runtime"
syandroo marked this conversation as resolved.
Show resolved Hide resolved
return {
"framework_type": "pytorch",
"pytorch_image_tag": tag,
}
except Exception as e:
raise ValueError(
"Failed to parse correct PyTorch version, try setting your own env_params."
)
elif env_selector == "tensorflow":
import tensorflow as tf

ver = tf.__version__
return {
"framework_type": "tensorflow",
"tensorflow_version": ver,
}
else:
raise ValueError(
"Unsupported env_selector, please set to pytorch or tensorflow, or set your own env_params."
)