forked from NVIDIA/cutlass
-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for SYCL on example 35 #142
Closed
aacostadiaz
wants to merge
1
commit into
codeplaysoftware:sycl-develop
from
aacostadiaz:aacosta/example35
Closed
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,7 +52,11 @@ | |
#include "cutlass/util/reference/host/tensor_compare.h" | ||
#include "cutlass/util/reference/host/tensor_norm.h" | ||
#include "cutlass/util/reference/host/tensor_copy.h" | ||
#if defined(CUTLASS_ENABLE_SYCL) | ||
#include "cutlass/util/reference/device/sycl_tensor_fill.h" | ||
#else | ||
#include "cutlass/util/reference/device/tensor_fill.h" | ||
#endif | ||
#include "cutlass/util/reference/host/tensor_fill.h" | ||
#include "cutlass/util/reference/host/error_metrics.h" | ||
#include "cutlass/util/tensor_view_io.h" | ||
|
@@ -61,6 +65,8 @@ | |
#include "cutlass/epilogue/thread/linear_combination.h" | ||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
#include <helper.h> | ||
|
||
#include "gemm_with_softmax.h" | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
@@ -159,6 +165,8 @@ struct Options { | |
/// Returns true if the environment and Toolkit support this | ||
bool supported(bool verbose = true) const { | ||
|
||
#if !defined(CUTLASS_ENABLE_SYCL) | ||
|
||
// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available | ||
// in CUDA 11.0. | ||
// | ||
|
@@ -187,7 +195,7 @@ struct Options { | |
} | ||
return false; | ||
} | ||
|
||
#endif | ||
return true; | ||
} | ||
}; | ||
|
@@ -333,12 +341,16 @@ struct Testbed { | |
return disposition; | ||
} | ||
|
||
#if defined(CUTLASS_ENABLE_SYCL) | ||
syclcompat::wait(); | ||
#else | ||
cudaError_t result = cudaDeviceSynchronize(); | ||
if (result != cudaSuccess) { | ||
std::cerr << "Device synchronize failed with error " | ||
<< cudaGetErrorString(result) << std::endl; | ||
return disposition; | ||
} | ||
#endif | ||
|
||
// | ||
// Verify | ||
|
@@ -513,6 +525,10 @@ struct Testbed { | |
ElementCompute(0) | ||
); | ||
|
||
#if defined(CUTLASS_ENABLE_SYCL) | ||
syclcompat::wait(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is that extra wait needed? Should probably be a code comment. |
||
#endif | ||
|
||
// Copy reference results to host memory for verification | ||
std::vector<ElementD> matrix_D_Ref(layout_C.capacity(extent_C)); | ||
cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); | ||
|
@@ -597,25 +613,10 @@ struct Testbed { | |
// | ||
|
||
cutlass::Status status = cutlass::Status::kSuccess; | ||
cudaError_t result; | ||
cudaEvent_t events[2]; | ||
GpuTimer timer; | ||
int const kIterations = options.iterations; | ||
|
||
for (cudaEvent_t &evt : events) { | ||
result = cudaEventCreate(&evt); | ||
if (result != cudaSuccess) { | ||
std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
} | ||
|
||
result = cudaEventRecord(events[0]); | ||
|
||
if (result != cudaSuccess) { | ||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
|
||
timer.start(); | ||
for (int iter = 0; iter < kIterations; ++iter) { | ||
|
||
status = execute_device_kernel(); | ||
|
@@ -625,36 +626,9 @@ struct Testbed { | |
return false; | ||
} | ||
} | ||
timer.stop(); | ||
|
||
result = cudaEventRecord(events[1]); | ||
|
||
if (result != cudaSuccess) { | ||
std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
|
||
result = cudaDeviceSynchronize(); | ||
|
||
if (result != cudaSuccess) { | ||
std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
|
||
float elapsed_ms = 0; | ||
result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); | ||
|
||
if (result != cudaSuccess) { | ||
std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
|
||
for (cudaEvent_t &evt : events) { | ||
result = cudaEventDestroy(evt); | ||
if (result != cudaSuccess) { | ||
std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; | ||
return false; | ||
} | ||
} | ||
float elapsed_ms = timer.elapsed_millis(); | ||
|
||
int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; | ||
int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -201,12 +201,12 @@ class ApplySoftmax { | |
|
||
using AccessTypeD = AlignedArray<ElementD, kAlignment>; | ||
|
||
int block_batch = blockIdx.z; | ||
int block_m = blockIdx.x * ApplyShape::kRow; | ||
int block_batch = BlockIdxZ(); | ||
int block_m = BlockIdxX() * ApplyShape::kRow; | ||
int block_n = 0; | ||
|
||
int thread_m = threadIdx.y; | ||
int thread_n = threadIdx.x * kAlignment; | ||
int thread_m = ThreadIdxY(); | ||
int thread_n = ThreadIdxX() * kAlignment; | ||
|
||
int idx_m = block_m + thread_m; | ||
int idx_n = block_n + thread_n; | ||
|
@@ -580,6 +580,17 @@ class GemmSoftmax { | |
|
||
cudaError_t result; | ||
|
||
#if defined(CUTLASS_ENABLE_SYCL) | ||
const auto sycl_block = syclcompat::dim3(gemm_block.x, gemm_block.y, gemm_block.z); | ||
const auto sycl_grid = syclcompat::dim3(gemm_grid.x, gemm_grid.y, gemm_grid.z); | ||
|
||
using namespace syclcompat::experimental; | ||
|
||
auto gemm_event = launch<cutlass::Kernel<GemmKernel>>(launch_policy{ | ||
sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(gemm_smem_size)}}, | ||
params_.gemm); | ||
EventManager::getInstance().addEvent(gemm_event); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does the event need to be recorded? (a general question for each launch) |
||
#else | ||
if (gemm_smem_size >= (48 << 10)) { | ||
result = cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, | ||
cudaFuncAttributeMaxDynamicSharedMemorySize, | ||
|
@@ -591,6 +602,7 @@ class GemmSoftmax { | |
} | ||
|
||
cutlass::Kernel<GemmKernel><<<gemm_grid, gemm_block, gemm_smem_size, stream>>>(params_.gemm); | ||
#endif | ||
|
||
result = cudaGetLastError(); | ||
|
||
|
@@ -613,9 +625,21 @@ class GemmSoftmax { | |
dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); | ||
dim3 final_reduction_block(thread_per_block); | ||
|
||
#if defined(CUTLASS_ENABLE_SYCL) | ||
const auto sycl_final_reduction_block = syclcompat::dim3(final_reduction_block.x, final_reduction_block.y, final_reduction_block.z); | ||
const auto sycl_final_reduction_grid = syclcompat::dim3(final_reduction_grid.x, final_reduction_grid.y, final_reduction_grid.z); | ||
|
||
using namespace syclcompat::experimental; | ||
|
||
auto final_reduction_event = launch<Kernel<ApplyFinalReductionKernel>>(launch_policy{ | ||
sycl_final_reduction_grid, sycl_final_reduction_block, local_mem_size{sizeof(typename ApplyFinalReductionKernel::SharedStorage)}}, | ||
params_.reduction); | ||
EventManager::getInstance().addEvent(final_reduction_event); | ||
#else | ||
Kernel<ApplyFinalReductionKernel><<< | ||
final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream | ||
>>>(params_.reduction); | ||
#endif | ||
|
||
result = cudaGetLastError(); | ||
|
||
|
@@ -637,9 +661,21 @@ class GemmSoftmax { | |
(params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, | ||
params_.softmax.args.batch_count); | ||
|
||
#if defined(CUTLASS_ENABLE_SYCL) | ||
const auto sycl_apply_block = syclcompat::dim3(apply_block.x, apply_block.y, apply_block.z); | ||
const auto sycl_apply_grid = syclcompat::dim3(apply_grid.x, apply_grid.y, apply_grid.z); | ||
|
||
using namespace syclcompat::experimental; | ||
|
||
auto apply_event = launch<Kernel<SoftmaxApplyKernel>>(launch_policy{ | ||
sycl_apply_grid, sycl_apply_block, local_mem_size{sizeof(typename SoftmaxApplyKernel::SharedStorage)}}, | ||
params_.softmax); | ||
EventManager::getInstance().addEvent(apply_event); | ||
#else | ||
Kernel<SoftmaxApplyKernel><<< | ||
apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream | ||
>>>(params_.softmax); | ||
#endif | ||
|
||
result = cudaGetLastError(); | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only temporary needed until we rework how
__CUDACC_VER_MAJOR__
works for syclcompat, right? If so a code comment would be good so that we can find those easily in the future.