Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 23, 2024
1 parent 4a5c789 commit 0bea004
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 11 deletions.
5 changes: 2 additions & 3 deletions examples/pytorch_lightning/gin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def configure_optimizers(self):

devices = torch.cuda.device_count()
strategy = DDPStrategy(accelerator='gpu')
checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1,
mode='max')
checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max')
trainer = Trainer(strategy=strategy, devices=devices, max_epochs=50,
log_every_n_steps=5, callbacks=[checkpoint])
log_every_n_steps=5, callbacks=[checkpoint])

trainer.fit(model, datamodule)
trainer.test(ckpt_path='best', datamodule=datamodule)
5 changes: 2 additions & 3 deletions examples/pytorch_lightning/graph_sage.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ def configure_optimizers(self):
model = Model(dataset.num_node_features, dataset.num_classes)

strategy = SingleDeviceStrategy('cuda:0')
checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1,
mode='max')
checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max')
trainer = Trainer(strategy=strategy, devices=1, max_epochs=20,
callbacks=[checkpoint])
callbacks=[checkpoint])

trainer.fit(model, datamodule)
trainer.test(ckpt_path='best', datamodule=datamodule)
8 changes: 4 additions & 4 deletions test/data/lightning/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def expect_rank_zero_user_warning(match: str):
model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)

trainer = L.Trainer(strategy=strategy, devices=devices, max_epochs=1,
log_every_n_steps=1)
log_every_n_steps=1)
with pytest.warns(UserWarning, match="'shuffle=True' option is ignored"):
datamodule = LightningDataset(train_dataset, val_dataset, test_dataset,
pred_dataset, batch_size=5,
Expand All @@ -131,7 +131,7 @@ def expect_rank_zero_user_warning(match: str):
# Test with `val_dataset=None` and `test_dataset=None`:
if strategy_type is None:
trainer = L.Trainer(strategy=strategy, devices=devices, max_epochs=1,
log_every_n_steps=1)
log_every_n_steps=1)

datamodule = LightningDataset(train_dataset, batch_size=5)
assert str(datamodule) == ('LightningDataset(train_dataset=MUTAG(50), '
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_lightning_node_data(get_dataset, strategy_type, loader):
kwargs_repr += 'num_neighbors=[5], '

trainer = L.Trainer(strategy=strategy, devices=devices, max_epochs=5,
log_every_n_steps=1)
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size,
num_workers=num_workers, **kwargs)

Expand Down Expand Up @@ -314,7 +314,7 @@ def test_lightning_hetero_node_data(preserve_context, get_dataset):
strategy = DDPStrategy(accelerator='gpu')

trainer = L.Trainer(strategy=strategy, devices=devices, max_epochs=5,
log_every_n_steps=1)
log_every_n_steps=1)
datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5],
batch_size=32, num_workers=3)
assert isinstance(datamodule.graph_sampler, NeighborSampler)
Expand Down
2 changes: 1 addition & 1 deletion test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_train(destroy_process_group, tmp_path, capfd):
cfg.params = params_count(model)
logger = LoggerCallback()
trainer = L.Trainer(max_epochs=1, max_steps=4, callbacks=logger,
log_every_n_steps=1)
log_every_n_steps=1)
train_loader, val_loader = loaders[0], loaders[1]
trainer.fit(model, train_loader, val_loader)

Expand Down

0 comments on commit 0bea004

Please sign in to comment.