From e5988bbf4a02225ec5bce183822a8e439c445551 Mon Sep 17 00:00:00 2001 From: Dustin Zelle Date: Tue, 5 Dec 2023 10:23:44 -0800 Subject: [PATCH] Add structured readout versions of binary and multiclass classification tasks. PiperOrigin-RevId: 588115018 --- tensorflow_gnn/api_def/runner-symbols.txt | 2 + tensorflow_gnn/runner/__init__.py | 6 +- tensorflow_gnn/runner/tasks/classification.py | 360 +++++++++++++++++- .../runner/tasks/classification_test.py | 125 +++++- 4 files changed, 467 insertions(+), 26 deletions(-) diff --git a/tensorflow_gnn/api_def/runner-symbols.txt b/tensorflow_gnn/api_def/runner-symbols.txt index ff6aa925..780a247c 100644 --- a/tensorflow_gnn/api_def/runner-symbols.txt +++ b/tensorflow_gnn/api_def/runner-symbols.txt @@ -26,6 +26,7 @@ runner.ParameterServerStrategy runner.PassthruDatasetProvider runner.PassthruSampleDatasetsProvider runner.Predictions +runner.NodeBinaryClassification runner.RootNodeBinaryClassification runner.RootNodeLabelFn runner.RootNodeMeanAbsoluteError @@ -34,6 +35,7 @@ runner.RootNodeMeanAbsolutePercentageError runner.RootNodeMeanSquaredError runner.RootNodeMeanSquaredLogScaledError runner.RootNodeMeanSquaredLogarithmicError +runner.NodeMulticlassClassification runner.RootNodeMulticlassClassification runner.RunResult runner.SampleTFRecordDatasetsProvider diff --git a/tensorflow_gnn/runner/__init__.py b/tensorflow_gnn/runner/__init__.py index d7a654b6..897433f6 100644 --- a/tensorflow_gnn/runner/__init__.py +++ b/tensorflow_gnn/runner/__init__.py @@ -94,10 +94,12 @@ # in `distribute_test.py`.) # # Tasks (Classification) -RootNodeBinaryClassification = classification.RootNodeBinaryClassification -RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification GraphBinaryClassification = classification.GraphBinaryClassification GraphMulticlassClassification = classification.GraphMulticlassClassification +NodeBinaryClassification = classification.NodeBinaryClassification +NodeMulticlassClassification = classification.NodeMulticlassClassification +RootNodeBinaryClassification = classification.RootNodeBinaryClassification +RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification # Tasks (Link Prediction) DotProductLinkPrediction = link_prediction.DotProductLinkPrediction HadamardProductLinkPrediction = link_prediction.HadamardProductLinkPrediction diff --git a/tensorflow_gnn/runner/tasks/classification.py b/tensorflow_gnn/runner/tasks/classification.py index 4d94c543..b068cc1d 100644 --- a/tensorflow_gnn/runner/tasks/classification.py +++ b/tensorflow_gnn/runner/tasks/classification.py @@ -112,8 +112,9 @@ def __init__( Args: units: The units for the classification head. (Typically `1` for binary classification and `num_classes` for multiclass classification.) - name: The classification head's layer name. This name typically appears - in the exported model's SignatureDef. + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). label_fn: A label extraction function. This function mutates the input `GraphTensor`. Mutually exclusive with `label_feature_name`. label_feature_name: A label feature name for readout from the auxiliary @@ -148,7 +149,7 @@ def predict(self, inputs: tfgnn.GraphTensor) -> interfaces.Predictions: activations = self.gather_activations(inputs) logits = tf.keras.layers.Dense( self._units, - name=self._name)(activations) # Name seen in SignatureDef. + name=self._name)(activations) return logits def preprocess(self, inputs: GraphTensor) -> tuple[GraphTensor, Field]: @@ -303,21 +304,364 @@ def gather_activations(self, inputs: GraphTensor) -> Field: feature_name=self._state_name)(inputs) -# TODO(dzelle): Add an `__init__` with parameters and doc for all of the below. +class _NodeClassification(_Classification): + """Classification by node(s) via structured readout.""" + + def __init__(self, + key: str = "seed", + *, + feature_name: str = tfgnn.HIDDEN_STATE, + readout_node_set: tfgnn.NodeSetName = "_readout", + validate: bool = True, + **kwargs): + """Classification of node(s) via structured readout. + + This task defines classification via structured readout (see + `tfgnn.keras.layers.StructuredReadout`). Structured readout addresses the + need to read out final hidden states from a GNN computation to make + predictions for some nodes (or edges) of interest. To add auxiliary node + (and edge) sets for structured readout see, e.g.: + `tfgnn.keras.layers.AddReadoutFromFirstNode`. + + Any labels are expected to be sparse, i.e.: scalars. + + Args: + key: A string key to select between possibly multiple named readouts. + feature_name: The name of the feature to read. If unset, + `tfgnn.HIDDEN_STATE` will be read. + readout_node_set: A string, defaults to `"_readout"`. This is used as the + name for the readout node set and as a name prefix for its edge sets. + validate: Setting this to false disables the validity checks for the + auxiliary edge sets. This is stronlgy discouraged, unless great care is + taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on + structurally unchanged GraphTensors. + **kwargs: Additional keyword arguments. + """ + super().__init__(**kwargs) + self._key = key + self._feature_name = feature_name + self._readout_node_set = readout_node_set + self._validate = validate + + def gather_activations(self, inputs: GraphTensor) -> Field: + """Gather activations from auxiliary node (and edge) sets.""" + try: + return tfgnn.keras.layers.StructuredReadout( + self._key, + feature_name=self._feature_name, + readout_node_set=self._readout_node_set, + validate=self._validate)(inputs) + except (KeyError, ValueError) as e: + raise ValueError( + "This NodeClassification task failed in StructuredReadout(" + f"{self._key}, feature_name={self._feature_name}, " + f"readout_node_set={self._readout_node_set}).\n" + "For a dataset of sampled subgraphs that does not provide a readout " + "structure but follows the conventional placement of root nodes " + "first in their node set, consider using a RootNodeClassification " + "task or tfgnn.keras.layers.AddReadoutFromFirstNode." + ) from e + + class GraphBinaryClassification(_GraphClassification, _BinaryClassification): - pass + """Graph binary (or multi-label) classification from pooled node states.""" + + def __init__(self, + node_set_name: str, + units: int = 1, + *, + state_name: str = tfgnn.HIDDEN_STATE, + reduce_type: str = "mean", + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Graph binary (or multi-label) classification. + + This task performs binary classification (or multiple independent ones: + often called multi-label classification). + + Args: + node_set_name: The node set to pool. + units: The units for the classification head. (Typically `1` for binary + classification and the number of labels for multi-label classification.) + state_name: The feature name for activations (e.g.: tfgnn.HIDDEN_STATE). + reduce_type: The context pooling reduction type. + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + node_set_name, + units=units, + state_name=state_name, + reduce_type=reduce_type, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) class GraphMulticlassClassification(_GraphClassification, _MulticlassClassification): - pass + """Graph multiclass classification from pooled node states.""" + + def __init__(self, + node_set_name: str, + *, + num_classes: Optional[int] = None, + class_names: Optional[Sequence[str]] = None, + per_class_statistics: bool = False, + state_name: str = tfgnn.HIDDEN_STATE, + reduce_type: str = "mean", + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Graph multiclass classification from pooled node states. + + Args: + node_set_name: The node set to pool. + num_classes: The number of classes. Exactly one of `num_classes` or + `class_names` must be specified + class_names: The class names. Exactly one of `num_classes` or + `class_names` must be specified + per_class_statistics: Whether to compute statistics per class. + state_name: The feature name for activations (e.g.: tfgnn.HIDDEN_STATE). + reduce_type: The context pooling reduction type. + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + node_set_name, + num_classes=num_classes, + class_names=class_names, + per_class_statistics=per_class_statistics, + state_name=state_name, + reduce_type=reduce_type, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) class RootNodeBinaryClassification(_RootNodeClassification, _BinaryClassification): - pass + """Root node binary (or multi-label) classification.""" + + def __init__(self, + node_set_name: str, + units: int = 1, + *, + state_name: str = tfgnn.HIDDEN_STATE, + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Root node binary (or multi-label) classification. + + This task performs binary classification (or multiple independent ones: + often called multi-label classification). + + The task can be used on graph datasets without a readout structure. + It requires that each input graph stores its unique root node as the + first node of `node_set_name`. + + Args: + node_set_name: The node set containing the root node. + units: The units for the classification head. (Typically `1` for binary + classification and the number of labels for multi-label classification.) + state_name: The feature name for activations (e.g.: tfgnn.HIDDEN_STATE). + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + node_set_name, + units=units, + state_name=state_name, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) class RootNodeMulticlassClassification(_RootNodeClassification, _MulticlassClassification): - pass + """Root node multiclass classification.""" + + def __init__(self, + node_set_name: str, + *, + num_classes: Optional[int] = None, + class_names: Optional[Sequence[str]] = None, + per_class_statistics: bool = False, + state_name: str = tfgnn.HIDDEN_STATE, + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Root node multiclass classification. + + This task can be used on graph datasets without a readout structure. + It requires that each input graph stores its unique root node as the + first node of `node_set_name`. + + Args: + node_set_name: The node set containing the root node. + num_classes: The number of classes. Exactly one of `num_classes` or + `class_names` must be specified + class_names: The class names. Exactly one of `num_classes` or + `class_names` must be specified + per_class_statistics: Whether to compute statistics per class. + state_name: The feature name for activations (e.g.: tfgnn.HIDDEN_STATE). + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + node_set_name, + num_classes=num_classes, + class_names=class_names, + per_class_statistics=per_class_statistics, + state_name=state_name, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) + + +class NodeBinaryClassification(_NodeClassification, _BinaryClassification): + """Node binary (or multi-label) classification via structured readout.""" + + def __init__(self, + key: str = "seed", + units: int = 1, + *, + feature_name: str = tfgnn.HIDDEN_STATE, + readout_node_set: tfgnn.NodeSetName = "_readout", + validate: bool = True, + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Node binary (or multi-label) classification. + + This task performs binary classification (or multiple independent ones: + often called multi-label classification). + + Args: + key: A string key to select between possibly multiple named readouts. + units: The units for the classification head. (Typically `1` for binary + classification and the number of labels for multi-label classification.) + feature_name: The name of the feature to read. If unset, + `tfgnn.HIDDEN_STATE` will be read. + readout_node_set: A string, defaults to `"_readout"`. This is used as the + name for the readout node set and as a name prefix for its edge sets. + validate: Setting this to false disables the validity checks for the + auxiliary edge sets. This is stronlgy discouraged, unless great care is + taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on + structurally unchanged GraphTensors. + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + key, + units=units, + feature_name=feature_name, + readout_node_set=readout_node_set, + validate=validate, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) + + +class NodeMulticlassClassification(_NodeClassification, + _MulticlassClassification): + """Node multiclass classification via structured readout.""" + + def __init__(self, + key: str = "seed", + *, + feature_name: str = tfgnn.HIDDEN_STATE, + readout_node_set: tfgnn.NodeSetName = "_readout", + validate: bool = True, + num_classes: Optional[int] = None, + class_names: Optional[Sequence[str]] = None, + per_class_statistics: bool = False, + name: str = "classification_logits", + label_fn: Optional[LabelFn] = None, + label_feature_name: Optional[str] = None, + **kwargs): + """Node multiclass classification via structured readout. + + Args: + key: A string key to select between possibly multiple named readouts. + feature_name: The name of the feature to read. If unset, + `tfgnn.HIDDEN_STATE` will be read. + readout_node_set: A string, defaults to `"_readout"`. This is used as the + name for the readout node set and as a name prefix for its edge sets. + validate: Setting this to false disables the validity checks for the + auxiliary edge sets. This is stronlgy discouraged, unless great care is + taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on + structurally unchanged GraphTensors. + num_classes: The number of classes. Exactly one of `num_classes` or + `class_names` must be specified + class_names: The class names. Exactly one of `num_classes` or + `class_names` must be specified + per_class_statistics: Whether to compute statistics per class. + name: The classification head's layer name. To control the naming of saved + model outputs see the runner model exporters (e.g., + `KerasModelExporter`). + label_fn: A label extraction function. This function mutates the input + `GraphTensor`. Mutually exclusive with `label_feature_name`. + label_feature_name: A label feature name for readout from the auxiliary + '_readout' node set. Readout does not mutate the input `GraphTensor`. + Mutually exclusive with `label_fn`. + **kwargs: Additional keyword arguments. + """ + super().__init__( + key, + feature_name=feature_name, + readout_node_set=readout_node_set, + validate=validate, + num_classes=num_classes, + class_names=class_names, + per_class_statistics=per_class_statistics, + name=name, + label_fn=label_fn, + label_feature_name=label_feature_name, + **kwargs) diff --git a/tensorflow_gnn/runner/tasks/classification_test.py b/tensorflow_gnn/runner/tasks/classification_test.py index 10140b90..dc30898e 100644 --- a/tensorflow_gnn/runner/tasks/classification_test.py +++ b/tensorflow_gnn/runner/tasks/classification_test.py @@ -29,6 +29,7 @@ # Enables tests for graph pieces that are members of test classes. tfgnn.enable_graph_tensor_validation_at_runtime() +READOUT_KEY = "0x8191" TEST_GRAPH_TENSOR = GraphTensor.from_pieces( context=tfgnn.Context.from_fields( features={"labels": tf.constant((8, 1, 9, 1))} @@ -52,7 +53,16 @@ def fn(inputs): return fn -def with_readout(num_labels: int, gt: GraphTensor) -> GraphTensor: +def add_readout_from_first_node(gt: GraphTensor) -> GraphTensor: + return tfgnn.add_readout_from_first_node( + gt, + key=READOUT_KEY, + node_set_name="nodes") + + +def context_readout_into_feature( + num_labels: int, + gt: GraphTensor) -> GraphTensor: context_fn = lambda inputs: {"labels": inputs["labels"] % num_labels} gt = tfgnn.keras.layers.MapFeatures(context_fn=context_fn)(gt) return tfgnn.experimental.context_readout_into_feature( @@ -81,8 +91,8 @@ def setUp(self): task=classification.GraphBinaryClassification( "nodes", label_feature_name="labels"), - inputs=with_readout(2, TEST_GRAPH_TENSOR), - expected_gt=with_readout(2, TEST_GRAPH_TENSOR), + inputs=context_readout_into_feature(2, TEST_GRAPH_TENSOR), + expected_gt=context_readout_into_feature(2, TEST_GRAPH_TENSOR), expected_labels=(0, 1, 1, 1)), dict( testcase_name="GraphMulticlassClassificationLabelFn", @@ -99,8 +109,8 @@ def setUp(self): "nodes", num_classes=4, label_feature_name="labels"), - inputs=with_readout(4, TEST_GRAPH_TENSOR), - expected_gt=with_readout(4, TEST_GRAPH_TENSOR), + inputs=context_readout_into_feature(4, TEST_GRAPH_TENSOR), + expected_gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR), expected_labels=(0, 1, 1, 1)), dict( testcase_name="RootNodeBinaryClassificationLabelFn", @@ -115,8 +125,8 @@ def setUp(self): task=classification.RootNodeBinaryClassification( "nodes", label_feature_name="labels"), - inputs=with_readout(2, TEST_GRAPH_TENSOR), - expected_gt=with_readout(2, TEST_GRAPH_TENSOR), + inputs=context_readout_into_feature(2, TEST_GRAPH_TENSOR), + expected_gt=context_readout_into_feature(2, TEST_GRAPH_TENSOR), expected_labels=(0, 1, 1, 1)), dict( testcase_name="RootNodeMulticlassClassificationLabelFn", @@ -133,8 +143,50 @@ def setUp(self): "nodes", num_classes=3, label_feature_name="labels"), - inputs=with_readout(3, TEST_GRAPH_TENSOR), - expected_gt=with_readout(3, TEST_GRAPH_TENSOR), + inputs=context_readout_into_feature(3, TEST_GRAPH_TENSOR), + expected_gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR), + expected_labels=(2, 1, 0, 1)), + dict( + testcase_name="NodeBinaryClassificationLabelFn", + task=classification.NodeBinaryClassification( + READOUT_KEY, + label_fn=label_fn(2)), + inputs=add_readout_from_first_node(TEST_GRAPH_TENSOR), + expected_gt=add_readout_from_first_node( + TEST_GRAPH_TENSOR.remove_features(context=("labels",))), + expected_labels=(0, 1, 1, 1)), + dict( + testcase_name="NodeBinaryClassificationReadout", + task=classification.NodeBinaryClassification( + READOUT_KEY, + label_feature_name="labels"), + inputs=add_readout_from_first_node(context_readout_into_feature( + 2, + TEST_GRAPH_TENSOR)), + expected_gt=add_readout_from_first_node( + context_readout_into_feature(2, TEST_GRAPH_TENSOR)), + expected_labels=(0, 1, 1, 1)), + dict( + testcase_name="NodeMulticlassClassificationLabelFn", + task=classification.NodeMulticlassClassification( + READOUT_KEY, + num_classes=3, + label_fn=label_fn(3)), + inputs=add_readout_from_first_node(TEST_GRAPH_TENSOR), + expected_gt=add_readout_from_first_node( + TEST_GRAPH_TENSOR.remove_features(context=("labels",))), + expected_labels=(2, 1, 0, 1)), + dict( + testcase_name="NodeMulticlassClassificationReadout", + task=classification.NodeMulticlassClassification( + READOUT_KEY, + num_classes=3, + label_feature_name="labels"), + inputs=add_readout_from_first_node(context_readout_into_feature( + 3, + TEST_GRAPH_TENSOR)), + expected_gt=add_readout_from_first_node( + context_readout_into_feature(3, TEST_GRAPH_TENSOR)), expected_labels=(2, 1, 0, 1)), ]) def test_preprocess( @@ -163,7 +215,7 @@ def test_preprocess( "nodes", num_classes=4, label_feature_name="labels"), - gt=with_readout(4, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, expected_shape=tf.TensorShape((None, 4))), dict( @@ -180,7 +232,26 @@ def test_preprocess( "nodes", num_classes=3, label_feature_name="labels"), - gt=with_readout(3, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR), + expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, + expected_shape=tf.TensorShape((None, 3))), + dict( + testcase_name="NodeBinaryClassification", + task=classification.NodeBinaryClassification( + READOUT_KEY, + label_fn=label_fn(2)), + gt=add_readout_from_first_node(TEST_GRAPH_TENSOR), + expected_loss=tf.keras.losses.BinaryCrossentropy, + expected_shape=tf.TensorShape((None, 1))), + dict( + testcase_name="NodeMulticlassClassification", + task=classification.NodeMulticlassClassification( + READOUT_KEY, + num_classes=3, + label_feature_name="labels"), + gt=add_readout_from_first_node(context_readout_into_feature( + 3, + TEST_GRAPH_TENSOR)), expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, expected_shape=tf.TensorShape((None, 3))), ]) @@ -197,7 +268,12 @@ def test_predict( self.assertIsInstance(model.layers[0], tf.keras.layers.InputLayer) self.assertIsInstance( model.layers[1], - (tfgnn.keras.layers.ReadoutFirstNode, tfgnn.keras.layers.Pool)) + ( + tfgnn.keras.layers.ReadoutFirstNode, + tfgnn.keras.layers.Pool, + tfgnn.keras.layers.StructuredReadout, + ), + ) self.assertIsInstance(model.layers[2], tf.keras.layers.Dense) _, _, dense = model.layers @@ -223,7 +299,7 @@ def test_predict( "nodes", num_classes=4, label_feature_name="labels"), - gt=with_readout(4, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR), batch_size=1), dict( testcase_name="RootNodeBinaryClassification", @@ -238,7 +314,24 @@ def test_predict( "nodes", num_classes=3, label_feature_name="labels"), - gt=with_readout(3, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR), + batch_size=1), + dict( + testcase_name="NodeBinaryClassification", + task=classification.NodeBinaryClassification( + READOUT_KEY, + label_fn=label_fn(2)), + gt=add_readout_from_first_node(TEST_GRAPH_TENSOR), + batch_size=1), + dict( + testcase_name="NodeMulticlassClassification", + task=classification.NodeMulticlassClassification( + READOUT_KEY, + num_classes=3, + label_feature_name="labels"), + gt=add_readout_from_first_node(context_readout_into_feature( + 3, + TEST_GRAPH_TENSOR)), batch_size=1), dict( testcase_name="GraphBinaryClassificationBatchSize2", @@ -253,7 +346,7 @@ def test_predict( "nodes", num_classes=4, label_feature_name="labels"), - gt=with_readout(4, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR), batch_size=2), dict( testcase_name="RootNodeBinaryClassificationBatchSize2", @@ -268,7 +361,7 @@ def test_predict( "nodes", num_classes=3, label_feature_name="labels"), - gt=with_readout(3, TEST_GRAPH_TENSOR), + gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR), batch_size=2), ]) def test_fit(