diff --git a/test/data/lightning/test_datamodule.py b/test/data/lightning/test_datamodule.py index cb86d810dfa8..6093f5f4e53c 100644 --- a/test/data/lightning/test_datamodule.py +++ b/test/data/lightning/test_datamodule.py @@ -23,6 +23,7 @@ onlyNeighborSampler, onlyOnline, withPackage, + has_package, ) try: @@ -114,12 +115,22 @@ def expect_rank_zero_user_warning(match: str): num_workers=3, shuffle=True) assert 'shuffle' not in datamodule.kwargs old_x = train_dataset._data.x.clone() - assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), ' - 'val_dataset=MUTAG(30), ' - 'test_dataset=MUTAG(10), ' - 'pred_dataset=MUTAG(98), batch_size=5, ' - 'num_workers=3, pin_memory=True, ' - 'persistent_workers=True)') + new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') + datamodule_repr = ( + '{Train dataloader: size=50}\n' + '{Validation dataloader: size=30}\n' + '{Test dataloader: size=10}\n' + '{Predict dataloader: size=98}' + if new_datamodule_repr else + 'LightningDataset(train_dataset=MUTAG(50), ' + 'val_dataset=MUTAG(30), ' + 'test_dataset=MUTAG(10), ' + 'pred_dataset=MUTAG(98), batch_size=5, ' + 'num_workers=3, pin_memory=True, ' + 'persistent_workers=True)' + ) + assert str(datamodule) == datamodule_repr + trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = train_dataset._data.x @@ -133,10 +144,18 @@ def expect_rank_zero_user_warning(match: str): log_every_n_steps=1) datamodule = LightningDataset(train_dataset, batch_size=5) - assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), ' - 'batch_size=5, num_workers=0, ' - 'pin_memory=True, ' - 'persistent_workers=False)') + datamodule_repr = ( + '{Train dataloader: size=50}\n' + '{Validation dataloader: None}\n' + '{Test dataloader: None}\n{' + 'Predict dataloader: None}' + if new_datamodule_repr else + 'LightningDataset(train_dataset=MUTAG(50), ' + 'batch_size=5, num_workers=0, ' + 'pin_memory=True, ' + 'persistent_workers=False)' + ) + assert str(datamodule) == datamodule_repr with expect_rank_zero_user_warning("defined a `validation_step`"): trainer.fit(model, datamodule) @@ -231,11 +250,22 @@ def test_lightning_node_data(get_dataset, strategy_type, loader): num_workers=num_workers, **kwargs) old_x = data.x.clone().cpu() - assert str(datamodule) == (f'LightningNodeData(data={data_repr}, ' - f'loader={loader}, batch_size={batch_size}, ' - f'num_workers={num_workers}, {kwargs_repr}' - f'pin_memory={loader != "full"}, ' - f'persistent_workers={loader != "full"})') + new_datamodule_repr = has_package('pytorch_lightning>=2.5.0') + flag = loader != 'full' + datamodule_repr = ( + '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\n' + '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\n' + '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\n' + '{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}' + if new_datamodule_repr else + f'LightningNodeData(data={data_repr}, ' + f'loader={loader}, batch_size={batch_size}, ' + f'num_workers={num_workers}, {kwargs_repr}' + f'pin_memory={flag}, ' + f'persistent_workers={flag})' + ) + assert str(datamodule) == datamodule_repr + trainer.fit(model, datamodule) trainer.test(model, datamodule) new_x = data.x.cpu()