Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3695

Closed
wants to merge 47 commits into from

Conversation

xwang233
Copy link
Collaborator

Stacked on top of

What

Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap.

We introduce a new parallel type Stream to account for this scheduling.

More precisely, this patch enables lowering the fusion:

  TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K]
  TensorView* b = makeContigTensor(2); //[K, N]
  TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);

  auto mesh = DeviceMesh::createForNumDevices(D);
  a->setDeviceMesh(mesh);
  b->setDeviceMesh(mesh);
  c->setDeviceMesh(mesh);

  a->axis(1)->parallelize(ParallelType::DIDx);
  c->axis(0)->parallelize(ParallelType::Stream);

to the Host Ir program (obtained from dump, using NVFUSER_DUMP=host_ir)

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  GetCurrentStream into Stream 0
  T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=fals
e)
  T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=fals
e, resets_to_zero=false)
  FOR i104 in iS0{i0}:
    SetCurrentStream to Stream ( i104 % numberOfStreams )
    T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = i104 )
    T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS11{i0}, index = i104 )
    Communication 46 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    Wait Communication 46
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStream6{i0}, index = i104 )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( i104 % numberOfStreams )
} // %HostIrContainer

The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments

Screenshot 2024-12-18 at 12 08 05

@xwang233
Copy link
Collaborator Author

For CI testing only, please ignore.

@xwang233
Copy link
Collaborator Author

!test

@xwang233 xwang233 changed the title [WIP] Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout [CI testing, please ignore] Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout Jan 10, 2025
@xwang233 xwang233 closed this Jan 10, 2025
@xwang233 xwang233 changed the title [CI testing, please ignore] Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout Jan 10, 2025
@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 10, 2025

PR Reviewer Guide 🔍

(o1-mini)

(Review updated until commit 061955f)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Function Signature

The canLower function signature was changed to include an additional parameter. Ensure all existing calls to this function are updated accordingly to prevent potential mismatches or unexpected behaviors.

bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {
  if (!isResharding(expr)) {
    return true;
  }
  if (!ir_utils::isTvOp(expr)) {
    return false;
  }
  if (auto* reduction = dynamic_cast<ReductionOp*>(expr)) {
    if (isInnerResharding(expr) && !ignore_inner_resharding) {
      return false;
    }
    auto in = reduction->in()->as<TensorView>();
    auto out = reduction->out()->as<TensorView>();
    // get the reduced axis
    std::vector<IterDomain*> reduction_axis;
    std::copy_if(
        out->getLogicalDomain().begin(),
        out->getLogicalDomain().end(),
        std::back_inserter(reduction_axis),
        [](IterDomain* id) { return id->isReduction(); });
    // check whether the reduction involves only one axis
    if (reduction_axis.size() != 1) {
      return false;
    }
    // We check whether the reduced axis is sharded on the input
    const auto c2p_map =
        PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
    auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
    return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
  } else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
    return (!isInnerResharding(ldst) || ignore_inner_resharding) &&
        ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
  } else if (auto* matmul = dynamic_cast<MatmulOp*>(expr)) {
    // For now we only support c = matmul(a,b) when b,c are fully replicated and
    // a is sharded on axis 1
    return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
        matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
        getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 &&
        matmul->out()->axis(0)->getParallelType() == ParallelType::Stream;
  }
  return false;
}
New ParallelType

Introduction of the new ParallelType::Stream. Verify that this new parallel type is consistently handled across all relevant parts of the codebase to avoid integration issues.

case ParallelType::Stream:
  return "Stream";
Complex Logic Addition

The lowerToCollectiveBasedPipelinedGemmComm function introduces intricate logic for pipelined communication. Thoroughly review this section to ensure correctness and identify any potential edge cases that might lead to performance bottlenecks or incorrect computations.

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
    Expr* expr) {
  auto matmul = expr->as<MatmulOp>();
  NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
  TensorView* tva = matmul->inA();
  TensorView* tvb = matmul->inB();
  TensorView* tvc = matmul->out();
  NVF_ERROR(
      !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded");
  NVF_ERROR(
      !isSharded(tvc),
      "The output ",
      matmul->out(),
      " is expected to not be sharded");
  const int64_t sharded_axis_index =
      getShardedLogicalAxis(tva, ParallelType::DIDx);
  IterDomain* stream_axis = tva->axis(0);
  NVF_ERROR(
      stream_axis->getParallelType() == ParallelType::Serial &&
          sharded_axis_index == 1,
      "The operand A ",
      tva,
      " is expected to be sharded on the dimension 1");

  auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

  auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
  hir::Stream* original_stream = get_current_stream->stream();

  TensorView* tva_allgathered =
      ops::newValLike(tva, tva->dtype())->as<TensorView>();
  tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
  tva_allgathered->setMemoryType(MemoryType::Global);
  auto* allocate_tva_allgathered =
      IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

  tvc->setMemoryType(MemoryType::Global);
  auto* allocate_tvc =
      IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

  auto* j =
      IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
  auto* start = hic->zeroVal();
  auto* stop = stream_axis->extent();
  auto* step = hic->oneVal();
  auto* for_loop = IrBuilder::create<ForLoop>(
      stream_axis,
      /*index=*/j,
      start,
      stop,
      step,
      /*vectorize=*/false,
      /*vectorize_shift=*/nullptr,
      /*unroll_required=*/false,
      CircularBufferLoopStage::NotApplicable,
      /*circular_buffer_loop_stage_depth=*/0);

  auto* number_of_streams =
      IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
  auto* stream_index = mod(j, number_of_streams);
  auto* stream = IrBuilder::create<hir::Stream>(stream_index);
  auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

  TensorView* tva_j = select(tva, 0, j);
  TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
  TensorView* tvc_j = select(tvc, 0, j);

  NVF_ERROR(
      tva->hasDeviceMesh(),
      "The matmul's input ",
      tva,
      "is expected to have a DeviceMesh");
  for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) {
    tv->setDeviceMesh(tva->getDeviceMesh());
  }

  auto* communication = IrBuilder::create<Communication>(
      CommunicationType::Allgather,
      /*out=*/tva_allgathered_j,
      /*in=*/tva_j,
      /*team=*/tva->getDeviceMesh().vector());
  auto* wait = IrBuilder::create<hir::Wait>(communication);

  auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

  auto* set_back_original_stream =
      IrBuilder::create<hir::SetCurrentStream>(original_stream);
  auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

  std::vector<Expr*> loop_body = {
      set_stream,
      tva_j->definition(),
      tva_allgathered_j->definition(),
      communication,
      wait,
      tvc_j->definition(),
      mm,
      set_back_original_stream,
      sync_stream};
  for (Expr* expr : loop_body) {
    for_loop->body().push_back(expr);
  }

  return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};

csrc/host_ir/lower.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
csrc/host_ir/executor.h Show resolved Hide resolved
csrc/host_ir/executor.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
tests/cpp/test_multidevice_host_ir.cpp Show resolved Hide resolved
csrc/host_ir/executor.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
tests/cpp/test_multidevice_host_ir.cpp Show resolved Hide resolved
@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 11, 2025

PR Code Suggestions ✨

(gpt-4o)

CategorySuggestion                                                                                                                                    Score
Possible issue
Add null pointer checks for tensor views to prevent segmentation faults

Ensure that the lowerToCollectiveBasedPipelinedGemmComm function checks for null
pointers before dereferencing tva, tvb, and tvc to prevent potential segmentation
faults.

csrc/host_ir/lower.cpp [354-359]

 auto matmul = expr->as<MatmulOp>();
 NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
 TensorView* tva = matmul->inA();
 TensorView* tvb = matmul->inB();
 TensorView* tvc = matmul->out();
+NVF_ERROR(tva != nullptr && tvb != nullptr && tvc != nullptr, "TensorViews must not be null");
Suggestion importance[1-10]: 8

Why: Adding null pointer checks for tva, tvb, and tvc is a crucial safety measure to prevent potential segmentation faults, which can cause the program to crash. This suggestion directly addresses a possible runtime error, enhancing the robustness of the code.

8
Add null check for stream_axis to prevent null pointer dereferencing

Ensure that the lowerToCollectiveBasedPipelinedGemmComm function verifies that
stream_axis is not null before using it to prevent potential null pointer
dereferencing.

csrc/host_ir/lower.cpp [369-375]

 IterDomain* stream_axis = tva->axis(0);
+NVF_ERROR(stream_axis != nullptr, "Stream axis must not be null");
 NVF_ERROR(
     stream_axis->getParallelType() == ParallelType::Serial &&
     sharded_axis_index == 1,
     "The operand A ",
     tva,
     " is expected to be sharded on the dimension 1");
Suggestion importance[1-10]: 8

Why: Adding a null check for stream_axis is a critical enhancement to prevent null pointer dereferencing, which could lead to crashes. This suggestion significantly improves the code's safety and stability.

8
Add error handling for invalid sharded axis index to prevent out-of-bounds access

Consider adding error handling for the getShardedLogicalAxis function to ensure it
returns a valid axis index, preventing potential out-of-bounds access.

csrc/host_ir/lower.cpp [367-369]

 const int64_t sharded_axis_index = getShardedLogicalAxis(tva, ParallelType::DIDx);
+NVF_ERROR(sharded_axis_index >= 0, "Invalid sharded axis index");
 IterDomain* stream_axis = tva->axis(0);
Suggestion importance[1-10]: 7

Why: Introducing error handling for the getShardedLogicalAxis function ensures that the returned axis index is valid, preventing potential out-of-bounds access. This suggestion improves the code's reliability by guarding against incorrect index values.

7

@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 12, 2025

PR Code Suggestions ✨

(o1-mini)

No code suggestions found for the PR.

@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 12, 2025

PR Reviewer Guide 🔍

(Qwen/Qwen2.5-Coder-32B-Instruct)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The canLower function now includes additional checks for LoadStoreOp and MatmulOp. Ensure that these checks are correct and cover all necessary cases.

  if (!isResharding(expr)) {
    return true;
  }
  if (!ir_utils::isTvOp(expr)) {
    return false;
  }
  if (auto* reduction = dynamic_cast<ReductionOp*>(expr)) {
    if (isInnerResharding(expr) && !ignore_inner_resharding) {
      return false;
    }
    auto in = reduction->in()->as<TensorView>();
    auto out = reduction->out()->as<TensorView>();
    // get the reduced axis
    std::vector<IterDomain*> reduction_axis;
    std::copy_if(
        out->getLogicalDomain().begin(),
        out->getLogicalDomain().end(),
        std::back_inserter(reduction_axis),
        [](IterDomain* id) { return id->isReduction(); });
    // check whether the reduction involves only one axis
    if (reduction_axis.size() != 1) {
      return false;
    }
    // We check whether the reduced axis is sharded on the input
    const auto c2p_map =
        PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
    auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
    return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
  } else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
    return (!isInnerResharding(ldst) || ignore_inner_resharding) &&
        ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
  } else if (auto* matmul = dynamic_cast<MatmulOp*>(expr)) {
    // For now we only support c = matmul(a,b) when b,c are fully replicated and
    // a is sharded on axis 1
    return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
        matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
        getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 &&
        matmul->out()->axis(0)->getParallelType() == ParallelType::Stream;
  }
  return false;
}
New Function

The new function lowerToCollectiveBasedPipelinedGemmComm is introduced for lowering MatmulOp. Verify that this function correctly handles all edge cases and integrates well with the existing code.

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
    Expr* expr) {
  auto matmul = expr->as<MatmulOp>();
  NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
  TensorView* tva = matmul->inA();
  TensorView* tvb = matmul->inB();
  TensorView* tvc = matmul->out();
  NVF_ERROR(
      !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded");
  NVF_ERROR(
      !isSharded(tvc),
      "The output ",
      matmul->out(),
      " is expected to not be sharded");
  const int64_t sharded_axis_index =
      getShardedLogicalAxis(tva, ParallelType::DIDx);
  IterDomain* stream_axis = tva->axis(0);
  NVF_ERROR(
      stream_axis->getParallelType() == ParallelType::Serial &&
          sharded_axis_index == 1,
      "The operand A ",
      tva,
      " is expected to be sharded on the dimension 1");

  auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

  auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
  hir::Stream* original_stream = get_current_stream->stream();

  TensorView* tva_allgathered =
      ops::newValLike(tva, tva->dtype())->as<TensorView>();
  tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
  tva_allgathered->setMemoryType(MemoryType::Global);
  auto* allocate_tva_allgathered =
      IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

  tvc->setMemoryType(MemoryType::Global);
  auto* allocate_tvc =
      IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

  auto* j =
      IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
  auto* start = hic->zeroVal();
  auto* stop = stream_axis->extent();
  auto* step = hic->oneVal();
  auto* for_loop = IrBuilder::create<ForLoop>(
      stream_axis,
      /*index=*/j,
      start,
      stop,
      step,
      /*vectorize=*/false,
      /*vectorize_shift=*/nullptr,
      /*unroll_required=*/false,
      CircularBufferLoopStage::NotApplicable,
      /*circular_buffer_loop_stage_depth=*/0);

  auto* number_of_streams =
      IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
  auto* stream_index = mod(j, number_of_streams);
  auto* stream = IrBuilder::create<hir::Stream>(stream_index);
  auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

  TensorView* tva_j = select(tva, 0, j);
  TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
  TensorView* tvc_j = select(tvc, 0, j);

  NVF_ERROR(
      tva->hasDeviceMesh(),
      "The matmul's input ",
      tva,
      "is expected to have a DeviceMesh");
  for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) {
    tv->setDeviceMesh(tva->getDeviceMesh());
  }

  auto* communication = IrBuilder::create<Communication>(
      CommunicationType::Allgather,
      /*out=*/tva_allgathered_j,
      /*in=*/tva_j,
      /*team=*/tva->getDeviceMesh().vector());
  auto* wait = IrBuilder::create<hir::Wait>(communication);

  auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

  auto* set_back_original_stream =
      IrBuilder::create<hir::SetCurrentStream>(original_stream);
  auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

  std::vector<Expr*> loop_body = {
      set_stream,
      tva_j->definition(),
      tva_allgathered_j->definition(),
      communication,
      wait,
      tvc_j->definition(),
      mm,
      set_back_original_stream,
      sync_stream};
  for (Expr* expr : loop_body) {
    for_loop->body().push_back(expr);
  }

  return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}
New Parallel Type

The new ParallelType::Stream is introduced. Ensure that this type is handled correctly throughout the codebase, especially in places where parallel types are checked or manipulated.

Stream,

@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 12, 2025

PR Code Suggestions ✨

(Qwen/Qwen2.5-Coder-32B-Instruct)

CategorySuggestion                                                                                                                                    Score
Possible issue
Validate the number of streams before binding it to the expression evaluator

Ensure that params_.number_of_streams is properly initialized and validated before
use to avoid runtime errors.

csrc/host_ir/executor.cpp [204]

-expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
+if (params_.number_of_streams > 0) {
+    expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
+} else {
+    throw std::invalid_argument("Number of streams must be greater than zero.");
+}
Suggestion importance[1-10]: 7

Why: Ensures that params_.number_of_streams is greater than zero before binding it, preventing potential runtime errors.

7
Validate expressions in the for_loop body to avoid adding null pointers

Ensure that the for_loop body is correctly constructed and that all necessary
expressions are added.

csrc/host_ir/lower.cpp [452-453]

 for (Expr* expr : loop_body) {
-    for_loop->body().push_back(expr);
+    if (expr != nullptr) {
+        for_loop->body().push_back(expr);
+    } else {
+        throw std::runtime_error("Null expression found in loop body.");
+    }
 }
Suggestion importance[1-10]: 7

Why: Adds a null check before adding expressions to the for_loop body, preventing potential runtime errors.

7
Ensure cudaProfilerStop is called after the last iteration to correctly capture profiling data

Ensure that the cudaProfilerStart and cudaProfilerStop calls are correctly paired
and that profiling is only started after warmup iterations.

tests/cpp/test_multidevice_host_ir.cpp [401-405]

 if (i == kNumberOfWarmupIterations) {
     cudaProfilerStart();
 }
 tc = executor.runWithInput(inputs).at(0);
+if (i == kNumberOfIterations - 1) {
+    cudaProfilerStop();
+}
Suggestion importance[1-10]: 7

Why: Ensures that cudaProfilerStop is called after the last iteration, ensuring that profiling data is correctly captured.

7

@@ -235,6 +236,10 @@ void lowerToReduceScatter(
std::vector<Expr*> HostIrLower::lower(Expr* c) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
std::vector<Expr*> HostIrLower::lower(Expr* c) {
/**
* Lower a given expression to a series of communication operations.
* This function checks the type of the expression and calls the appropriate
* lowering function.
*
* @param c The expression to be lowered.
* @return A vector of expressions representing the lowered communication operations.
*/
std::vector<Expr*> HostIrLower::lower(Expr* c) {

@@ -302,16 +307,19 @@
return comms;
}

bool HostIrLower::canLower(Expr* expr) {
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {
/**
* Determines if a given expression can be lowered.
* This function checks if the expression is a resharding operation and if it involves
* tensor operations. It also checks specific conditions for reduction and load/store operations.
*
* @param expr The expression to be checked.
* @param ignore_inner_resharding Whether to ignore inner resharding checks.
* @return True if the expression can be lowered, false otherwise.
*/
bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {

return false;
}

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
/**
* Lower a MatmulOp to a collective-based pipelined GEMM communication.
* This function handles the specific lowering of a matrix multiplication operation
* into a series of communication and computation operations suitable for pipelining.
*
* @param expr The MatmulOp expression to be lowered.
* @return A vector of expressions representing the lowered communication and computation operations.
*/
std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(

@@ -16,14 +16,17 @@ namespace nvfuser {

class HostIrLower {
public:
static bool canLower(Expr* expr);
static bool canLower(Expr* expr, bool ignore_inner_resharding = false);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
static bool canLower(Expr* expr, bool ignore_inner_resharding = false);
/**
* Determines if a given expression can be lowered.
* This function checks if the expression is a resharding operation and if it involves
* tensor operations. It also checks specific conditions for reduction and load/store operations.
*
* @param expr The expression to be checked.
* @param ignore_inner_resharding Whether to ignore inner resharding checks.
* @return True if the expression can be lowered, false otherwise.
*/
static bool canLower(Expr* expr, bool ignore_inner_resharding = false);


// Lower a sharded Expr into a series of Communication.
static std::vector<Expr*> lower(Expr* c);

static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
/**
* Lower a MatmulOp to a collective-based pipelined GEMM communication.
* This function handles the specific lowering of a matrix multiplication operation
* into a series of communication and computation operations suitable for pipelining.
*
* @param expr The MatmulOp expression to be lowered.
* @return A vector of expressions representing the lowered communication and computation operations.
*/
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);

@@ -74,6 +74,9 @@ struct HostIrEvaluatorParams {
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
// but implicitely assumes that the input shape don't change over iterations
bool cache_fusion_executor = false;
// number of additional cuda streams to use at runtime for comm+compute
// pipelining
int64_t number_of_streams = 4;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
int64_t number_of_streams = 4;
/**
* Number of additional CUDA streams to use at runtime for communication and computation pipelining.
*/
int64_t number_of_streams = 4;

@@ -208,6 +208,8 @@ class Wait : public Expr {
}
};

// Makes the current stream wait on the given stream. Non-blocking from the host
// point of view.
class Synchronize : public Expr {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
class Synchronize : public Expr {
/**
* Makes the current stream wait on the given stream. This operation is non-blocking from the host point of view.
*/
class Synchronize : public Expr {

@@ -672,6 +672,7 @@ enum class ParallelType {
TIDz,
TIDy,
TIDx,
Stream,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Proposed documentation

Suggested change
Stream,
/**
* Stream parallel type.
*/
Stream,

@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 12, 2025

PR Reviewer Guide 🔍

(meta/llama-3.1-405b-instruct)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The function HostIrLower::canLower has been modified to include an additional parameter ignore_inner_resharding which is used to determine whether to ignore inner resharding or not. This change may affect the logic of the function and should be reviewed carefully.

bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {
  if (!isResharding(expr)) {
    return true;
  }
  if (!ir_utils::isTvOp(expr)) {
    return false;
  }
  if (auto* reduction = dynamic_cast<ReductionOp*>(expr)) {
    if (isInnerResharding(expr) && !ignore_inner_resharding) {
      return false;
    }
    auto in = reduction->in()->as<TensorView>();
    auto out = reduction->out()->as<TensorView>();
    // get the reduced axis
    std::vector<IterDomain*> reduction_axis;
    std::copy_if(
        out->getLogicalDomain().begin(),
        out->getLogicalDomain().end(),
        std::back_inserter(reduction_axis),
        [](IterDomain* id) { return id->isReduction(); });
    // check whether the reduction involves only one axis
    if (reduction_axis.size() != 1) {
      return false;
    }
    // We check whether the reduced axis is sharded on the input
    const auto c2p_map =
        PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
    auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
    return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
  } else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
    return (!isInnerResharding(ldst) || ignore_inner_resharding) &&
        ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
  } else if (auto* matmul = dynamic_cast<MatmulOp*>(expr)) {
    // For now we only support c = matmul(a,b) when b,c are fully replicated and
    // a is sharded on axis 1
    return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
        matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
        getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 &&
        matmul->out()->axis(0)->getParallelType() == ParallelType::Stream;
  }
  return false;
}
New Function

A new function HostIrLower::lowerToCollectiveBasedPipelinedGemmComm has been added. This function appears to be related to lowering a MatmulOp to a collective-based pipelined GEMM communication. The implementation of this function should be reviewed carefully to ensure it is correct and efficient.

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
    Expr* expr) {
  auto matmul = expr->as<MatmulOp>();
  NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
  TensorView* tva = matmul->inA();
  TensorView* tvb = matmul->inB();
  TensorView* tvc = matmul->out();
  NVF_ERROR(
      !isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded");
  NVF_ERROR(
      !isSharded(tvc),
      "The output ",
      matmul->out(),
      " is expected to not be sharded");
  const int64_t sharded_axis_index =
      getShardedLogicalAxis(tva, ParallelType::DIDx);
  IterDomain* stream_axis = tva->axis(0);
  NVF_ERROR(
      stream_axis->getParallelType() == ParallelType::Serial &&
          sharded_axis_index == 1,
      "The operand A ",
      tva,
      " is expected to be sharded on the dimension 1");

  auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

  auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
  hir::Stream* original_stream = get_current_stream->stream();

  TensorView* tva_allgathered =
      ops::newValLike(tva, tva->dtype())->as<TensorView>();
  tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
  tva_allgathered->setMemoryType(MemoryType::Global);
  auto* allocate_tva_allgathered =
      IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

  tvc->setMemoryType(MemoryType::Global);
  auto* allocate_tvc =
      IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

  auto* j =
      IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
  auto* start = hic->zeroVal();
  auto* stop = stream_axis->extent();
  auto* step = hic->oneVal();
  auto* for_loop = IrBuilder::create<ForLoop>(
      stream_axis,
      /*index=*/j,
      start,
      stop,
      step,
      /*vectorize=*/false,
      /*vectorize_shift=*/nullptr,
      /*unroll_required=*/false,
      CircularBufferLoopStage::NotApplicable,
      /*circular_buffer_loop_stage_depth=*/0);

  auto* number_of_streams =
      IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
  auto* stream_index = mod(j, number_of_streams);
  auto* stream = IrBuilder::create<hir::Stream>(stream_index);
  auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

  TensorView* tva_j = select(tva, 0, j);
  TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
  TensorView* tvc_j = select(tvc, 0, j);

  NVF_ERROR(
      tva->hasDeviceMesh(),
      "The matmul's input ",
      tva,
      "is expected to have a DeviceMesh");
  for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) {
    tv->setDeviceMesh(tva->getDeviceMesh());
  }

  auto* communication = IrBuilder::create<Communication>(
      CommunicationType::Allgather,
      /*out=*/tva_allgathered_j,
      /*in=*/tva_j,
      /*team=*/tva->getDeviceMesh().vector());
  auto* wait = IrBuilder::create<hir::Wait>(communication);

  auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);

  auto* set_back_original_stream =
      IrBuilder::create<hir::SetCurrentStream>(original_stream);
  auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

  std::vector<Expr*> loop_body = {
      set_stream,
      tva_j->definition(),
      tva_allgathered_j->definition(),
      communication,
      wait,
      tvc_j->definition(),
      mm,
      set_back_original_stream,
      sync_stream};
  for (Expr* expr : loop_body) {
    for_loop->body().push_back(expr);
  }

  return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}
Parameter Addition

A new parameter number_of_streams has been added to the HostIrEvaluatorParams struct. This parameter is used to determine the number of additional CUDA streams to use at runtime for comm+compute pipelining. The usage of this parameter should be reviewed to ensure it is correct and does not introduce any issues.

expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);

@xwang233
Copy link
Collaborator Author

xwang233 commented Jan 12, 2025

PR Code Suggestions ✨

(meta/llama-3.1-405b-instruct)

CategorySuggestion                                                                                                                                    Score
General
Add null check for expr_evaluator_ to prevent null pointer dereferences

Consider adding a null check for expr_evaluator_ before calling bind to prevent
potential null pointer dereferences.

csrc/host_ir/executor.cpp [204]

-expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
+if (expr_evaluator_) {
+    expr_evaluator_->bind("numberOfStreams", params_.number_of_streams);
+}
Suggestion importance[1-10]: 8

Why: The suggestion prevents potential null pointer dereferences by adding a null check for expr_evaluator_.

8
Add null check for matmul to prevent null pointer dereferences

In the lowerToCollectiveBasedPipelinedGemmComm function, consider adding a check to
ensure that the matmul object is not null before accessing its members to prevent
null pointer dereferences.

csrc/host_ir/lower.cpp [355]

 auto matmul = expr->as<MatmulOp>();
+if (!matmul) {
+    // Handle the case where matmul is null
+    return {};
+}
 ...
Suggestion importance[1-10]: 8

Why: The suggestion prevents potential null pointer dereferences by adding a null check for matmul.

8
Add null check for tva to prevent null pointer dereferences

In the lowerToCollectiveBasedPipelinedGemmComm function, consider adding a check to
ensure that the tva object is not null before accessing its members to prevent null
pointer dereferences.

csrc/host_ir/lower.cpp [357]

 TensorView* tva = matmul->inA();
+if (!tva) {
+    // Handle the case where tva is null
+    return {};
+}
 ...
Suggestion importance[1-10]: 8

Why: The suggestion prevents potential null pointer dereferences by adding a null check for tva.

8
Add null check for tvb to prevent null pointer dereferences

In the lowerToCollectiveBasedPipelinedGemmComm function, consider adding a check to
ensure that the tvb object is not null before accessing its members to prevent null
pointer dereferences.

csrc/host_ir/lower.cpp [358]

 TensorView* tvb = matmul->inB();
+if (!tvb) {
+    // Handle the case where tvb is null
+    return {};
+}
 ...
Suggestion importance[1-10]: 8

Why: The suggestion prevents potential null pointer dereferences by adding a null check for tvb.

8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants