Skip to content

Commit

Permalink
schedulers/docker: Add support for setting environment variables
Browse files Browse the repository at this point in the history
  • Loading branch information
asuta274 committed Mar 23, 2024
1 parent a4ff1b1 commit bc49d4a
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 2 deletions.
12 changes: 12 additions & 0 deletions torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ def dump(
val = ";".join(opt.default)
else:
val = _NONE
elif opt.opt_type == Dict[str, str]:
# deal with empty or None default lists
if opt.default:
# pyre-ignore[6] opt.default type checked already as Dict[str, str]
val = ";".join([f"{k}={v}" for k, v in opt.default.items()])
else:
val = _NONE
else:
val = f"{opt.default}"

Expand Down Expand Up @@ -527,6 +534,11 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
cfg[name] = config.getboolean(section, name)
elif runopt.opt_type is List[str]:
cfg[name] = value.split(";")
elif runopt.opt_type is Dict[str, str]:
cfg[name] = {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(",", ";").split(";")
}
else:
# pyre-ignore[29]
cfg[name] = runopt.opt_type(value)
12 changes: 12 additions & 0 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def _run_opts(self) -> runopts:
default=None,
help="a None list option",
)
opts.add(
"d",
type_=Dict[str, str],
default={"foo": "bar"},
help="a dict option",
)
opts.add(
"d_none",
type_=Dict[str, str],
default=None,
help="a None dict option",
)
opts.add(
"empty",
type_=str,
Expand Down
11 changes: 11 additions & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def ensure_network(client: Optional["DockerClient"] = None) -> None:

class DockerOpts(TypedDict, total=False):
copy_env: Optional[List[str]]
env: Optional[Dict[str, str]]


class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
Expand Down Expand Up @@ -217,6 +218,10 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo
for k in keys:
default_env[k] = os.environ[k]

env = cfg.get("env")
if env:
default_env.update(env)

app_id = make_unique(app.name)
req = DockerJob(app_id=app_id, containers=[])
rank0_name = f"{app_id}-{app.roles[0].name}-0"
Expand Down Expand Up @@ -359,6 +364,12 @@ def _run_opts(self) -> runopts:
default=None,
help="list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*",
)
opts.add(
"env",
type_=Dict[str, str],
default=None,
help="environment varibles to be passed to the run (e.g. ENV1:v1,ENV2:v2,ENV3:v3)",
)
return opts

def _get_app_state(self, container: "Container") -> AppState:
Expand Down
15 changes: 15 additions & 0 deletions torchx/schedulers/test/docker_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ def test_copy_env(self) -> None:
},
)

def test_env(self) -> None:
app = _test_app()
cfg = DockerOpts({"env": {"FOO_1": "BAR_1"}})
with patch("torchx.schedulers.docker_scheduler.make_unique") as make_unique_ctx:
make_unique_ctx.return_value = "app_name_42"
info = self.scheduler._submit_dryrun(app, cfg)
self.assertEqual(
info.request.containers[0].kwargs["environment"],
{
"FOO": "bar",
"FOO_1": "BAR_1",
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
},
)


if has_docker():
# These are the live tests that require a local docker instance.
Expand Down
13 changes: 11 additions & 2 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,11 @@ def __init__(self, status: AppStatus, *args: object) -> None:
self.status = status


# valid run cfg values; only support primitives (str, int, float, bool, List[str])
# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
# TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
# in isinstance(). Should replace with that.
# see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
CfgVal = Union[str, int, float, bool, List[str], None]
CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]


T = TypeVar("T")
Expand Down Expand Up @@ -757,6 +757,10 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool:
except TypeError:
if isinstance(obj, list):
return all(isinstance(e, str) for e in obj)
elif isinstance(obj, dict):
return all(
isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
)
else:
return False

Expand Down Expand Up @@ -865,6 +869,11 @@ def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
# lists may be ; or , delimited
# also deal with trailing "," by removing empty strings
return [v for v in value.replace(";", ",").split(",") if v]
elif opt_type == Dict[str, str]:
return {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(";", ",").split(",")
}
else:
# pyre-ignore[19]
return opt_type(value)
Expand Down
8 changes: 8 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def test_cfg_from_str(self) -> None:
opts = runopts()
opts.add("K", type_=List[str], help="a list opt", default=[])
opts.add("J", type_=str, help="a str opt", required=True)
opts.add("E", type_=Dict[str, str], help="a dict opt", default=[])

self.assertDictEqual({}, opts.cfg_from_str(""))
self.assertDictEqual({}, opts.cfg_from_str("UNKWN=b"))
Expand All @@ -455,6 +456,9 @@ def test_cfg_from_str(self) -> None:
self.assertDictEqual(
{"K": ["a"], "J": "d"}, opts.cfg_from_str("J=d,K=a,UNKWN=e")
)
self.assertDictEqual(
{"E": {"f": "b", "F": "B"}}, opts.cfg_from_str("E=f:b,F:B")
)

def test_resolve_from_str(self) -> None:
opts = runopts()
Expand Down Expand Up @@ -489,6 +493,10 @@ def test_runopts_is_type(self) -> None:
self.assertFalse(runopts.is_type(None, List[str]))
self.assertTrue(runopts.is_type([], List[str]))
self.assertTrue(runopts.is_type(["a", "b"], List[str]))
# List[str]
self.assertFalse(runopts.is_type(None, Dict[str, str]))
self.assertTrue(runopts.is_type({}, Dict[str, str]))
self.assertTrue(runopts.is_type({"foo": "bar", "fee": "bez"}, Dict[str, str]))

def test_runopts_iter(self) -> None:
runopts = self.get_runopts()
Expand Down

0 comments on commit bc49d4a

Please sign in to comment.