Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Add basic gemm generator.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 14, 2024
1 parent 7616dc9 commit 078de87
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "3rd-party/googletest"]
path = 3rd-party/googletest
url = [email protected]:google/googletest.git
[submodule "3rd-party/fmtlog"]
path = 3rd-party/fmtlog
url = [email protected]:MengRao/fmtlog.git
1 change: 1 addition & 0 deletions 3rd-party/fmtlog
Submodule fmtlog added at 01b6cd
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CC := g++
EXAMPLE ?= shared_graph
EXAMPLE ?= rf_graph
EXAMPLE_SRCS := $(wildcard examples/*.cc)
EXAMPLES := $(patsubst examples/%.cc, %, $(EXAMPLE_SRCS))

Expand Down
12 changes: 10 additions & 2 deletions examples/rf_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "graph/tilededge.hpp"
#include "type/data_type.hpp"
#include "tiledbuffer.hpp"
#include "generator.hpp"
#include "platform.hpp"
#include "op.hpp"
#include <iostream>

Expand Down Expand Up @@ -69,17 +71,23 @@ int main() {
NodeType::Buffer, MemoryLevel::Shared, TiledNodeData{sC}, "sC_node",
std::vector<EdgePtr>{acc_sC_edge}, std::vector<EdgePtr>{});

auto rf_gemm_graph = TiledGraph(
auto rf_gemm_graph = std::make_shared<TiledGraph>(
MemoryLevel::RF, "rf_gemm_graph",
std::vector<NodePtr>{rA_node, rB_node, gemm_node, acc_node},
std::vector<EdgePtr>{sA_rA_edge, sB_rB_edge},
std::vector<EdgePtr>{acc_sC_edge},
std::vector<EdgePtr>{rA_gemm_edge, rB_gemm_edge, gemm_acc_edge});

auto sorted_nodes = rf_gemm_graph.topoSort();
auto sorted_nodes = rf_gemm_graph->topoSort();

std::cout << "ID\tName" << std::endl;
for (auto node : sorted_nodes) {
std::cout << node->id << "\t" << node->name << std::endl;
}

auto generator = std::make_shared<TiledGenerator>();

auto kernel = generator->emit(Platform::Cute, rf_gemm_graph);

std::cout << "Generate kernel:" << std::endl << kernel << std::endl;
}
2 changes: 2 additions & 0 deletions include/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@ namespace tiledkernel {
TiledContext::Pointer ctx;

std::string emit_cute(TiledGraph::Pointer graph);
std::string emit_rf_cute(TiledGraph::Pointer graph);
std::string emit_rf_cute_gemm(TiledNode::Pointer node);
};
} // namespace tiledkernel
2 changes: 2 additions & 0 deletions include/graph/tiledgraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ namespace tiledkernel::graph {
std::vector<std::shared_ptr<TiledEdge>> out_edges = {},
std::vector<std::shared_ptr<TiledEdge>> intra_edges = {});

MemoryLevel getMemLevel() { return mem_level; }

std::vector<std::shared_ptr<TiledNode>> topoSort();

bool isNodeExist(std::shared_ptr<TiledNode> node);
Expand Down
29 changes: 29 additions & 0 deletions include/graph/tilednode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,35 @@ namespace tiledkernel::graph {
std::string name = "",
std::vector<std::shared_ptr<TiledEdge>> in_edges = {},
std::vector<std::shared_ptr<TiledEdge>> out_edges = {});

NodeType getType() { return node_type; }

OpType getOpType() {
if (node_type == NodeType::Operator) {
return std::get<std::shared_ptr<Operator>>(data)->getOpType();
}
return OpType::Null;
}

std::string getName() { return name; }

std::vector<std::shared_ptr<TiledEdge>> getInEdges() {
return in_edges;
}

std::vector<std::shared_ptr<TiledEdge>> getOutEdges() {
return out_edges;
}

std::vector<std::shared_ptr<TiledNode>> getPredecessors() {
return predecessors;
}

std::vector<std::shared_ptr<TiledNode>> getSuccessors() {
return successors;
}

using Pointer = std::shared_ptr<TiledNode>;
};

using NodePtr = std::shared_ptr<TiledNode>;
Expand Down
5 changes: 4 additions & 1 deletion include/op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
#include <vector>

namespace tiledkernel {
enum OpType { Add, Sub, Mul, Div, Gemm };
enum OpType { Add, Sub, Mul, Div, Gemm, Null };
class Operator {
public:
Operator(OpType op_type = OpType::Add);

OpType getOpType() { return op_type; }

private:
OpType op_type;
};

Expand Down
2 changes: 2 additions & 0 deletions include/tiledbuffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace tiledkernel {
std::string name;
MemoryLevel mem_level;
type::DataType dtype;

using Pointer = std::shared_ptr<TiledBuffer>;
};

using BufferPtr = std::shared_ptr<TiledBuffer>;
Expand Down
60 changes: 59 additions & 1 deletion src/generator.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "graph/tilednode.hpp"
#include "generator.hpp"
#include "error_handler.hpp"
#include <fmt/core.h>

namespace tiledkernel {

Expand All @@ -16,7 +18,63 @@ namespace tiledkernel {
}

std::string TiledGenerator::emit_cute(TiledGraph::Pointer graph) {
TODO("Implement Cute generator");
std::string kernel = "";
switch (graph->getMemLevel()) {
case MemoryLevel::Global:
break;
case MemoryLevel::Shared:
break;
case MemoryLevel::RF:
kernel += emit_rf_cute(graph);
}
return kernel;
}

std::string TiledGenerator::emit_rf_cute(TiledGraph::Pointer graph) {
std::string kernel;
auto sorted_nodes = graph->topoSort();

// Find compute nodes
std::vector<TiledNode::Pointer> compute_nodes;
for (auto node : sorted_nodes) {
if (node->getType() == NodeType::Operator) {
compute_nodes.push_back(node);
}
}

// Emit compute kernel
for (auto node : compute_nodes) {
// auto in_edges = node->getInEdges();
// auto out_edges = node->getOutEdges();

switch (node->getOpType()) {
case OpType::Gemm:
kernel += emit_rf_cute_gemm(node);
break;
default:
TODO("Operator not supported yet.");
}
}
return kernel;
}

std::string TiledGenerator::emit_rf_cute_gemm(TiledNode::Pointer node) {
auto predecessors = node->getPredecessors();
auto successors = node->getSuccessors();

ASSERT(predecessors.size() == 2,
"Gemm node should have 2 predecessors.");
ASSERT(successors.size() == 1, "Gemm node should have 1 successor.");

auto rA = predecessors[0];
auto rB = predecessors[1];
auto acc = successors[0];

// TODO: Add `access_map` to `TiledNode` to store the access pattern.
std::string kernel;
kernel += fmt::format("gemm({}, {}, {});\n", rA->getName(),
rB->getName(), acc->getName());
return kernel;
}

} // namespace tiledkernel

0 comments on commit 078de87

Please sign in to comment.