Skip to content

Commit

Permalink
Refactor datasets and update dependencies (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcogidi authored Sep 18, 2024
1 parent fb7f8b3 commit 26bb863
Show file tree
Hide file tree
Showing 7 changed files with 545 additions and 480 deletions.
927 changes: 498 additions & 429 deletions poetry.lock

Large diffs are not rendered by default.

26 changes: 12 additions & 14 deletions projects/med_benchmarking/datasets/mimiciv_cxr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import os
from typing import Callable, Literal, Optional, get_args
from typing import Callable, Literal, Optional, Union, get_args

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
split: Literal["train", "validate", "test"],
labeler: Literal["chexpert", "negbio", "double_image", "single_image"],
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], Union[torch.Tensor, dict]]] = None,
include_report: bool = False,
) -> None:
"""Initialize the dataset."""
Expand Down Expand Up @@ -103,20 +103,18 @@ def __init__(
self._labeler = labeler

if self._labeler in ["double_image", "single_image"]:
df = pd.read_csv(data_path)
df = df.dropna(subset=["caption"]) # some captions are missing
self.entries = df.to_dict("records")
self.data_df = pd.read_csv(data_path)
self.data_df = self.data_df.dropna(
subset=["caption"]
) # some captions are missing
else:
with open(data_path, "rb") as file:
entries = json.load(file)
self.data_df = pd.read_json(data_path)

# remove entries with no label if reports are not requested either
old_num = len(entries)
entries_df = pd.DataFrame(entries)
entries_df = entries_df[entries_df["label"].apply(len) > 0]
self.entries = entries_df.to_dict("records")
old_num = len(self.data_df)
entries_df = self.data_df[self.data_df["label"].apply(len) > 0]
logger.info(
f"{old_num - len(entries)} datapoints removed due to lack of a label."
f"{old_num - len(self.data_df)} datapoints removed due to lack of a label."
)

if transform is not None:
Expand All @@ -128,7 +126,7 @@ def __init__(

def __getitem__(self, idx: int) -> Example:
"""Return all the images and the label vector of the idx'th study."""
entry = self.entries[idx]
entry = self.data_df.iloc[idx]
img_path = entry["image_path"]

with Image.open(
Expand Down Expand Up @@ -171,7 +169,7 @@ def __getitem__(self, idx: int) -> Example:

def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.entries)
return len(self.data_df)


class CreateJSONFiles(object):
Expand Down
34 changes: 17 additions & 17 deletions projects/med_benchmarking/datasets/pmcoa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""PMC-OA dataset."""

import os
from typing import Any, Callable, Dict, Literal, Optional, Tuple
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

import jsonlines
import pandas as pd
import torch
from omegaconf import MISSING
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import pyarrow.json as pj
from pyarrow import csv
import pyarrow as pa

from mmlearn.conf import external_store
from mmlearn.constants import EXAMPLE_INDEX_KEY
Expand All @@ -30,7 +31,7 @@ def __init__(
caption_key: str = "caption",
csv_separator: str = ",",
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], Union[torch.Tensor, dict]]] = None,
mask_generator: Optional[
Callable[
[Dict[str, torch.Tensor], Any],
Expand Down Expand Up @@ -104,13 +105,13 @@ def __len__(self) -> int:
def __getitem__(self, idx: int) -> Example:
"""Return items in the dataset."""
image_path = os.path.join(
self.root_dir, self.image_dir, self.image_filenames[idx]
self.root_dir, self.image_dir, self.image_filenames[idx].as_py()
)

with Image.open(image_path) as img:
images = self.transform(img)

caption = str(self.captions[idx])
caption = self.captions[idx].as_py()
example = Example(
{
Modalities.RGB: images,
Expand Down Expand Up @@ -141,19 +142,18 @@ def __getitem__(self, idx: int) -> Example:

def _csv_loader(
self, input_filename: str, img_key: str, caption_key: str, sep: str
) -> Tuple[Any, Any]:
) -> Tuple[pa.ChunkedArray, pa.ChunkedArray]:
"""Load images, captions from CSV data."""
df = pd.read_csv(input_filename, sep=sep)
images, captions = df[img_key].tolist(), df[caption_key].tolist()
return images, captions
table = csv.read_csv(
input_filename,
parse_options=csv.ParseOptions(delimiter=sep, newlines_in_values=True),
)
return table[img_key], table[caption_key]

def _jsonl_loader(
self, input_filename: str, img_key: str, caption_key: str
) -> Tuple[Any, Any]:
) -> Tuple[pa.ChunkedArray, pa.ChunkedArray]:
"""Load images, captions from JSON data."""
images, captions = [], []
with jsonlines.open(input_filename) as reader:
for obj in reader:
images.append(obj[img_key])
captions.append(obj[caption_key])
return images, captions
parse_options = pj.ParseOptions(newlines_in_values=True)
table = pj.read_json(input_filename, parse_options=parse_options)
return table[img_key], table[caption_key]
18 changes: 9 additions & 9 deletions projects/med_benchmarking/datasets/quilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import ast
import os
from typing import Callable, List, Literal, Optional
from typing import Callable, List, Literal, Optional, Union

import pandas as pd
import torch
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
split: Literal["train", "val"] = "train",
subset: Optional[List[str]] = None,
transform: Optional[Callable[[Image.Image], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], torch.Tensor]] = None,
tokenizer: Optional[Callable[[str], Union[torch.Tensor, dict]]] = None,
processor: Optional[
Callable[[Image.Image, str], tuple[torch.Tensor, torch.Tensor]]
] = None,
Expand Down Expand Up @@ -128,17 +128,17 @@ def __getitem__(self, idx: int) -> Example:
try:
with Image.open(
os.path.join(
self.root_dir, "quilt_1m", self.data_df["image_path"].iloc[idx]
self.root_dir, "quilt_1m", self.data_df.loc[idx, "image_path"]
)
) as img:
image = img.convert("RGB")
except Exception as e:
print(f"ERROR: {e} on {self.data_df['image_path'].iloc[idx]}")
print(f"ERROR: {e} on {self.data_df.loc[idx, 'image_path']}")

if self.transform is not None:
image = self.transform(image)

caption = self.data_df["caption"].iloc[idx]
caption = self.data_df.loc[idx, "caption"]
tokens = self.tokenizer(caption) if self.tokenizer is not None else None

if self.processor is not None:
Expand All @@ -150,9 +150,9 @@ def __getitem__(self, idx: int) -> Example:
Modalities.TEXT: caption,
EXAMPLE_INDEX_KEY: idx,
"qid": self.data_df.index[idx],
"magnification": self.data_df["magnification"].iloc[idx],
"height": self.data_df["height"].iloc[idx],
"width": self.data_df["width"].iloc[idx],
"magnification": self.data_df.loc[idx, "magnification"],
"height": self.data_df.loc[idx, "height"],
"width": self.data_df.loc[idx, "width"],
}
)

Expand All @@ -169,7 +169,7 @@ def __getitem__(self, idx: int) -> Example:

def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data_df.index)
return len(self.data_df)


def _safe_eval(x: str) -> list[str]:
Expand Down
13 changes: 5 additions & 8 deletions projects/med_benchmarking/datasets/roco.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""ROCO Dataset."""

import json
import os
from typing import Callable, Dict, Literal, Optional, Union

import torch
import pandas as pd
from omegaconf import MISSING
from PIL import Image
from torch.utils.data import Dataset
Expand Down Expand Up @@ -55,9 +55,7 @@ def __init__(
) -> None:
"""Initialize the dataset."""
data_path = os.path.join(root_dir, group + split + "_dataset.json")
with open(data_path, encoding="utf-8") as file:
entries = [json.loads(line) for line in file.readlines()]
self.entries = entries
self.data_df = pd.read_json(data_path, lines=True)

if processor is None and transform is None:
self.transform = ToTensor()
Expand All @@ -80,14 +78,13 @@ def __getitem__(self, idx: int) -> Example:
image and free text caption are returned. Otherwise, the image, free-
text caption, and caption tokens are returned.
"""
entry = self.entries[idx]
with Image.open(entry["image_path"]) as img:
with Image.open(self.data_df.loc[idx, "image_path"]) as img:
image = img.convert("RGB")

if self.transform is not None:
image = self.transform(image)

caption = entry["caption"]
caption = self.data_df.loc[idx, "caption"]
tokens = self.tokenizer(caption) if self.tokenizer is not None else None

if self.processor is not None:
Expand All @@ -114,4 +111,4 @@ def __getitem__(self, idx: int) -> Example:

def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.entries)
return len(self.data_df)
2 changes: 1 addition & 1 deletion projects/med_benchmarking/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
timm~=1.0.7
torchvision~=0.19.0
wandb~=0.17.7
wandb~=0.18.0
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ packages = [
[tool.poetry.dependencies]
python = ">=3.9, <3.13"
numpy = "^1.26.4"
pyarrow = "^17.0.0"
hydra-core = "^1.3.0"
hydra-zen = "^0.13.0"
hydra-submitit-launcher = "^1.2.0"
transformers = "^4.44.0"
torch = {version = "^2.4.0", source = "torch-cu121"}
lightning = "^2.4.0"
jsonlines = "^4.0.0"
pandas = {version="^2.2.2", extras=["performance"]}
torchvision = {version = "^0.19.0", source = "torch-cu121"}

Expand All @@ -28,6 +28,7 @@ timm = {version = "^1.0.8", optional = true}
torchaudio = {version = "^2.4.0", source = "torch-cu121", optional = true}
peft = {version = "^0.12.0", optional = true}


[tool.poetry.group.vision.dependencies]
opencv-python = "^4.10.0.84"
timm = "^1.0.8"
Expand All @@ -45,7 +46,7 @@ peft = "^0.12.0"
optional = true

[tool.poetry.group.dev.dependencies]
wandb = "^0.17.7"
wandb = "^0.18.0"
ipykernel = "^6.29.5"
ipython = "8.18.0"
h5py = "^3.11.0"
Expand Down

0 comments on commit 26bb863

Please sign in to comment.