From 4eed8a425874f67a5f4f08be0b72acd5161b40f7 Mon Sep 17 00:00:00 2001 From: loic-lb Date: Thu, 16 May 2024 14:36:58 +0200 Subject: [PATCH] Refactor Patch based embedder classes to accept img_folder as a parameter --- src/prismtoolbox/utils/torch_utils.py | 9 ++++----- src/prismtoolbox/wsiemb/embedder.py | 29 +++++++++++++++------------ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/prismtoolbox/utils/torch_utils.py b/src/prismtoolbox/utils/torch_utils.py index 452e0ef..7dd4ec1 100644 --- a/src/prismtoolbox/utils/torch_utils.py +++ b/src/prismtoolbox/utils/torch_utils.py @@ -131,13 +131,12 @@ def create_dataloader(self, dataset): class BasePatchHandler: - def __init__(self, img_folder, batch_size, num_workers, transforms_dict=None): - self.img_folder = img_folder + def __init__(self, batch_size, num_workers, transforms_dict=None): self.batch_size = batch_size self.num_workers = num_workers self.transforms_dict = transforms_dict - def create_dataset(self): + def create_dataset(self, img_folder): if self.transforms_dict is not None: log.info("Creating transform from transforms_dict.") transform = create_transforms(self.transforms_dict) @@ -147,8 +146,8 @@ def create_dataset(self): else: log.info("No transform provided.") transform = None - dataset = ImageFolder(self.img_folder, transform=transform) - log.info(f"Created dataset from {self.img_folder} using ImageFolder from torchvision.") + dataset = ImageFolder(img_folder, transform=transform) + log.info(f"Created dataset from {img_folder} using ImageFolder from torchvision.") return dataset def create_dataloader(self, dataset): diff --git a/src/prismtoolbox/wsiemb/embedder.py b/src/prismtoolbox/wsiemb/embedder.py index 80429c5..31bfe50 100644 --- a/src/prismtoolbox/wsiemb/embedder.py +++ b/src/prismtoolbox/wsiemb/embedder.py @@ -168,7 +168,6 @@ def save_embeddings( class PatchEmbedder(BasePatchHandler): def __init__( self, - img_folder: str, arch_name: str, batch_size: int, num_workers: int, @@ -180,7 +179,6 @@ def __init__( """The PatchEmbedder class is used to extract embeddings from patches extracted as images in a folder. Args: - img_folder: The directory containing the images of the patches. arch_name: The name of the architecture to use. See [create_model][prismtoolbox.wsiemb.emb_utils.create_model] for available architectures. batch_size: The batch size to use for the dataloader. @@ -194,7 +192,6 @@ def __init__( need_login: Whether to login to the HuggingFace Hub (for Uni and Conch models). Attributes: - img_folder: The directory containing the images of the patches. batch_size: The batch size to use for the dataloader. num_workers: The number of workers to use for the dataloader. transforms_dict: The dictionary of transforms to use. @@ -205,7 +202,7 @@ def __init__( """ super().__init__( - img_folder, batch_size, num_workers, transforms_dict + batch_size, num_workers, transforms_dict ) self.device = device @@ -225,27 +222,33 @@ def __init__( self.embeddings = [] - def extract_embeddings(self, show_progress: bool = True): + def extract_embeddings(self, img_folder, show_progress: bool = True): """Extract embeddings from the images in the img_folder. Args: + img_folder: A folder containing a series of subfolders, each containing images. + For example, img_folder could be a folder where the subfolders correpond to different slides. show_progress: Whether to show the progress bar. """ - log.info(f"Extracting embeddings from images in {self.img_folder}.") - dataset = self.create_dataset() + log.info(f"Extracting embeddings from images in {img_folder}.") + dataset = self.create_dataset(img_folder=img_folder) dataloader = self.create_dataloader(dataset) start_time = time.time() - embeddings = [] + embeddings = [[] for _ in range(len(dataset.classes))] + img_ids = [] + for i in range(len(dataset.classes)): + img_ids.append(np.array(dataset.imgs)[np.array(dataset.targets)==i][:,0]) for imgs, folder_id in tqdm( dataloader, - desc=f"Extracting embeddings from images in {self.img_folder}", + desc=f"Extracting embeddings from images in {img_folder}", disable=not show_progress, ): - log.info(f"Extracting embeddings from folder {folder_id}.") imgs = imgs.to(self.device) with torch.no_grad(): output = self.model(imgs) - embeddings.append(output.cpu()) + for i in range(len(dataset.classes)): + embeddings[i].append(output[folder_id==i].cpu()) log.info(f"Embedding time: {time.time() - start_time}.") - log.info(f"Extracted {len(embeddings)} rom images in {self.img_folder}.") - self.embeddings.append(embeddings) \ No newline at end of file + log.info(f"Extracted {len(embeddings)} from images in {img_folder}.") + self.img_ids.append(img_ids) + self.embeddings.append([torch.cat(embedding, dim=0) for embedding in embeddings]) \ No newline at end of file