diff --git a/examples/rf_graph.cc b/examples/rf_graph.cc index 226fd4b..73fda36 100644 --- a/examples/rf_graph.cc +++ b/examples/rf_graph.cc @@ -15,7 +15,7 @@ using namespace tiledkernel::graph; int main() { // Build a RF Gemm graph - std::cout << "Run RF GEMM Graph Example:" << std::endl; + std::cout << "Run RF GEMM Graph Example:" << std::endl << std::endl; // Define buffers auto sA = std::make_shared("sA", MemoryLevel::Shared, DataType::Float32); @@ -78,6 +78,26 @@ int main() { std::vector{acc_sC_edge}, std::vector{rA_gemm_edge, rB_gemm_edge, gemm_acc_edge}); + // Define Access Map + auto rA_gemm_access_map_i = std::make_shared( + 1, 1, std::vector>{std::vector{1}}, + std::vector>{std::make_pair(0, 10)}, + std::vector{1}, std::vector{0}); + + auto rB_gemm_access_map_i = std::make_shared( + 1, 1, std::vector>{std::vector{1}}, + std::vector>{std::make_pair(0, 10)}, + std::vector{1}, std::vector{0}); + + auto gemm_acc_access_map_i = std::make_shared( + 1, 1, std::vector>{std::vector{0}}, + std::vector>{std::make_pair(0, 10)}, + std::vector{1}, std::vector{0}); + + rA_gemm_edge->setAccessMapI(rA_gemm_access_map_i); + rB_gemm_edge->setAccessMapI(rB_gemm_access_map_i); + gemm_acc_edge->setAccessMapI(gemm_acc_access_map_i); + auto sorted_nodes = rf_gemm_graph->topoSort(); std::cout << "ID\tName" << std::endl; @@ -85,6 +105,8 @@ int main() { std::cout << node->id << "\t" << node->name << std::endl; } + std::cout << std::endl; + auto generator = std::make_shared(); auto kernel = generator->emit(Platform::Cute, rf_gemm_graph); diff --git a/include/access_map.hpp b/include/access_map.hpp index 09c700b..3383071 100644 --- a/include/access_map.hpp +++ b/include/access_map.hpp @@ -4,6 +4,7 @@ #include "context.hpp" #include #include +#include namespace tiledkernel { class AccessMap { @@ -30,5 +31,7 @@ namespace tiledkernel { // loop_depth std::vector step_size; std::vector offset; + + using Pointer = std::shared_ptr; }; } // namespace tiledkernel \ No newline at end of file diff --git a/include/graph/tilededge.hpp b/include/graph/tilededge.hpp index 60a5d54..689c825 100644 --- a/include/graph/tilededge.hpp +++ b/include/graph/tilededge.hpp @@ -1,6 +1,5 @@ #pragma once #include "mem_level.hpp" -#include "microop.hpp" #include "platform.hpp" #include "access_map.hpp" #include "id.hpp" @@ -9,16 +8,20 @@ namespace tiledkernel::graph { + enum class EdgeType { Load, Store, Compute }; + class TiledNode; class TiledEdge { public: - std::string name; - TiledEdge(std::string name = "", std::shared_ptr producer = nullptr, std::shared_ptr consumer = nullptr); + std::string getName() { return name; } + + ID getID() { return id; } + std::shared_ptr getProducer() { return producer; } std::shared_ptr getConsumer() { return consumer; } @@ -31,15 +34,30 @@ namespace tiledkernel::graph { this->consumer = consumer; } + void setAccessMapI(AccessMap::Pointer access_map_i) { + this->access_map_i = access_map_i; + } + + void setAccessMapO(AccessMap::Pointer access_map_o) { + this->access_map_o = access_map_o; + } + + AccessMap::Pointer getAccessMapI() { return access_map_i; } + + AccessMap::Pointer getAccessMapO() { return access_map_o; } + + void inferType(); + using Pointer = std::shared_ptr; protected: - // std::shared_ptr input; - // std::shared_ptr output; + std::string name; ID id; + EdgeType edge_type; std::shared_ptr producer; std::shared_ptr consumer; - // std::shared_ptr access_map; + AccessMap::Pointer access_map_i; + AccessMap::Pointer access_map_o; }; using EdgePtr = std::shared_ptr; diff --git a/include/graph/tilednode.hpp b/include/graph/tilednode.hpp index 892300f..29234fc 100644 --- a/include/graph/tilednode.hpp +++ b/include/graph/tilednode.hpp @@ -6,6 +6,7 @@ #include #include #include +#include namespace tiledkernel::graph { class TiledEdge; @@ -47,8 +48,17 @@ namespace tiledkernel::graph { return OpType::Null; } + MemoryLevel getMemLevel() { return mem_level; } + std::string getName() { return name; } + std::optional getBufferName() { + if (node_type == NodeType::Buffer) { + return std::get>(data)->name; + } + return {}; + } + std::vector> getInEdges() { return in_edges; } diff --git a/include/kernel/header.hpp b/include/kernel/header.hpp new file mode 100644 index 0000000..0f2a4f2 --- /dev/null +++ b/include/kernel/header.hpp @@ -0,0 +1,5 @@ +namespace tiledkernel::kernel { + enum class Header { + Cute, + }; +} \ No newline at end of file diff --git a/include/mem_level.hpp b/include/mem_level.hpp index a718e55..a935bb0 100644 --- a/include/mem_level.hpp +++ b/include/mem_level.hpp @@ -1,4 +1,8 @@ #pragma once namespace tiledkernel { - enum MemoryLevel { RF, Shared, Global }; -}; \ No newline at end of file + enum MemoryLevel { RF = 1, Shared = 2, Global = 3 }; + + // bool operator>(MemoryLevel a, MemoryLevel b) { + // return static_cast(a) > static_cast(b); + // } +}; // namespace tiledkernel \ No newline at end of file diff --git a/src/generator.cc b/src/generator.cc index f032a76..74a940b 100644 --- a/src/generator.cc +++ b/src/generator.cc @@ -1,4 +1,5 @@ #include "graph/tilednode.hpp" +#include "graph/tilededge.hpp" #include "generator.hpp" #include "error_handler.hpp" #include @@ -70,10 +71,26 @@ namespace tiledkernel { auto rB = predecessors[1]; auto acc = successors[0]; + auto gemm_in_edges = node->getInEdges(); + auto gemm_out_edges = node->getOutEdges(); + + ASSERT(gemm_in_edges.size() == 2, "Gemm node should have 2 in edges."); + ASSERT(gemm_out_edges.size() == 1, "Gemm node should have 1 out edge."); + // TODO: Add `access_map` to `TiledNode` to store the access pattern. std::string kernel; - kernel += fmt::format("gemm({}, {}, {});\n", rA->getName(), - rB->getName(), acc->getName()); + + // TODO: Generate `for` loop. + auto access_map_rA = gemm_in_edges[0]->getAccessMapI(); + auto access_map_rB = gemm_in_edges[1]->getAccessMapI(); + auto access_map_acc = gemm_out_edges[0]->getAccessMapI(); + + // TODO: Check if the access map is valid. + + // TODO: Use macro kernel instead of hardcoding the kernel. + kernel += fmt::format( + "gemm({}, {}, {});\n", rA->getBufferName().value(), + rB->getBufferName().value(), acc->getBufferName().value()); return kernel; } diff --git a/src/graph/tilededge.cc b/src/graph/tilededge.cc index efe5db9..999fdf6 100644 --- a/src/graph/tilededge.cc +++ b/src/graph/tilededge.cc @@ -1,5 +1,6 @@ #include "graph/tilededge.hpp" #include "graph/tilednode.hpp" +#include "error_handler.hpp" namespace tiledkernel::graph { @@ -8,4 +9,27 @@ namespace tiledkernel::graph { : name(name), producer(producer), consumer(consumer) { id = ID::make(); } + + void TiledEdge::inferType() { + auto producer = getProducer(); + auto consumer = getConsumer(); + auto producer_type = producer->getType(); + auto consumer_type = consumer->getType(); + if (producer_type == NodeType::Buffer && + consumer_type == NodeType::Buffer) { + if (producer->getMemLevel() >= consumer->getMemLevel()) { + edge_type = EdgeType::Load; + } else { + edge_type = EdgeType::Store; + } + } else if (producer_type == NodeType::Buffer && + consumer_type == NodeType::Operator) { + edge_type = EdgeType::Compute; + } else if (producer_type == NodeType::Operator && + consumer_type == NodeType::Buffer) { + edge_type = EdgeType::Compute; + } else { + TODO("Not implemented yet"); + } + } }; // namespace tiledkernel::graph \ No newline at end of file