Skip to content

Commit

Permalink
Update utilities to support multi-output nodes
Browse files Browse the repository at this point in the history
Differential Revision: D48997052

fbshipit-source-id: dcc116a64d2f5ec29516bb84cace46f0e71cb103
  • Loading branch information
jfix71 authored and facebook-github-bot committed Sep 6, 2023
1 parent b8a45f5 commit c4cbc62
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 21 deletions.
43 changes: 31 additions & 12 deletions include/glow/Graph/FXIRUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,31 +103,50 @@ std::string getNodeName(const folly::dynamic &node);
/// Get the target of the node.
std::string getNodeTarget(const folly::dynamic &node);

/// Get the data type of the node.
ElemKind getNodeDataType(const folly::dynamic &node);
/// Get the data type of the node. \p idx represents which output to get a
/// result for in the case that the node has multiple outputs; if it's a single
/// output then should be left to -1.
ElemKind getNodeDataType(const folly::dynamic &node, int idx = -1);

bool hasFxOutTensorView(const folly::dynamic &node);

const folly::dynamic &getFxOutTensorView(const folly::dynamic &node);
/// Get out tensorview for \p node. If \p idx is non-negative then assume this
/// is a multi-output node, so get the tensorview output for that specific idx.
/// \p idx represents which output to get a result for in the case that the node
/// has multiple outputs; if it's a single output then should be left to -1.
const folly::dynamic &getFxOutTensorView(const folly::dynamic &node,
int idx = -1);

/// \returns specific item \p itemName from \p node. \p idx represents which
/// output to get a result for in the case that the node has multiple outputs;
/// if it's a single output then should be left to -1.
std::string getNodeItemAsString(const folly::dynamic &node,
const char *itemName);
std::string getNodeShapeAsString(const folly::dynamic &node);
std::string getNodeStrideAsString(const folly::dynamic &node);
const char *itemName, int idx = -1);
std::string getNodeShapeAsString(const folly::dynamic &node, int idx = -1);
std::string getNodeStrideAsString(const folly::dynamic &node, int idx = -1);

template <class T>
std::vector<T> getNodeItem(const folly::dynamic &node, const char *itemName) {
const std::string itemString = getNodeItemAsString(node, itemName);
std::vector<T> getNodeItem(const folly::dynamic &node, const char *itemName,
int idx = -1) {
const std::string itemString = getNodeItemAsString(node, itemName, idx);
return toIntegerArray<glow::dim_t>(itemString);
}

template <class T> std::vector<T> getNodeShape(const folly::dynamic &node) {
const std::string shapeString = getNodeShapeAsString(node);
/// \returns the shape from \p node. \p idx represents which output to get a
/// result for in the case that the node has multiple outputs; if it's a single
/// output then should be left to -1.
template <class T>
std::vector<T> getNodeShape(const folly::dynamic &node, int idx = -1) {
const std::string shapeString = getNodeShapeAsString(node, idx);
return toIntegerArray<glow::dim_t>(shapeString);
}

template <class T> std::vector<T> getNodeStride(const folly::dynamic &node) {
const std::string strideString = getNodeStrideAsString(node);
/// \returns the stride from \p node. \p idx represents which output to get a
/// result for in the case that the node has multiple outputs; if it's a single
/// output then should be left to -1.
template <class T>
std::vector<T> getNodeStride(const folly::dynamic &node, int idx = -1) {
const std::string strideString = getNodeStrideAsString(node, idx);
return toIntegerArray<glow::dim_t>(strideString);
}

Expand Down
34 changes: 25 additions & 9 deletions lib/Graph/FXIRUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ std::string glow::getNodeTarget(const folly::dynamic &node) {
return node["target"].getString();
}

ElemKind glow::getNodeDataType(const folly::dynamic &node) {
ElemKind glow::getNodeDataType(const folly::dynamic &node, int idx) {
CHECK(node.find("dtype") != node.items().end())
<< "dtype field doesn't exist in node " << node;
return getElemKind(node.at("dtype").getString());
auto s = idx < 0 ? node.at("dtype").getString()
: node.at("dtype").at(idx).getString();
return getElemKind(s);
}

double glow::getNodeScale(const folly::dynamic &node) {
Expand Down Expand Up @@ -140,10 +142,18 @@ bool glow::hasFxOutTensorView(const folly::dynamic &node) {
return kwargs.find("out_memref") != kwargs.items().end();
}

const folly::dynamic &glow::getFxOutTensorView(const folly::dynamic &node) {
const folly::dynamic &glow::getFxOutTensorView(const folly::dynamic &node,
int idx) {
const auto &kwargs = getNodeKwargs(node);
CHECK(hasFxOutTensorView(node)) << "Node must have 'out_memref'\n";
return kwargs["out_memref"];
const auto &out = kwargs["out_memref"];
if (idx < 0) {
CHECK(out.isObject());
return out;
}
CHECK(out.isArray() && idx < out.size());
CHECK(out.at(idx).isObject());
return out.at(idx);
}

std::vector<dim_t> glow::getOffsets(const folly::dynamic &node) {
Expand All @@ -158,13 +168,13 @@ std::vector<dim_t> glow::getOffsets(const folly::dynamic &node) {
}

//======================================================================
std::string glow::getNodeShapeAsString(const folly::dynamic &node) {
return glow::getNodeItemAsString(node, "shape");
std::string glow::getNodeShapeAsString(const folly::dynamic &node, int idx) {
return glow::getNodeItemAsString(node, "shape", idx);
}

//======================================================================
std::string glow::getNodeStrideAsString(const folly::dynamic &node) {
return getNodeItemAsString(node, "stride");
std::string glow::getNodeStrideAsString(const folly::dynamic &node, int idx) {
return getNodeItemAsString(node, "stride", idx);
}

//======================================================================
Expand All @@ -188,17 +198,23 @@ std::string glow::getNodeOffsetsAsString(const folly::dynamic &node) {
// returns the shape from the destination node.
//======================================================================
std::string glow::getNodeItemAsString(const folly::dynamic &node,
const char *itemName) {
const char *itemName, int idx) {
if (node.find("kwargs") != node.items().end()) {
const auto &kwargs = getNodeKwargs(node);
if (kwargs.find("out_memref") != kwargs.items().end()) {
const auto &out_memref = kwargs["out_memref"]; // out tensor view
if (idx > -1) {
return out_memref.at(idx).at(itemName).getString();
}
return out_memref.at(itemName).getString();
}
}
CHECK(node.find(itemName) != node.items().end())
<< "Neither " << itemName << " nor out_memref exists in node " << node
<< "\n";
if (idx > -1) {
return node.at(itemName).at(idx).getString();
}
return node.at(itemName).getString();
}

Expand Down

0 comments on commit c4cbc62

Please sign in to comment.