diff --git a/tensorflow_gnn/api_def/api_symbols_test.py b/tensorflow_gnn/api_def/api_symbols_test.py index 3b678182..373fb3e6 100644 --- a/tensorflow_gnn/api_def/api_symbols_test.py +++ b/tensorflow_gnn/api_def/api_symbols_test.py @@ -149,6 +149,8 @@ del tfgnn.models.vanilla_mpnn del tfgnn.models +# TODO(b/316135889): remove once fixed. +del tfgnn.proto.graph_schema_pb2 ## ## STEP 3: Recursively collect all module attributes exposed by the public API. diff --git a/tensorflow_gnn/proto/__init__.py b/tensorflow_gnn/proto/__init__.py index 2f074c2a..85389bca 100644 --- a/tensorflow_gnn/proto/__init__.py +++ b/tensorflow_gnn/proto/__init__.py @@ -51,5 +51,13 @@ OriginInfo = graph_schema.OriginInfo # Remove all names added by module imports, unless explicitly allowed here. -api_utils.remove_submodules_except(__name__, []) +api_utils.remove_submodules_except( + __name__, + [ + # Workaround for Beam/pickle, required by + # `experimental/sampler/beam/sampler.py`. + # TODO(b/316135889): remove once fixed. + 'graph_schema_pb2', + ], +) # LINT.ThenChange()../api_def/tfgnn-symbols.txt)