Skip to content

Commit

Permalink
Import node debugName (FX graph node name)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Aug 10, 2023
1 parent 82687dc commit caca87b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 4 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ static LogicalResult rewriteMonomorphizedFuncClone(
auto newOp = OpBuilder(op).create<GlobalSlotGetOp>(
op.getLoc(), op.getType(),
objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName());
if (auto fxOutputName = op->getAttr("FXOutputName"))
newOp->setAttr("FXOutputName", fxOutputName);
op.replaceAllUsesWith(&*newOp);
}
toErase.push_back(op);
Expand Down
7 changes: 5 additions & 2 deletions lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,11 @@ class InlineGlobalSlotsPass
getBackwardSliceIncludingRoot(initialValue);
IRMapping mapping;
OpBuilder builder(op);
for (Operation *opInSlice : slice)
builder.clone(*opInSlice, mapping);
for (Operation *opInSlice : slice) {
auto clonedOp = builder.clone(*opInSlice, mapping);
if (auto fxOutputName = op->getAttr("FXOutputName"))
clonedOp->setAttr("FXOutputName", fxOutputName);
}
auto inlinedInitialValue = mapping.lookup(initialValue);
inlinedInitialValue = Torch::adjustStaticInformation(
builder, op.getLoc(), inlinedInitialValue, op.getType(),
Expand Down
4 changes: 3 additions & 1 deletion python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def compile(model: torch.nn.Module,
backend_legal_ops: Optional[Sequence[str]] = None,
extra_library: Iterable[Callable] = [],
verbose: bool = False,
use_make_fx: bool = False):
use_make_fx: bool = False,
add_fx_outputname: bool = False):
"""Convert a PyTorch model to MLIR.
Args:
Expand Down Expand Up @@ -437,6 +438,7 @@ def compile(model: torch.nn.Module,
mb = ModuleBuilder()
import_options = ImportOptions()
import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes
import_options.addFxOutputName = add_fx_outputname
try:
original_stderr = sys.stderr
sys.stderr = StringIO()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ struct ImportOptions {
// In that case, the appropriate shape information is provided via the type
// bound annotations on the function arguments instead.
bool ignoreExistingTensorShapesAndDtypes = false;

// Whether to add the FXOutputName attribute from the debug name of the jit
// node.
bool addFxOutputName = false;
};
} // namespace torch_mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) {
.def_readwrite("assumeTensorsHaveValueSemantics",
&ImportOptions::assumeTensorsHaveValueSemantics)
.def_readwrite("ignoreExistingTensorShapesAndDtypes",
&ImportOptions::ignoreExistingTensorShapesAndDtypes);
&ImportOptions::ignoreExistingTensorShapesAndDtypes)
.def_readwrite("addFxOutputName",
&ImportOptions::addFxOutputName);
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,29 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
MlirLocation loc = getMlirLocationFromNode(context, node);
auto kind = node->kind();


auto outs = node->outputs();
auto output = outs.size() == 1 ? outs[0] : nullptr;

auto addFxOutputNameAttr = [&](MlirOperation& operation) {
if (importOptions.addFxOutputName && output && output->hasDebugName()) {
std::string name = output->debugName();
size_t len = name.size();
if (len > 2 && name[len-2] == '.' && name[len-1] == '1')
name = name.substr(0, len-2);
auto strAttr = mlirStringAttrGet(context, toMlirStringRef(name));
mlirOperationSetAttributeByName(operation, toMlirStringRef("FXOutputName"), strAttr);
}
};

auto createAndMapTrivialNode = [&](Node *node, const std::string &opName,
InputsTransformFn t) {
std::vector<MlirValue> mappedInputs = lookupMappedValues(node->inputs());
MlirOperation operation = createMlirOperationAtEnd(
appendToBlock, opName, loc,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
t ? t(mappedInputs) : mappedInputs);
addFxOutputNameAttr(operation);
mapResults(node, operation);
};

Expand All @@ -102,6 +118,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()),
toMlirNamedAttribute(attrName.c_str(), attr));
addFxOutputNameAttr(operation);
mapResults(node, operation);
};

Expand All @@ -112,6 +129,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock,
appendToBlock, loc, node->schema(),
getMlirTypesFromValues(loc, node->outputs(), importOptions),
lookupMappedValues(node->inputs()));
addFxOutputNameAttr(operation);
mapResults(node, operation);
return;
}
Expand Down

0 comments on commit caca87b

Please sign in to comment.