diff --git a/torchx/runner/api.py b/torchx/runner/api.py index 72e85ef5d..67771c9c6 100644 --- a/torchx/runner/api.py +++ b/torchx/runner/api.py @@ -357,7 +357,7 @@ def dryrun( cfg = cfg or dict() with log_event("dryrun", scheduler, runcfg=json.dumps(cfg) if cfg else None): sched = self._scheduler(scheduler) - + resolved_cfg = sched.run_opts().resolve(cfg) if workspace and isinstance(sched, WorkspaceMixin): role = app.roles[0] old_img = role.image @@ -366,7 +366,7 @@ def dryrun( logger.info( 'To disable workspaces pass: --workspace="" from CLI or workspace=None programmatically.' ) - sched.build_workspace_and_update_role(role, workspace, cfg) + sched.build_workspace_and_update_role(role, workspace, resolved_cfg) if old_img != role.image: logger.info( @@ -380,7 +380,7 @@ def dryrun( ) sched._validate(app, scheduler) - dryrun_info = sched.submit_dryrun(app, cfg) + dryrun_info = sched.submit_dryrun(app, resolved_cfg) dryrun_info._scheduler = scheduler return dryrun_info diff --git a/torchx/runner/test/api_test.py b/torchx/runner/test/api_test.py index abaf4e2ef..09e3812f1 100644 --- a/torchx/runner/test/api_test.py +++ b/torchx/runner/test/api_test.py @@ -114,6 +114,10 @@ def test_run(self, _) -> None: def test_dryrun(self, _) -> None: scheduler_mock = MagicMock() + scheduler_mock.run_opts.return_value.resolve.return_value = { + **self.cfg, + "foo": "bar", + } with Runner( name=SESSION_NAME, scheduler_factories={"local_dir": lambda name: scheduler_mock}, @@ -127,7 +131,9 @@ def test_dryrun(self, _) -> None: ) app = AppDef("name", roles=[role]) runner.dryrun(app, "local_dir", cfg=self.cfg) - scheduler_mock.submit_dryrun.assert_called_once_with(app, self.cfg) + scheduler_mock.submit_dryrun.assert_called_once_with( + app, {**self.cfg, "foo": "bar"} + ) scheduler_mock._validate.assert_called_once() def test_dryrun_env_variables(self, _) -> None: diff --git a/torchx/schedulers/api.py b/torchx/schedulers/api.py index 52f2a6f16..f2ccc0718 100644 --- a/torchx/schedulers/api.py +++ b/torchx/schedulers/api.py @@ -136,12 +136,15 @@ def submit( Returns: The application id that uniquely identifies the submitted app. """ + # pyre-fixme: Generic cfg type passed to resolve + resolved_cfg = self.run_opts().resolve(cfg) if workspace: sched = self assert isinstance(sched, WorkspaceMixin) role = app.roles[0] - sched.build_workspace_and_update_role(role, workspace, cfg) - dryrun_info = self.submit_dryrun(app, cfg) + sched.build_workspace_and_update_role(role, workspace, resolved_cfg) + # pyre-fixme: submit_dryrun takes Generic type for resolved_cfg + dryrun_info = self.submit_dryrun(app, resolved_cfg) return self.schedule(dryrun_info) @abc.abstractmethod diff --git a/torchx/schedulers/test/api_test.py b/torchx/schedulers/test/api_test.py index 9d36e5946..3b9cb897b 100644 --- a/torchx/schedulers/test/api_test.py +++ b/torchx/schedulers/test/api_test.py @@ -108,8 +108,8 @@ def test_submit_workspace(self) -> None: scheduler_mock = SchedulerTest.MockScheduler("test_session") - bad_type_cfg = {"foo": "asdf"} - scheduler_mock.submit(app, bad_type_cfg, workspace="some_workspace") + cfg = {"foo": "asdf"} + scheduler_mock.submit(app, cfg, workspace="some_workspace") self.assertEqual(app.roles[0].image, "some_workspace") def test_invalid_dryrun_cfg(self) -> None: