Skip to content

Commit

Permalink
Added optional attr argument to BaseData.snapshot method (#9335)
Browse files Browse the repository at this point in the history
Fixes #9333.

---------

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
nelsonaloysio and rusty1s authored May 21, 2024
1 parent 381d366 commit 9745df0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion torch_geometric/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,14 @@ def snapshot(
self,
start_time: Union[float, int],
end_time: Union[float, int],
attr: str = 'time',
) -> Self:
r"""Returns a snapshot of :obj:`data` to only hold events that occurred
in period :obj:`[start_time, end_time]`.
"""
out = copy.copy(self)
for store in out.stores:
store.snapshot(start_time, end_time)
store.snapshot(start_time, end_time, attr)
return out

def up_to(self, end_time: Union[float, int]) -> Self:
Expand Down
12 changes: 7 additions & 5 deletions torch_geometric/data/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,18 +370,20 @@ def snapshot(
self,
start_time: Union[float, int],
end_time: Union[float, int],
attr: str = 'time',
) -> Self:
if 'time' in self:
mask = (self.time >= start_time) & (self.time <= end_time)
if attr in self:
time = self[attr]
mask = (time >= start_time) & (time <= end_time)

if self.is_node_attr('time'):
if self.is_node_attr(attr):
keys = self.node_attrs()
elif self.is_edge_attr('time'):
elif self.is_edge_attr(attr):
keys = self.edge_attrs()

self._select(keys, mask)

if self.is_node_attr('time') and 'num_nodes' in self:
if self.is_node_attr(attr) and 'num_nodes' in self:
self.num_nodes: Optional[int] = int(mask.sum())

return self
Expand Down

0 comments on commit 9745df0

Please sign in to comment.