diff --git a/torchx/runner/config.py b/torchx/runner/config.py index 53c6ccc9f..72da133f5 100644 --- a/torchx/runner/config.py +++ b/torchx/runner/config.py @@ -268,6 +268,13 @@ def dump( val = ";".join(opt.default) else: val = _NONE + elif opt.opt_type == Dict[str, str]: + # deal with empty or None default lists + if opt.default: + # pyre-ignore[6] opt.default type checked already as Dict[str, str] + val = ";".join([f"{k}={v}" for k, v in opt.default.items()]) + else: + val = _NONE else: val = f"{opt.default}" @@ -527,6 +534,11 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None: cfg[name] = config.getboolean(section, name) elif runopt.opt_type is List[str]: cfg[name] = value.split(";") + elif runopt.opt_type is Dict[str, str]: + cfg[name] = { + s.split(":", 1)[0]: s.split(":", 1)[1] + for s in value.replace(",", ";").split(";") + } else: # pyre-ignore[29] cfg[name] = runopt.opt_type(value) diff --git a/torchx/runner/test/config_test.py b/torchx/runner/test/config_test.py index c9730d9cf..fe86dc007 100644 --- a/torchx/runner/test/config_test.py +++ b/torchx/runner/test/config_test.py @@ -105,6 +105,18 @@ def _run_opts(self) -> runopts: default=None, help="a None list option", ) + opts.add( + "d", + type_=Dict[str, str], + default={"foo": "bar"}, + help="a dict option", + ) + opts.add( + "d_none", + type_=Dict[str, str], + default=None, + help="a None dict option", + ) opts.add( "empty", type_=str, diff --git a/torchx/schedulers/docker_scheduler.py b/torchx/schedulers/docker_scheduler.py index d7dfe9470..a6d75071a 100644 --- a/torchx/schedulers/docker_scheduler.py +++ b/torchx/schedulers/docker_scheduler.py @@ -123,6 +123,7 @@ def ensure_network(client: Optional["DockerClient"] = None) -> None: class DockerOpts(TypedDict, total=False): copy_env: Optional[List[str]] + env: Optional[Dict[str, str]] class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]): @@ -217,6 +218,10 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo for k in keys: default_env[k] = os.environ[k] + env = cfg.get("env") + if env: + default_env.update(env) + app_id = make_unique(app.name) req = DockerJob(app_id=app_id, containers=[]) rank0_name = f"{app_id}-{app.roles[0].name}-0" @@ -359,6 +364,12 @@ def _run_opts(self) -> runopts: default=None, help="list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*", ) + opts.add( + "env", + type_=Dict[str, str], + default=None, + help="environment varibles to be passed to the run (e.g. ENV1:v1,ENV2:v2,ENV3:v3)", + ) return opts def _get_app_state(self, container: "Container") -> AppState: diff --git a/torchx/schedulers/test/docker_scheduler_test.py b/torchx/schedulers/test/docker_scheduler_test.py index 929444fff..e69fa6d99 100644 --- a/torchx/schedulers/test/docker_scheduler_test.py +++ b/torchx/schedulers/test/docker_scheduler_test.py @@ -185,6 +185,21 @@ def test_copy_env(self) -> None: }, ) + def test_env(self) -> None: + app = _test_app() + cfg = DockerOpts({"env": {"FOO_1": "BAR_1"}}) + with patch("torchx.schedulers.docker_scheduler.make_unique") as make_unique_ctx: + make_unique_ctx.return_value = "app_name_42" + info = self.scheduler._submit_dryrun(app, cfg) + self.assertEqual( + info.request.containers[0].kwargs["environment"], + { + "FOO": "bar", + "FOO_1": "BAR_1", + "TORCHX_RANK0_HOST": "app_name_42-trainer-0", + }, + ) + if has_docker(): # These are the live tests that require a local docker instance. diff --git a/torchx/specs/api.py b/torchx/specs/api.py index fd8761585..505986716 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -639,11 +639,11 @@ def __init__(self, status: AppStatus, *args: object) -> None: self.status = status -# valid run cfg values; only support primitives (str, int, float, bool, List[str]) +# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str]) # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly # in isinstance(). Should replace with that. # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type -CfgVal = Union[str, int, float, bool, List[str], None] +CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None] T = TypeVar("T") @@ -757,6 +757,10 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool: except TypeError: if isinstance(obj, list): return all(isinstance(e, str) for e in obj) + elif isinstance(obj, dict): + return all( + isinstance(k, str) and isinstance(v, str) for k, v in obj.items() + ) else: return False @@ -865,6 +869,11 @@ def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal: # lists may be ; or , delimited # also deal with trailing "," by removing empty strings return [v for v in value.replace(";", ",").split(",") if v] + elif opt_type == Dict[str, str]: + return { + s.split(":", 1)[0]: s.split(":", 1)[1] + for s in value.replace(";", ",").split(",") + } else: # pyre-ignore[19] return opt_type(value) diff --git a/torchx/specs/test/api_test.py b/torchx/specs/test/api_test.py index 6a74fe421..af8e34b49 100644 --- a/torchx/specs/test/api_test.py +++ b/torchx/specs/test/api_test.py @@ -431,6 +431,7 @@ def test_cfg_from_str(self) -> None: opts = runopts() opts.add("K", type_=List[str], help="a list opt", default=[]) opts.add("J", type_=str, help="a str opt", required=True) + opts.add("E", type_=Dict[str, str], help="a dict opt", default=[]) self.assertDictEqual({}, opts.cfg_from_str("")) self.assertDictEqual({}, opts.cfg_from_str("UNKWN=b")) @@ -455,6 +456,9 @@ def test_cfg_from_str(self) -> None: self.assertDictEqual( {"K": ["a"], "J": "d"}, opts.cfg_from_str("J=d,K=a,UNKWN=e") ) + self.assertDictEqual( + {"E": {"f": "b", "F": "B"}}, opts.cfg_from_str("E=f:b,F:B") + ) def test_resolve_from_str(self) -> None: opts = runopts() @@ -489,6 +493,10 @@ def test_runopts_is_type(self) -> None: self.assertFalse(runopts.is_type(None, List[str])) self.assertTrue(runopts.is_type([], List[str])) self.assertTrue(runopts.is_type(["a", "b"], List[str])) + # List[str] + self.assertFalse(runopts.is_type(None, Dict[str, str])) + self.assertTrue(runopts.is_type({}, Dict[str, str])) + self.assertTrue(runopts.is_type({"foo": "bar", "fee": "bez"}, Dict[str, str])) def test_runopts_iter(self) -> None: runopts = self.get_runopts()