Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
chore: small changes for generate.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 14, 2024
1 parent 4e08388 commit 205ffa5
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions src/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ namespace tiledkernel {
ASSERT(access_map_rA->getLoops() == access_map_acc->getLoops(),
"Access map should have the same number of loops.");

uint64_t indient = 0;
for (auto loop = 0; loop < access_map_rA->getLoops(); loop++) {
if (access_map_rA->hasPinIterVar(loop)) {
ASSERT(access_map_rB->hasPinIterVar(loop),
Expand Down Expand Up @@ -129,12 +128,15 @@ namespace tiledkernel {
acc->getBufferName().value() + generate_access_map(access_map_acc);

// TODO: Use macro kernel instead of hardcoding the kernel.
auto kernel_body =
fmt::format("{}gemm({}, {}, {});\n", std::string(indient, ' '),
rA_access, rB_access, acc_access);
auto kernel_body = fmt::format("gemm({}, {}, {});\n", rA_access,
rB_access, acc_access);

kernel +=
generate_loop(access_map_rA->pin_iter_vars[0].value(), kernel_body);
for (auto loop = 0; loop < access_map_rA->getLoops(); loop++) {
kernel_body = generate_loop(
access_map_rA->pin_iter_vars[loop].value(), kernel_body);
}

kernel += kernel_body;

return kernel;
}
Expand All @@ -160,14 +162,21 @@ namespace tiledkernel {
std::string TiledGenerator::generate_access_map(
AccessMap::Pointer access_map) {
std::string kernel;
for (int access_loop = 0; access_loop < access_map->getAccessDims();
for (auto access_loop = 0; access_loop < access_map->getAccessDims();
access_loop++) {
auto access_dim = access_map->access_pattern[access_loop];
for (int loop = 0; loop < access_map->getLoops(); loop++) {
if (access_dim[loop] != 0)
for (auto loop = 0; loop < access_map->getLoops(); loop++) {
if (access_dim[loop] == 0) {
return "";
} else if (access_dim[loop] == 1) {
kernel += fmt::format(
"[{}]", access_map->pin_iter_vars[loop].value()->name);
} else {
kernel += fmt::format(
"[{} * {}]", access_dim[loop],
access_map->pin_iter_vars[loop].value()->name);
"[{} * {}]",
access_map->pin_iter_vars[loop].value()->name,
access_dim[loop]);
}
}
}
return kernel;
Expand Down

0 comments on commit 205ffa5

Please sign in to comment.