From 8048ef35ae02cf31cc827e3913d9f8121e01cc98 Mon Sep 17 00:00:00 2001 From: Alexander Jipa Date: Wed, 15 Nov 2023 15:04:48 -0500 Subject: [PATCH] fixing k8s node types for trn1 instance types (#792) Co-authored-by: Alexander Jipa --- torchx/components/structured_arg.py | 6 +++--- torchx/components/test/structured_arg_test.py | 6 +++--- torchx/specs/named_resources_aws.py | 14 ++++++++------ torchx/specs/test/named_resources_aws_test.py | 8 ++++---- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/torchx/components/structured_arg.py b/torchx/components/structured_arg.py index b857624ab..6c1851cc0 100644 --- a/torchx/components/structured_arg.py +++ b/torchx/components/structured_arg.py @@ -191,12 +191,12 @@ def parse_from(h: str, j: str) -> "StructuredJArgument": .. doctest:: - >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2")) + >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2")) Traceback (most recent call last): ... - ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xl` is not a GPU instance. ... + ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xlarge` is not a GPU instance. ... - >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2x16")) + >>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2x16")) '2x16' """ diff --git a/torchx/components/test/structured_arg_test.py b/torchx/components/test/structured_arg_test.py index 40e535a34..c62374743 100644 --- a/torchx/components/test/structured_arg_test.py +++ b/torchx/components/test/structured_arg_test.py @@ -81,14 +81,14 @@ def test_create(self) -> None: ) self.assertEqual( StructuredJArgument(nnodes=2, nproc_per_node=8), - StructuredJArgument.parse_from(h="aws_trn1.2xl", j="2x8"), + StructuredJArgument.parse_from(h="aws_trn1.2xlarge", j="2x8"), ) with self.assertRaisesRegex( ValueError, - "nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xl` is not a GPU instance.", + "nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xlarge` is not a GPU instance.", ): - StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2") + StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2") with self.assertRaisesRegex(ValueError, "Invalid format for `-j"): StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x2x2") diff --git a/torchx/specs/named_resources_aws.py b/torchx/specs/named_resources_aws.py index cd9aec703..376f7145d 100644 --- a/torchx/specs/named_resources_aws.py +++ b/torchx/specs/named_resources_aws.py @@ -205,13 +205,15 @@ def aws_g5_48xlarge() -> Resource: ) -def aws_trn1_2xl() -> Resource: - return Resource(cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xl"}) +def aws_trn1_2xlarge() -> Resource: + return Resource( + cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"} + ) -def aws_trn1_32xl() -> Resource: +def aws_trn1_32xlarge() -> Resource: return Resource( - cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xl"} + cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xlarge"} ) @@ -239,6 +241,6 @@ def aws_trn1_32xl() -> Resource: "aws_g5.12xlarge": aws_g5_12xlarge, "aws_g5.24xlarge": aws_g5_24xlarge, "aws_g5.48xlarge": aws_g5_48xlarge, - "aws_trn1.2xl": aws_trn1_2xl, - "aws_trn1.32xl": aws_trn1_32xl, + "aws_trn1.2xlarge": aws_trn1_2xlarge, + "aws_trn1.32xlarge": aws_trn1_32xlarge, } diff --git a/torchx/specs/test/named_resources_aws_test.py b/torchx/specs/test/named_resources_aws_test.py index d7d3b9755..1829f9e32 100644 --- a/torchx/specs/test/named_resources_aws_test.py +++ b/torchx/specs/test/named_resources_aws_test.py @@ -31,8 +31,8 @@ aws_p4d_24xlarge, aws_p4de_24xlarge, aws_t3_medium, - aws_trn1_2xl, - aws_trn1_32xl, + aws_trn1_2xlarge, + aws_trn1_32xlarge, EFA_DEVICE, GiB, K8S_ITYPE, @@ -156,13 +156,13 @@ def test_aws_g5(self) -> None: self.assertEqual(g5_48.memMB, g5_12.memMB * 4) def test_aws_trn1(self) -> None: - trn1_2 = aws_trn1_2xl() + trn1_2 = aws_trn1_2xlarge() self.assertEqual(8, trn1_2.cpu) self.assertEqual(0, trn1_2.gpu) self.assertEqual(32 * GiB, trn1_2.memMB) - trn1_32 = aws_trn1_32xl() + trn1_32 = aws_trn1_32xlarge() self.assertEqual(trn1_32.cpu, trn1_2.cpu * 16) self.assertEqual(trn1_32.gpu, trn1_2.gpu) self.assertEqual(trn1_32.memMB, trn1_2.memMB * 16)