From 42ea5b419bb686af1a2ffceb5eae33f53538d4cc Mon Sep 17 00:00:00 2001 From: xla authors Date: Mon, 11 Nov 2024 11:40:04 -0800 Subject: [PATCH] [XLA] Modify comments in ragged all-to-all HLO. PiperOrigin-RevId: 695425829 --- xla/hlo/ir/hlo_instruction.h | 4 ++-- xla/hlo/parser/hlo_parser_test.cc | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/xla/hlo/ir/hlo_instruction.h b/xla/hlo/ir/hlo_instruction.h index dca1b25bba4a5..f51b28feab717 100644 --- a/xla/hlo/ir/hlo_instruction.h +++ b/xla/hlo/ir/hlo_instruction.h @@ -1067,10 +1067,10 @@ class HloInstruction { // The ragged all-to-all HLO has the following arguments: // input: ragged input data tensor. // input_offsets: ragged input offsets tensor. - // input_sizes: ragged input sizes tensor. + // send_sizes: ragged send sizes tensor. // output: ragged output data tensor. // output_offsets: ragged output offsets tensor. - // output_sizes: ragged output sizes tensor. + // recv_sizes: ragged recv sizes tensor. // // The '*_offsets' and '*_sizes' tensors must have the same shape. // The output buffer is passed in as an input (and aliased in the output), diff --git a/xla/hlo/parser/hlo_parser_test.cc b/xla/hlo/parser/hlo_parser_test.cc index 34311e3c95e76..ade7da2c25191 100644 --- a/xla/hlo/parser/hlo_parser_test.cc +++ b/xla/hlo/parser/hlo_parser_test.cc @@ -2193,10 +2193,10 @@ ENTRY AllToAll { input = bf16[1024,256]{1,0} parameter(0) output = bf16[1024,256]{1,0} parameter(1) input_offsets = s32[8]{0} parameter(2) - input_sizes = s32[8]{0} parameter(3) + send_sizes = s32[8]{0} parameter(3) output_offsets = s32[8]{0} parameter(4) - output_sizes = s32[8]{0} parameter(5) - ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={{0,1,2,3,4,5,6,7}} + recv_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={{0,1,2,3,4,5,6,7}} } )", @@ -2211,10 +2211,10 @@ ENTRY AllToAll { input = bf16[1024,256]{1,0} parameter(0) output = bf16[1024,256]{1,0} parameter(1) input_offsets = s32[8]{0} parameter(2) - input_sizes = s32[8]{0} parameter(3) + send_sizes = s32[8]{0} parameter(3) output_offsets = s32[8]{0} parameter(4) - output_sizes = s32[8]{0} parameter(5) - ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups=[2,4]<=[4,2]T(1,0) + recv_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups=[2,4]<=[4,2]T(1,0) } )", @@ -2229,10 +2229,10 @@ ENTRY AllToAll { input = bf16[1024,256]{1,0} parameter(0) output = bf16[1024,256]{1,0} parameter(1) input_offsets = s32[8]{0} parameter(2) - input_sizes = s32[8]{0} parameter(3) + send_sizes = s32[8]{0} parameter(3) output_offsets = s32[8]{0} parameter(4) - output_sizes = s32[8]{0} parameter(5) - ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, input_sizes, output_offsets, output_sizes), replica_groups={} + recv_sizes = s32[8]{0} parameter(5) + ROOT ra2a = bf16[1024,256]{1,0} ragged-all-to-all(input, output, input_offsets, send_sizes, output_offsets, recv_sizes), replica_groups={} } )"