Skip to content

Commit

Permalink
Adds interfaces for temporal graph sampling
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627346700
  • Loading branch information
Graph Learning Team authored and tensorflower-gardener committed Apr 25, 2024
1 parent 4edb0dd commit e97cb3d
Showing 1 changed file with 84 additions and 6 deletions.
90 changes: 84 additions & 6 deletions tensorflow_gnn/experimental/sampler/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@

class SamplingPrimitive(abc.ABC):
"""Base class for all sampling primitives."""

pass


Expand Down Expand Up @@ -129,8 +130,7 @@ def call(self, keys: tf.RaggedTensor) -> tf.RaggedTensor:


class KeyToFeaturesAccessor(AccessorBase):
"""Generic key to features dict accessor.
"""
"""Generic key to features dict accessor."""

@abc.abstractmethod
def call(self, keys: tf.RaggedTensor) -> Features:
Expand Down Expand Up @@ -165,7 +165,6 @@ def call(self, source_node_ids: tf.RaggedTensor) -> Features:
shape `[batch_size, (num_source_nodes)]` and tf.int32, tf.int64 or
tf.string type.
Returns:
`Features` containing the subset of all edges whose source nodes are
in `source_node_ids`. All returned features must have shape
Expand All @@ -182,6 +181,47 @@ def edge_set_name(self) -> str:
raise NotImplementedError


class TemporalOutgoingEdgesSampler(SamplingPrimitive):
"""Samples outgoing edges for given source nodes at a specific point in time.
Used to create rooted subgraphs from temporal graphs as of a specific point in
time.
"""

def __call__(
self, source_node_ids: tf.RaggedTensor, *, timestamps: tf.Tensor
) -> Features:
return self.call(source_node_ids=source_node_ids, timestamps=timestamps)

@abc.abstractmethod
def call(
self, source_node_ids: tf.RaggedTensor, *, timestamps: tf.Tensor
) -> Features:
"""Samples outgoing edges for the given source node ids and timestamps.
Args:
source_node_ids: node ids for sampling outgoing edges. Ragged tensor with
shape `[batch_size, (num_source_nodes)]` and tf.int32, tf.int64 or
tf.string dtype.
timestamps: points in time to filter edge states. Tensor with shape
`[batch_size]` and tf.int32, tf.int64 dtype.
Returns:
`Features` containing a subset of all edges whose source nodes are in
`source_node_ids` as of their state as at `timestamps`. All returned
features have shape `[batch_size, (num_edges), ...]`. The result includes
two special features "#source" and "#target" of rank 2 containing,
respectively, source node ids and target node ids of the sampled edges.
"""
raise NotImplementedError

@property
@abc.abstractmethod
def edge_set_name(self) -> str:
"""The edge set name."""
raise NotImplementedError


class UniformEdgesSampler(OutgoingEdgesSampler):
"""Samples up to the `sample_size` outgoing edges uniformly at random."""

Expand All @@ -198,6 +238,22 @@ def edge_target_feature_name(self) -> str:
raise NotImplementedError


class TemporalUniformEdgesSampler(TemporalOutgoingEdgesSampler):
"""Samples up to the `sample_size` outgoing edges uniformly at random."""

@property
@abc.abstractmethod
def sample_size(self) -> int:
"""The maximum number of edges to sample."""
raise NotImplementedError

@property
@abc.abstractmethod
def edge_target_feature_name(self) -> str:
"""The input feature name containing edge target node ids."""
raise NotImplementedError


class TopKEdgesSampler(OutgoingEdgesSampler):
"""Samples up to the `sample_size` top weighted outgoing edges."""

Expand All @@ -220,6 +276,28 @@ def weight_feature_name(self) -> str:
raise NotImplementedError


class TemporalTopKEdgesSampler(TemporalOutgoingEdgesSampler):
"""Samples up to the `sample_size` top weighted outgoing edges."""

@property
@abc.abstractmethod
def sample_size(self) -> int:
"""The maximum number of edges to sample."""
raise NotImplementedError

@property
@abc.abstractmethod
def edge_target_feature_name(self) -> str:
"""The input feature name containing edge target node ids."""
raise NotImplementedError

@property
@abc.abstractmethod
def weight_feature_name(self) -> str:
"""The input feature name containing edge weights."""
raise NotImplementedError


class ConnectingEdgesSampler(SamplingPrimitive):
"""Samples incident edges between given subsets of source and target nodes.
Expand All @@ -229,9 +307,9 @@ class ConnectingEdgesSampler(SamplingPrimitive):
"""

@abc.abstractmethod
def call(self,
source_node_ids: tf.RaggedTensor,
target_node_ids: tf.RaggedTensor) -> Features:
def call(
self, source_node_ids: tf.RaggedTensor, target_node_ids: tf.RaggedTensor
) -> Features:
"""Samples incident edges *from* source *on* target node ids.
Each sampled edges has its source in the `source_node_ids` and its target in
Expand Down

0 comments on commit e97cb3d

Please sign in to comment.