Skip to content

Commit

Permalink
Use batch_per_epoch to slice the train data (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel4x authored Jan 13, 2024
1 parent a1a9caf commit 04d5c13
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion script/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import math
import pprint
from itertools import islice

import torch
import torch_geometric as pyg
Expand Down Expand Up @@ -59,7 +60,7 @@ def train_and_validate(cfg, model, train_data, valid_data, device, logger, filte

losses = []
sampler.set_epoch(epoch)
for batch in train_loader:
for batch in islice(train_loader, batch_per_epoch):
batch = tasks.negative_sampling(train_data, batch, cfg.task.num_negative,
strict=cfg.task.strict_negative)
pred = parallel_model(train_data, batch)
Expand Down

0 comments on commit 04d5c13

Please sign in to comment.