From 6f609dd9a2bd5da3c5d969b6d9f32dcf3a0c33b3 Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Fri, 15 Sep 2023 21:41:31 +0000 Subject: [PATCH] [torchx/specs] Use default_factory for the default value of Role.resource and mlflow_test.Config.model_config to support python 3.11 clients --- .github/workflows/python-unittests.yaml | 2 +- torchx/specs/api.py | 7 ++++++- torchx/tracker/test/mlflow_test.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-unittests.yaml b/.github/workflows/python-unittests.yaml index 5a8daa193..867404414 100644 --- a/.github/workflows/python-unittests.yaml +++ b/.github/workflows/python-unittests.yaml @@ -10,7 +10,7 @@ jobs: unittest: strategy: matrix: - python-version: [3.8, 3.9, '3.10'] + python-version: [3.8, 3.9, 3.10, 3.11] platform: ["linux.20_04.4x"] include: - python-version: 3.9 diff --git a/torchx/specs/api.py b/torchx/specs/api.py index 779e1b8de..5def61798 100644 --- a/torchx/specs/api.py +++ b/torchx/specs/api.py @@ -104,6 +104,11 @@ def copy(original: "Resource", **capabilities: Any) -> "Resource": # sentinel value used for cases when resource does not matter (e.g. ignored) NULL_RESOURCE: Resource = Resource(cpu=-1, gpu=-1, memMB=-1) +# no-arg static factory method to use with default_factory in @dataclass +# needed to support python 3.11 since mutable defaults for dataclasses are not allowed in 3.11 +def _null_resource() -> Resource: + return NULL_RESOURCE + # used as "*" scheduler backend ALL: str = "all" @@ -333,7 +338,7 @@ class Role: num_replicas: int = 1 max_retries: int = 0 retry_policy: RetryPolicy = RetryPolicy.APPLICATION - resource: Resource = NULL_RESOURCE + resource: Resource = field(default_factory=_null_resource) port_map: Dict[str, int] = field(default_factory=dict) metadata: Dict[str, Any] = field(default_factory=dict) mounts: List[Union[BindMount, VolumeMount, DeviceMount]] = field( diff --git a/torchx/tracker/test/mlflow_test.py b/torchx/tracker/test/mlflow_test.py index d367218f4..791e388f0 100644 --- a/torchx/tracker/test/mlflow_test.py +++ b/torchx/tracker/test/mlflow_test.py @@ -44,7 +44,7 @@ class Config: locales: List[str] = field(default_factory=lambda: ["us", "eu", "fr"]) empty_list: List[str] = field(default_factory=list) empty_map: Dict[str, str] = field(default_factory=dict) - model_config: ModelConfig = ModelConfig() + model_config: ModelConfig = field(default_factory=ModelConfig) datasets: List[DatasetConfig] = field( default_factory=lambda: [ DatasetConfig(url="s3://dataset1"),