diff --git a/torchx/components/structured_arg.py b/torchx/components/structured_arg.py index b857624ab..6c1851cc0 100644 --- a/torchx/components/structured_arg.py +++ b/torchx/components/structured_arg.py @@ -191,12 +191,12 @@ def parse_from(h: str, j: str) -> "StructuredJArgument": .. doctest:: - >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2")) + >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2")) Traceback (most recent call last): ... - ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xl` is not a GPU instance. ... + ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xlarge` is not a GPU instance. ... - >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2x16")) + >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2x16")) '2x16' """ diff --git a/torchx/components/test/structured_arg_test.py b/torchx/components/test/structured_arg_test.py index 40e535a34..c62374743 100644 --- a/torchx/components/test/structured_arg_test.py +++ b/torchx/components/test/structured_arg_test.py @@ -81,14 +81,14 @@ def test_create(self) -> None: ) self.assertEqual( StructuredJArgument(nnodes=2, nproc_per_node=8), - StructuredJArgument.parse_from(h="aws_trn1.2xl", j="2x8"), + StructuredJArgument.parse_from(h="aws_trn1.2xlarge", j="2x8"), ) with self.assertRaisesRegex( ValueError, - "nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xl` is not a GPU instance.", + "nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xlarge` is not a GPU instance.", ): - StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2") + StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2") with self.assertRaisesRegex(ValueError, "Invalid format for `-j"): StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x2x2") diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index c3e45b6dc..821b2ea5c 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -192,13 +192,15 @@ def aws_g5_48xlarge() -> Resource: ) -def aws_trn1_2xl() -> Resource: - return Resource(cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xl"}) +def aws_trn1_2xlarge() -> Resource: + return Resource( + cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"} + ) -def aws_trn1_32xl() -> Resource: +def aws_trn1_32xlarge() -> Resource: return Resource( - cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xl"} + cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xlarge"} ) @@ -226,6 +228,6 @@ def aws_trn1_32xl() -> Resource: "aws_g5.12xlarge": aws_g5_12xlarge, "aws_g5.24xlarge": aws_g5_24xlarge, "aws_g5.48xlarge": aws_g5_48xlarge, - "aws_trn1.2xl": aws_trn1_2xl, - "aws_trn1.32xl": aws_trn1_32xl, + "aws_trn1.2xlarge": aws_trn1_2xlarge, + "aws_trn1.32xlarge": aws_trn1_32xlarge, } diff --git a/torchx/specs/test/named_resources_aws_test.py b/torchx/specs/test/named_resources_aws_test.py index d7d3b9755..1829f9e32 100644 --- a/torchx/specs/test/named_resources_aws_test.py +++ b/torchx/specs/test/named_resources_aws_test.py @@ -31,8 +31,8 @@ aws_p4d_24xlarge, aws_p4de_24xlarge, aws_t3_medium, - aws_trn1_2xl, - aws_trn1_32xl, + aws_trn1_2xlarge, + aws_trn1_32xlarge, EFA_DEVICE, GiB, K8S_ITYPE, @@ -156,13 +156,13 @@ def test_aws_g5(self) -> None: self.assertEqual(g5_48.memMB, g5_12.memMB * 4) def test_aws_trn1(self) -> None: - trn1_2 = aws_trn1_2xl() + trn1_2 = aws_trn1_2xlarge() self.assertEqual(8, trn1_2.cpu) self.assertEqual(0, trn1_2.gpu) self.assertEqual(32 * GiB, trn1_2.memMB) - trn1_32 = aws_trn1_32xl() + trn1_32 = aws_trn1_32xlarge() self.assertEqual(trn1_32.cpu, trn1_2.cpu * 16) self.assertEqual(trn1_32.gpu, trn1_2.gpu) self.assertEqual(trn1_32.memMB, trn1_2.memMB * 16)