Skip to content
This repository has been archived by the owner on Sep 9, 2024. It is now read-only.

Commit

Permalink
Add InferEdge and AccessMap.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 14, 2024
1 parent 078de87 commit 58f71c2
Show file tree
Hide file tree
Showing 8 changed files with 114 additions and 11 deletions.
24 changes: 23 additions & 1 deletion examples/rf_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using namespace tiledkernel::graph;
int main() {
// Build a RF Gemm graph

std::cout << "Run RF GEMM Graph Example:" << std::endl;
std::cout << "Run RF GEMM Graph Example:" << std::endl << std::endl;
// Define buffers
auto sA = std::make_shared<TiledBuffer>("sA", MemoryLevel::Shared,
DataType::Float32);
Expand Down Expand Up @@ -78,13 +78,35 @@ int main() {
std::vector<EdgePtr>{acc_sC_edge},
std::vector<EdgePtr>{rA_gemm_edge, rB_gemm_edge, gemm_acc_edge});

// Define Access Map
auto rA_gemm_access_map_i = std::make_shared<AccessMap>(
1, 1, std::vector<std::vector<int32_t>>{std::vector<int32_t>{1}},
std::vector<std::pair<int32_t, int32_t>>{std::make_pair(0, 10)},
std::vector<int32_t>{1}, std::vector<int32_t>{0});

auto rB_gemm_access_map_i = std::make_shared<AccessMap>(
1, 1, std::vector<std::vector<int32_t>>{std::vector<int32_t>{1}},
std::vector<std::pair<int32_t, int32_t>>{std::make_pair(0, 10)},
std::vector<int32_t>{1}, std::vector<int32_t>{0});

auto gemm_acc_access_map_i = std::make_shared<AccessMap>(
1, 1, std::vector<std::vector<int32_t>>{std::vector<int32_t>{0}},
std::vector<std::pair<int32_t, int32_t>>{std::make_pair(0, 10)},
std::vector<int32_t>{1}, std::vector<int32_t>{0});

rA_gemm_edge->setAccessMapI(rA_gemm_access_map_i);
rB_gemm_edge->setAccessMapI(rB_gemm_access_map_i);
gemm_acc_edge->setAccessMapI(gemm_acc_access_map_i);

auto sorted_nodes = rf_gemm_graph->topoSort();

std::cout << "ID\tName" << std::endl;
for (auto node : sorted_nodes) {
std::cout << node->id << "\t" << node->name << std::endl;
}

std::cout << std::endl;

auto generator = std::make_shared<TiledGenerator>();

auto kernel = generator->emit(Platform::Cute, rf_gemm_graph);
Expand Down
3 changes: 3 additions & 0 deletions include/access_map.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "context.hpp"
#include <vector>
#include <string>
#include <memory>

namespace tiledkernel {
class AccessMap {
Expand All @@ -30,5 +31,7 @@ namespace tiledkernel {
// loop_depth
std::vector<int32_t> step_size;
std::vector<int32_t> offset;

using Pointer = std::shared_ptr<AccessMap>;
};
} // namespace tiledkernel
30 changes: 24 additions & 6 deletions include/graph/tilededge.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once
#include "mem_level.hpp"
#include "microop.hpp"
#include "platform.hpp"
#include "access_map.hpp"
#include "id.hpp"
Expand All @@ -9,16 +8,20 @@

namespace tiledkernel::graph {

enum class EdgeType { Load, Store, Compute };

class TiledNode;

class TiledEdge {
public:
std::string name;

TiledEdge(std::string name = "",
std::shared_ptr<TiledNode> producer = nullptr,
std::shared_ptr<TiledNode> consumer = nullptr);

std::string getName() { return name; }

ID getID() { return id; }

std::shared_ptr<TiledNode> getProducer() { return producer; }

std::shared_ptr<TiledNode> getConsumer() { return consumer; }
Expand All @@ -31,15 +34,30 @@ namespace tiledkernel::graph {
this->consumer = consumer;
}

void setAccessMapI(AccessMap::Pointer access_map_i) {
this->access_map_i = access_map_i;
}

void setAccessMapO(AccessMap::Pointer access_map_o) {
this->access_map_o = access_map_o;
}

AccessMap::Pointer getAccessMapI() { return access_map_i; }

AccessMap::Pointer getAccessMapO() { return access_map_o; }

void inferType();

using Pointer = std::shared_ptr<TiledEdge>;

protected:
// std::shared_ptr<TiledBuffer> input;
// std::shared_ptr<TiledBuffer> output;
std::string name;
ID id;
EdgeType edge_type;
std::shared_ptr<TiledNode> producer;
std::shared_ptr<TiledNode> consumer;
// std::shared_ptr<AccessMap> access_map;
AccessMap::Pointer access_map_i;
AccessMap::Pointer access_map_o;
};

using EdgePtr = std::shared_ptr<TiledEdge>;
Expand Down
10 changes: 10 additions & 0 deletions include/graph/tilednode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <string>
#include <variant>
#include <optional>

namespace tiledkernel::graph {
class TiledEdge;
Expand Down Expand Up @@ -47,8 +48,17 @@ namespace tiledkernel::graph {
return OpType::Null;
}

MemoryLevel getMemLevel() { return mem_level; }

std::string getName() { return name; }

std::optional<std::string> getBufferName() {
if (node_type == NodeType::Buffer) {
return std::get<std::shared_ptr<TiledBuffer>>(data)->name;
}
return {};
}

std::vector<std::shared_ptr<TiledEdge>> getInEdges() {
return in_edges;
}
Expand Down
5 changes: 5 additions & 0 deletions include/kernel/header.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
namespace tiledkernel::kernel {
enum class Header {
Cute,
};
}
8 changes: 6 additions & 2 deletions include/mem_level.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
#pragma once
namespace tiledkernel {
enum MemoryLevel { RF, Shared, Global };
};
enum MemoryLevel { RF = 1, Shared = 2, Global = 3 };

// bool operator>(MemoryLevel a, MemoryLevel b) {
// return static_cast<int>(a) > static_cast<int>(b);
// }
}; // namespace tiledkernel
21 changes: 19 additions & 2 deletions src/generator.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "graph/tilednode.hpp"
#include "graph/tilededge.hpp"
#include "generator.hpp"
#include "error_handler.hpp"
#include <fmt/core.h>
Expand Down Expand Up @@ -70,10 +71,26 @@ namespace tiledkernel {
auto rB = predecessors[1];
auto acc = successors[0];

auto gemm_in_edges = node->getInEdges();
auto gemm_out_edges = node->getOutEdges();

ASSERT(gemm_in_edges.size() == 2, "Gemm node should have 2 in edges.");
ASSERT(gemm_out_edges.size() == 1, "Gemm node should have 1 out edge.");

// TODO: Add `access_map` to `TiledNode` to store the access pattern.
std::string kernel;
kernel += fmt::format("gemm({}, {}, {});\n", rA->getName(),
rB->getName(), acc->getName());

// TODO: Generate `for` loop.
auto access_map_rA = gemm_in_edges[0]->getAccessMapI();
auto access_map_rB = gemm_in_edges[1]->getAccessMapI();
auto access_map_acc = gemm_out_edges[0]->getAccessMapI();

// TODO: Check if the access map is valid.

// TODO: Use macro kernel instead of hardcoding the kernel.
kernel += fmt::format(
"gemm({}, {}, {});\n", rA->getBufferName().value(),
rB->getBufferName().value(), acc->getBufferName().value());
return kernel;
}

Expand Down
24 changes: 24 additions & 0 deletions src/graph/tilededge.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "graph/tilededge.hpp"
#include "graph/tilednode.hpp"
#include "error_handler.hpp"

namespace tiledkernel::graph {

Expand All @@ -8,4 +9,27 @@ namespace tiledkernel::graph {
: name(name), producer(producer), consumer(consumer) {
id = ID::make();
}

void TiledEdge::inferType() {
auto producer = getProducer();
auto consumer = getConsumer();
auto producer_type = producer->getType();
auto consumer_type = consumer->getType();
if (producer_type == NodeType::Buffer &&
consumer_type == NodeType::Buffer) {
if (producer->getMemLevel() >= consumer->getMemLevel()) {
edge_type = EdgeType::Load;
} else {
edge_type = EdgeType::Store;
}
} else if (producer_type == NodeType::Buffer &&
consumer_type == NodeType::Operator) {
edge_type = EdgeType::Compute;
} else if (producer_type == NodeType::Operator &&
consumer_type == NodeType::Buffer) {
edge_type = EdgeType::Compute;
} else {
TODO("Not implemented yet");
}
}
}; // namespace tiledkernel::graph

0 comments on commit 58f71c2

Please sign in to comment.