diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index 561abc55..67d42750 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -1,4 +1,5 @@ import os +import time from glob import glob from pathlib import Path @@ -28,6 +29,8 @@ def _precompute( tile_shape, halo, precompute_amg_state, checkpoint_path, device, ndim, prefer_decoder, ): + t_start = time.time() + device = util.get_device(device) predictor, state = util.get_sam_model( model_type=model_type, checkpoint_path=checkpoint_path, device=device, return_state=True @@ -54,6 +57,11 @@ def _precompute( ] assert all(os.path.exists(emb_path) for emb_path in embedding_paths) + t_run = time.time() - t_start + minutes = int(t_run // 60) + seconds = int(round(t_run % 60, 0)) + print("Precomputation took", t_run, f"seconds (= {minutes:02}:{seconds:02} minutes)") + return predictor, decoder, embedding_paths diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 43fd28df..27922452 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -1,4 +1,5 @@ import os +import time from glob import glob from typing import Any, Dict, List, Optional, Tuple, Union @@ -181,6 +182,8 @@ def train_sam( save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. """ + t_start = time.time() + _check_loader(train_loader, with_segmentation_decoder) _check_loader(val_loader, with_segmentation_decoder) @@ -281,6 +284,12 @@ def train_sam( trainer.fit(**trainer_fit_params) + t_run = time.time() - t_start + hours = int(t_run // 3600) + minutes = int(t_run // 60) + seconds = int(round(t_run % 60, 0)) + print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") + def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): if not isinstance(raw_paths, (str, os.PathLike)):