Skip to content

Commit

Permalink
Fixing two lightning tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jan 11, 2025
1 parent ef02854 commit 5ba0010
Showing 1 changed file with 45 additions and 15 deletions.
60 changes: 45 additions & 15 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
onlyNeighborSampler,
onlyOnline,
withPackage,
has_package,
)

try:
Expand Down Expand Up @@ -114,12 +115,22 @@ def expect_rank_zero_user_warning(match: str):
num_workers=3, shuffle=True)
assert 'shuffle' not in datamodule.kwargs
old_x = train_dataset._data.x.clone()
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)')
new_datamodule_repr = has_package('pytorch_lightning>=2.5.0')
datamodule_repr = (
'{Train dataloader: size=50}\n'
'{Validation dataloader: size=30}\n'
'{Test dataloader: size=10}\n'
'{Predict dataloader: size=98}'
if new_datamodule_repr else
'LightningDataset(train_dataset=MUTAG(50), '
'val_dataset=MUTAG(30), '
'test_dataset=MUTAG(10), '
'pred_dataset=MUTAG(98), batch_size=5, '
'num_workers=3, pin_memory=True, '
'persistent_workers=True)'
)
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = train_dataset._data.x
Expand All @@ -133,10 +144,18 @@ def expect_rank_zero_user_warning(match: str):
log_every_n_steps=1)

datamodule = LightningDataset(train_dataset, batch_size=5)
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)')
datamodule_repr = (
'{Train dataloader: size=50}\n'
'{Validation dataloader: None}\n'
'{Test dataloader: None}\n{'
'Predict dataloader: None}'
if new_datamodule_repr else
'LightningDataset(train_dataset=MUTAG(50), '
'batch_size=5, num_workers=0, '
'pin_memory=True, '
'persistent_workers=False)'
)
assert str(datamodule) == datamodule_repr

with expect_rank_zero_user_warning("defined a `validation_step`"):
trainer.fit(model, datamodule)
Expand Down Expand Up @@ -231,11 +250,22 @@ def test_lightning_node_data(get_dataset, strategy_type, loader):
num_workers=num_workers, **kwargs)

old_x = data.x.clone().cpu()
assert str(datamodule) == (f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={loader != "full"}, '
f'persistent_workers={loader != "full"})')
new_datamodule_repr = has_package('pytorch_lightning>=2.5.0')
flag = loader != 'full'
datamodule_repr = (
'{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n'
'{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n'
'{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n'
'{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}'
if new_datamodule_repr else
f'LightningNodeData(data={data_repr}, '
f'loader={loader}, batch_size={batch_size}, '
f'num_workers={num_workers}, {kwargs_repr}'
f'pin_memory={flag}, '
f'persistent_workers={flag})'
)
assert str(datamodule) == datamodule_repr

trainer.fit(model, datamodule)
trainer.test(model, datamodule)
new_x = data.x.cpu()
Expand Down

0 comments on commit 5ba0010

Please sign in to comment.