Skip to content

Commit

Permalink
fix(Trainer): rm num_batches
Browse files Browse the repository at this point in the history
  • Loading branch information
billshitg committed Nov 10, 2023
1 parent f7dd719 commit e2d6088
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions pyTigerGraph/gds/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
import os
import warnings
import math

class BaseCallback():
"""Base class for training callbacks.
Expand Down Expand Up @@ -145,7 +146,7 @@ def on_train_step_end(self, trainer):
trainer.update_train_step_metrics(metric.get_metrics())
metric.reset_metrics()
trainer.update_train_step_metrics({"global_step": trainer.cur_step})
trainer.update_train_step_metrics({"epoch": int(trainer.cur_step/trainer.train_loader.num_batches)})
trainer.update_train_step_metrics({"epoch": trainer.cur_epoch})

def on_eval_start(self, trainer):
"""NO DOC"""
Expand Down Expand Up @@ -407,12 +408,17 @@ def train(self, num_epochs=None, max_num_steps=None):
Defaults to the length of the `training_dataloader`
"""
if num_epochs:
self.max_num_steps = self.train_loader.num_batches * num_epochs
else:
self.max_num_steps = math.inf
self.num_epochs = num_epochs
elif max_num_steps:
self.max_num_steps = max_num_steps
self.num_epochs = num_epochs
self.num_epochs = math.inf
else:
self.max_num_steps = math.inf
self.num_epochs = 1
self.cur_step = 0
while self.cur_step < self.max_num_steps:
self.cur_epoch = 0
while self.cur_step < self.max_num_steps and self.cur_epoch < self.num_epochs:
for callback in self.callbacks:
callback.on_epoch_start(trainer=self)
for batch in self.train_loader:
Expand All @@ -432,7 +438,7 @@ def train(self, num_epochs=None, max_num_steps=None):
self.cur_step += 1
for callback in self.callbacks:
callback.on_train_step_end(trainer=self)

self.cur_epoch += 1
for callback in self.callbacks:
callback.on_epoch_end(trainer=self)

Expand Down

0 comments on commit e2d6088

Please sign in to comment.