Skip to content

Commit

Permalink
fixing k8s node types for trn1 instance types (#792)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <[email protected]>
  • Loading branch information
Alexander Jipa and azzhipa authored Nov 15, 2023
1 parent b1e56b2 commit 8048ef3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
6 changes: 3 additions & 3 deletions torchx/components/structured_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ def parse_from(h: str, j: str) -> "StructuredJArgument":
.. doctest::
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2"))
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2"))
Traceback (most recent call last):
...
ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xl` is not a GPU instance. ...
ValueError: nproc_per_node cannot be inferred from GPU count. `trn1.32xlarge` is not a GPU instance. ...
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2x16"))
>>> str(StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2x16"))
'2x16'
"""
Expand Down
6 changes: 3 additions & 3 deletions torchx/components/test/structured_arg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def test_create(self) -> None:
)
self.assertEqual(
StructuredJArgument(nnodes=2, nproc_per_node=8),
StructuredJArgument.parse_from(h="aws_trn1.2xl", j="2x8"),
StructuredJArgument.parse_from(h="aws_trn1.2xlarge", j="2x8"),
)

with self.assertRaisesRegex(
ValueError,
"nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xl` is not a GPU instance.",
"nproc_per_node cannot be inferred from GPU count. `aws_trn1.32xlarge` is not a GPU instance.",
):
StructuredJArgument.parse_from(h="aws_trn1.32xl", j="2")
StructuredJArgument.parse_from(h="aws_trn1.32xlarge", j="2")

with self.assertRaisesRegex(ValueError, "Invalid format for `-j"):
StructuredJArgument.parse_from(h="aws_p4d.24xlarge", j="2x2x2")
Expand Down
14 changes: 8 additions & 6 deletions torchx/specs/named_resources_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,15 @@ def aws_g5_48xlarge() -> Resource:
)


def aws_trn1_2xl() -> Resource:
return Resource(cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xl"})
def aws_trn1_2xlarge() -> Resource:
return Resource(
cpu=8, gpu=0, memMB=32 * GiB, capabilities={K8S_ITYPE: "trn1.2xlarge"}
)


def aws_trn1_32xl() -> Resource:
def aws_trn1_32xlarge() -> Resource:
return Resource(
cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xl"}
cpu=128, gpu=0, memMB=512 * GiB, capabilities={K8S_ITYPE: "trn1.32xlarge"}
)


Expand Down Expand Up @@ -239,6 +241,6 @@ def aws_trn1_32xl() -> Resource:
"aws_g5.12xlarge": aws_g5_12xlarge,
"aws_g5.24xlarge": aws_g5_24xlarge,
"aws_g5.48xlarge": aws_g5_48xlarge,
"aws_trn1.2xl": aws_trn1_2xl,
"aws_trn1.32xl": aws_trn1_32xl,
"aws_trn1.2xlarge": aws_trn1_2xlarge,
"aws_trn1.32xlarge": aws_trn1_32xlarge,
}
8 changes: 4 additions & 4 deletions torchx/specs/test/named_resources_aws_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
aws_p4d_24xlarge,
aws_p4de_24xlarge,
aws_t3_medium,
aws_trn1_2xl,
aws_trn1_32xl,
aws_trn1_2xlarge,
aws_trn1_32xlarge,
EFA_DEVICE,
GiB,
K8S_ITYPE,
Expand Down Expand Up @@ -156,13 +156,13 @@ def test_aws_g5(self) -> None:
self.assertEqual(g5_48.memMB, g5_12.memMB * 4)

def test_aws_trn1(self) -> None:
trn1_2 = aws_trn1_2xl()
trn1_2 = aws_trn1_2xlarge()

self.assertEqual(8, trn1_2.cpu)
self.assertEqual(0, trn1_2.gpu)
self.assertEqual(32 * GiB, trn1_2.memMB)

trn1_32 = aws_trn1_32xl()
trn1_32 = aws_trn1_32xlarge()
self.assertEqual(trn1_32.cpu, trn1_2.cpu * 16)
self.assertEqual(trn1_32.gpu, trn1_2.gpu)
self.assertEqual(trn1_32.memMB, trn1_2.memMB * 16)
Expand Down

0 comments on commit 8048ef3

Please sign in to comment.