-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils.py
30 lines (24 loc) · 915 Bytes
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os
from tqdm import tqdm
from functools import partial
from multiprocessing import Pool
from typing import List, Tuple
import nibabel as nib
import numpy as np
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
def load_nii(path: str, dtype: str = 'float32') -> np.ndarray:
"""Load an MRI scan from disk and convert it to a given datatype
:param path: Path to file
:param dtype: Target dtype
:return img: Loaded image. Shape (H, W, D)
"""
return nib.load(path).get_fdata().astype(np.dtype(dtype))
def load_segmentations(paths: str):
"""Load all segmentations and associated subject_ids"""
filenames, segmentations = [], []
for im in tqdm(paths):
id = im.split('_brain_')[0].split('/')[-1].split('-')[1].split('_')[0]
segmentations.append(load_nii(im))
filenames.append(id)
return filenames, np.array(segmentations)