diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 680a52be558a..474fe2bfddbc 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -120,16 +120,18 @@ def sparse_export( node.meta["sparsity"] = sparse_metadata(args[k]) k = k + 1 elif node.op == "call_function": + # TODO: use upstream _opname implementation when available + opname = node.target._schema.name.split("::")[1] # Zero preserving elt-wise unary op. - if node.name in {"abs", "neg", "relu", "sin"}: + if opname in {"abs", "neg", "relu", "sin"}: node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) - elif node.name == "_to_sparse": + elif opname == "_to_sparse": dim = len(node.meta.get("val").shape) node.meta["sparsity"] = SparsityMeta( torch.sparse_coo, 0, dim, 0, None, torch.int64, torch.int64 ) # TODO: Uncomment this to hack sparsity into the network. - # elif node.name == "_to_dense": + # elif opname == "_to_dense": # # hack (assumes we never really want the to_dense for now) # node.meta["sparsity"] = node.args[0].meta.get("sparsity", None) return prog