Skip to content

Commit

Permalink
fix: bool error when num_batches is None
Browse files Browse the repository at this point in the history
  • Loading branch information
billshitg committed Nov 17, 2023
1 parent d139b8d commit cf68fe8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions pyTigerGraph/gds/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,28 +210,28 @@ def on_epoch_start(self, trainer):
self.epoch_bar = self.tqdm(desc="Epochs", total=trainer.num_epochs)
else:
self.epoch_bar = self.tqdm(desc="Training Steps", total=trainer.max_num_steps)
if not(self.batch_bar):
if self.batch_bar is None:
self.batch_bar = self.tqdm(desc="Training Batches", total=trainer.train_loader.num_batches)

def on_train_step_end(self, trainer):
"""NO DOC"""
logger = logging.getLogger(__name__)
logger.info("train_step:"+str(trainer.get_train_step_metrics()))
if self.tqdm:
if self.batch_bar:
if self.batch_bar is not None:
self.batch_bar.update(1)

def on_eval_start(self, trainer):
"""NO DOC"""
trainer.reset_eval_metrics()
if self.tqdm:
if not(self.valid_bar):
if self.valid_bar is None:
self.valid_bar = self.tqdm(desc="Eval Batches", total=trainer.eval_loader.num_batches)

def on_eval_step_end(self, trainer):
"""NO DOC"""
if self.tqdm:
if self.valid_bar:
if self.valid_bar is not None:
self.valid_bar.update(1)

def on_eval_end(self, trainer):
Expand All @@ -240,7 +240,7 @@ def on_eval_end(self, trainer):
logger.info("evaluation:"+str(trainer.get_eval_metrics()))
trainer.model.train()
if self.tqdm:
if self.valid_bar:
if self.valid_bar is not None:
self.valid_bar.close()
self.valid_bar = None

Expand All @@ -249,7 +249,7 @@ def on_epoch_end(self, trainer):
if self.tqdm:
if self.epoch_bar:
self.epoch_bar.update(1)
if self.batch_bar:
if self.batch_bar is not None:
self.batch_bar.close()
self.batch_bar = None
trainer.eval()
Expand Down

0 comments on commit cf68fe8

Please sign in to comment.