Skip to content

Commit

Permalink
Refactor Patch based embedder classes to accept img_folder as a param…
Browse files Browse the repository at this point in the history
…eter
  • Loading branch information
loic-lb committed May 16, 2024
1 parent eb5705a commit 4eed8a4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
9 changes: 4 additions & 5 deletions src/prismtoolbox/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,12 @@ def create_dataloader(self, dataset):


class BasePatchHandler:
def __init__(self, img_folder, batch_size, num_workers, transforms_dict=None):
self.img_folder = img_folder
def __init__(self, batch_size, num_workers, transforms_dict=None):
self.batch_size = batch_size
self.num_workers = num_workers
self.transforms_dict = transforms_dict

def create_dataset(self):
def create_dataset(self, img_folder):
if self.transforms_dict is not None:
log.info("Creating transform from transforms_dict.")
transform = create_transforms(self.transforms_dict)
Expand All @@ -147,8 +146,8 @@ def create_dataset(self):
else:
log.info("No transform provided.")
transform = None
dataset = ImageFolder(self.img_folder, transform=transform)
log.info(f"Created dataset from {self.img_folder} using ImageFolder from torchvision.")
dataset = ImageFolder(img_folder, transform=transform)
log.info(f"Created dataset from {img_folder} using ImageFolder from torchvision.")
return dataset

def create_dataloader(self, dataset):
Expand Down
29 changes: 16 additions & 13 deletions src/prismtoolbox/wsiemb/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def save_embeddings(
class PatchEmbedder(BasePatchHandler):
def __init__(
self,
img_folder: str,
arch_name: str,
batch_size: int,
num_workers: int,
Expand All @@ -180,7 +179,6 @@ def __init__(
"""The PatchEmbedder class is used to extract embeddings from patches extracted as images in a folder.
Args:
img_folder: The directory containing the images of the patches.
arch_name: The name of the architecture to use.
See [create_model][prismtoolbox.wsiemb.emb_utils.create_model] for available architectures.
batch_size: The batch size to use for the dataloader.
Expand All @@ -194,7 +192,6 @@ def __init__(
need_login: Whether to login to the HuggingFace Hub (for Uni and Conch models).
Attributes:
img_folder: The directory containing the images of the patches.
batch_size: The batch size to use for the dataloader.
num_workers: The number of workers to use for the dataloader.
transforms_dict: The dictionary of transforms to use.
Expand All @@ -205,7 +202,7 @@ def __init__(
"""
super().__init__(
img_folder, batch_size, num_workers, transforms_dict
batch_size, num_workers, transforms_dict
)
self.device = device

Expand All @@ -225,27 +222,33 @@ def __init__(

self.embeddings = []

def extract_embeddings(self, show_progress: bool = True):
def extract_embeddings(self, img_folder, show_progress: bool = True):
"""Extract embeddings from the images in the img_folder.
Args:
img_folder: A folder containing a series of subfolders, each containing images.
For example, img_folder could be a folder where the subfolders correpond to different slides.
show_progress: Whether to show the progress bar.
"""
log.info(f"Extracting embeddings from images in {self.img_folder}.")
dataset = self.create_dataset()
log.info(f"Extracting embeddings from images in {img_folder}.")
dataset = self.create_dataset(img_folder=img_folder)
dataloader = self.create_dataloader(dataset)
start_time = time.time()
embeddings = []
embeddings = [[] for _ in range(len(dataset.classes))]
img_ids = []
for i in range(len(dataset.classes)):
img_ids.append(np.array(dataset.imgs)[np.array(dataset.targets)==i][:,0])
for imgs, folder_id in tqdm(
dataloader,
desc=f"Extracting embeddings from images in {self.img_folder}",
desc=f"Extracting embeddings from images in {img_folder}",
disable=not show_progress,
):
log.info(f"Extracting embeddings from folder {folder_id}.")
imgs = imgs.to(self.device)
with torch.no_grad():
output = self.model(imgs)
embeddings.append(output.cpu())
for i in range(len(dataset.classes)):
embeddings[i].append(output[folder_id==i].cpu())
log.info(f"Embedding time: {time.time() - start_time}.")
log.info(f"Extracted {len(embeddings)} rom images in {self.img_folder}.")
self.embeddings.append(embeddings)
log.info(f"Extracted {len(embeddings)} from images in {img_folder}.")
self.img_ids.append(img_ids)
self.embeddings.append([torch.cat(embedding, dim=0) for embedding in embeddings])

0 comments on commit 4eed8a4

Please sign in to comment.