Skip to content

Commit

Permalink
schedulers/docker: Add support for setting environment variables (#855)
Browse files Browse the repository at this point in the history
* schedulers/docker: Add support for setting environment variables

* fix lint

* improve doc

* add more test

---------

Co-authored-by: Viet Anh To <[email protected]>
  • Loading branch information
asuta274 and Viet Anh To authored Mar 27, 2024
1 parent 3edf94a commit e9a6957
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 6 deletions.
12 changes: 12 additions & 0 deletions torchx/runner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ def dump(
val = ";".join(opt.default)
else:
val = _NONE
elif opt.opt_type == Dict[str, str]:
# deal with empty or None default lists
if opt.default:
# pyre-ignore[16] opt.default type checked already as Dict[str, str]
val = ";".join([f"{k}:{v}" for k, v in opt.default.items()])
else:
val = _NONE
else:
val = f"{opt.default}"

Expand Down Expand Up @@ -527,6 +534,11 @@ def load(scheduler: str, f: TextIO, cfg: Dict[str, CfgVal]) -> None:
cfg[name] = config.getboolean(section, name)
elif runopt.opt_type is List[str]:
cfg[name] = value.split(";")
elif runopt.opt_type is Dict[str, str]:
cfg[name] = {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(",", ";").split(";")
}
else:
# pyre-ignore[29]
cfg[name] = runopt.opt_type(value)
16 changes: 16 additions & 0 deletions torchx/runner/test/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ def _run_opts(self) -> runopts:
default=None,
help="a None list option",
)
opts.add(
"d",
type_=Dict[str, str],
default={"foo": "bar"},
help="a dict option",
)
opts.add(
"d_none",
type_=Dict[str, str],
default=None,
help="a None dict option",
)
opts.add(
"empty",
type_=str,
Expand All @@ -131,6 +143,8 @@ def _run_opts(self) -> runopts:
s = team_default
i = 50
f = 1.2
d = a:b,c:d
d_none= x:y
"""

_MY_CONFIG = """#
Expand Down Expand Up @@ -356,6 +370,8 @@ def test_apply_default(self, _) -> None:
self.assertEqual("runtime_value", cfg.get("s"))
self.assertEqual(50, cfg.get("i"))
self.assertEqual(1.2, cfg.get("f"))
self.assertEqual({"a": "b", "c": "d"}, cfg.get("d"))
self.assertEqual({"x": "y"}, cfg.get("d_none"))

@patch(
TORCHX_GET_SCHEDULER_FACTORIES,
Expand Down
19 changes: 16 additions & 3 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def ensure_network(client: Optional["DockerClient"] = None) -> None:

class DockerOpts(TypedDict, total=False):
copy_env: Optional[List[str]]
env: Optional[Dict[str, str]]


class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
Expand Down Expand Up @@ -217,6 +218,10 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo
for k in keys:
default_env[k] = os.environ[k]

env = cfg.get("env")
if env:
default_env.update(env)

app_id = make_unique(app.name)
req = DockerJob(app_id=app_id, containers=[])
rank0_name = f"{app_id}-{app.roles[0].name}-0"
Expand Down Expand Up @@ -294,9 +299,9 @@ def _submit_dryrun(self, app: AppDef, cfg: DockerOpts) -> AppDryRunInfo[DockerJo
if resource.memMB >= 0:
# To support PyTorch dataloaders we need to set /dev/shm to
# larger than the 64M default.
c.kwargs["mem_limit"] = c.kwargs["shm_size"] = (
f"{int(resource.memMB)}m"
)
c.kwargs["mem_limit"] = c.kwargs[
"shm_size"
] = f"{int(resource.memMB)}m"
if resource.cpu >= 0:
c.kwargs["nano_cpus"] = int(resource.cpu * 1e9)
if resource.gpu > 0:
Expand Down Expand Up @@ -359,6 +364,14 @@ def _run_opts(self) -> runopts:
default=None,
help="list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*",
)
opts.add(
"env",
type_=Dict[str, str],
default=None,
help="""environment variables to be passed to the run. The separator sign can be eiher comma or semicolon
(e.g. ENV1:v1,ENV2:v2,ENV3:v3 or ENV1:V1;ENV2:V2). Environment variables from env will be applied on top
of the ones from copy_env""",
)
return opts

def _get_app_state(self, container: "Container") -> AppState:
Expand Down
15 changes: 15 additions & 0 deletions torchx/schedulers/test/docker_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,21 @@ def test_copy_env(self) -> None:
},
)

def test_env(self) -> None:
app = _test_app()
cfg = DockerOpts({"env": {"FOO_1": "BAR_1"}})
with patch("torchx.schedulers.docker_scheduler.make_unique") as make_unique_ctx:
make_unique_ctx.return_value = "app_name_42"
info = self.scheduler._submit_dryrun(app, cfg)
self.assertEqual(
info.request.containers[0].kwargs["environment"],
{
"FOO": "bar",
"FOO_1": "BAR_1",
"TORCHX_RANK0_HOST": "app_name_42-trainer-0",
},
)


if has_docker():
# These are the live tests that require a local docker instance.
Expand Down
15 changes: 12 additions & 3 deletions torchx/specs/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,11 @@ def __init__(self, status: AppStatus, *args: object) -> None:
self.status = status


# valid run cfg values; only support primitives (str, int, float, bool, List[str])
# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str])
# TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly
# in isinstance(). Should replace with that.
# see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type
CfgVal = Union[str, int, float, bool, List[str], None]
CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None]


T = TypeVar("T")
Expand Down Expand Up @@ -757,6 +757,10 @@ def is_type(obj: CfgVal, tp: Type[CfgVal]) -> bool:
except TypeError:
if isinstance(obj, list):
return all(isinstance(e, str) for e in obj)
elif isinstance(obj, dict):
return all(
isinstance(k, str) and isinstance(v, str) for k, v in obj.items()
)
else:
return False

Expand Down Expand Up @@ -865,8 +869,13 @@ def _cast_to_type(value: str, opt_type: Type[CfgVal]) -> CfgVal:
# lists may be ; or , delimited
# also deal with trailing "," by removing empty strings
return [v for v in value.replace(";", ",").split(",") if v]
elif opt_type == Dict[str, str]:
return {
s.split(":", 1)[0]: s.split(":", 1)[1]
for s in value.replace(";", ",").split(",")
}
else:
# pyre-ignore[19]
# pyre-ignore[19, 6] type won't be dict here as we handled it above
return opt_type(value)

cfg: Dict[str, CfgVal] = {}
Expand Down
8 changes: 8 additions & 0 deletions torchx/specs/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def test_cfg_from_str(self) -> None:
opts = runopts()
opts.add("K", type_=List[str], help="a list opt", default=[])
opts.add("J", type_=str, help="a str opt", required=True)
opts.add("E", type_=Dict[str, str], help="a dict opt", default=[])

self.assertDictEqual({}, opts.cfg_from_str(""))
self.assertDictEqual({}, opts.cfg_from_str("UNKWN=b"))
Expand All @@ -455,6 +456,9 @@ def test_cfg_from_str(self) -> None:
self.assertDictEqual(
{"K": ["a"], "J": "d"}, opts.cfg_from_str("J=d,K=a,UNKWN=e")
)
self.assertDictEqual(
{"E": {"f": "b", "F": "B"}}, opts.cfg_from_str("E=f:b,F:B")
)

def test_resolve_from_str(self) -> None:
opts = runopts()
Expand Down Expand Up @@ -489,6 +493,10 @@ def test_runopts_is_type(self) -> None:
self.assertFalse(runopts.is_type(None, List[str]))
self.assertTrue(runopts.is_type([], List[str]))
self.assertTrue(runopts.is_type(["a", "b"], List[str]))
# List[str]
self.assertFalse(runopts.is_type(None, Dict[str, str]))
self.assertTrue(runopts.is_type({}, Dict[str, str]))
self.assertTrue(runopts.is_type({"foo": "bar", "fee": "bez"}, Dict[str, str]))

def test_runopts_iter(self) -> None:
runopts = self.get_runopts()
Expand Down

0 comments on commit e9a6957

Please sign in to comment.