Skip to content

Commit

Permalink
[sparse] match fx node using target name instead of variables name (l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Peiming Liu authored May 9, 2024
1 parent 64b59c7 commit 2c22087
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions test/python/fx_importer/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,18 @@ def sparse_export(
node.meta["sparsity"] = sparse_metadata(args[k])
k = k + 1
elif node.op == "call_function":
# TODO: use upstream _opname implementation when available
opname = node.target._schema.name.split("::")[1]
# Zero preserving elt-wise unary op.
if node.name in {"abs", "neg", "relu", "sin"}:
if opname in {"abs", "neg", "relu", "sin"}:
node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
elif node.name == "_to_sparse":
elif opname == "_to_sparse":
dim = len(node.meta.get("val").shape)
node.meta["sparsity"] = SparsityMeta(
torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64
)
# TODO: Uncomment this to hack sparsity into the network.
# elif node.name == "_to_dense":
# elif opname == "_to_dense":
# # hack (assumes we never really want the to_dense for now)
# node.meta["sparsity"] = node.args[0].meta.get("sparsity", None)
return prog
Expand Down

0 comments on commit 2c22087

Please sign in to comment.