diff --git a/mindfultensors/creator/__init__.py b/mindfultensors/creator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mindfultensors/creator/base_db_creator/__init__.py b/mindfultensors/creator/base_db_creator/__init__.py new file mode 100644 index 0000000..79d2508 --- /dev/null +++ b/mindfultensors/creator/base_db_creator/__init__.py @@ -0,0 +1,6 @@ +from .base_db_creator import BaseDBCreator + + +__all__ = [ + "BaseDBCreator", +] diff --git a/mindfultensors/creator/base_db_creator/base_db_creator.py b/mindfultensors/creator/base_db_creator/base_db_creator.py new file mode 100644 index 0000000..ab8f9b5 --- /dev/null +++ b/mindfultensors/creator/base_db_creator/base_db_creator.py @@ -0,0 +1,41 @@ +import abc + + +class BaseDBCreator(abc.ABC): + """ + Base class for database creators + """ + def __init__(self): + """ + Base class for database creators + """ + super().__init__() + + @abc.abstractmethod + def connect(self): + """ + Connects to the database + """ + pass + + @abc.abstractmethod + def write(self, *args, **kwargs): + """ + Writes the data + """ + pass + + @abc.abstractmethod + def close(self): + """ + Closes the database connection + """ + pass + + @abc.abstractmethod + def clean(self): + """ + Cleans the database + """ + pass + diff --git a/mindfultensors/creator/mongo_creator/__init__.py b/mindfultensors/creator/mongo_creator/__init__.py new file mode 100644 index 0000000..9bb7556 --- /dev/null +++ b/mindfultensors/creator/mongo_creator/__init__.py @@ -0,0 +1,6 @@ +from .mongo_creator import MongoDBCreator + + +__all__ = [ + "MongoDBCreator", +] diff --git a/mindfultensors/creator/mongo_creator/mongo_creator.py b/mindfultensors/creator/mongo_creator/mongo_creator.py new file mode 100644 index 0000000..6cf16a6 --- /dev/null +++ b/mindfultensors/creator/mongo_creator/mongo_creator.py @@ -0,0 +1,163 @@ +import pandas as pd +from typing import Dict, List, Callable, Optional, Any, Tuple + +from pymongo import MongoClient, ASCENDING + + +from mindfultensors.utils import unit_interval_normalize as normalize +from mindfultensors.creator.base_db_creator import BaseDBCreator +from mindfultensors.creator.mongo_creator.utils import insert_samples + + +class MongoDBCreator(BaseDBCreator): + """ + MongoDB Creator class + """ + + def __init__( + self, + database_name: str, + collection_name: str, + host: str, + port: int, + preprocessing_functions: Optional[Dict[str, Callable]] = None, + chunk_size: int = 10, + ) -> BaseDBCreator: + """ + Constructor + + Args: + database_name: str: name of the database + collection_name: str: name of the collection + host: str: host name + port: int: port number + preprocessing_functions: Optional[Dict[str, Callable]]: dictionary of preprocessing functions + chunk_size: int: size of the chunk + + Returns: + BaseDBCreator: an object of MongoDBCreator class + """ + super().__init__() + self._database_name = database_name + self._collection_name = collection_name + self._host = host + self._port = port + self._url = f"mongodb://{self._host}:{self._port}" + self._preprocessing_functions = preprocessing_functions + self._chunk_size = chunk_size + self._client = None + + def connect(self) -> None: + """ + Connects to the database + """ + self._client = MongoClient(self._url) + self._database = self._client[self._database_name] + self._collection_bin = self._database[f"{self._collection_name}.bin"] + self._collection_meta = self._database[f"{self._collection_name}.meta"] + + def write( + self, + data: pd.DataFrame, + input_columns: List[str], + label_columns: List[str], + meta_columns: List[str], + label_description: Optional[Dict[str, str]] = None, + *args, **kwargs + ) -> None: + """ + Writes the data + + Args: + data: pd.DataFrame: data + input_columns: List[str]: list of input columns + label_columns: List[str]: list of label columns + meta_columns: List[str]: list of meta columns + label_description: Optional[Dict[str, str]: dictionary of label description + + Returns: + None + """ + insert_samples( + data, + input_columns, + label_columns, + meta_columns, + self._collection_bin, + self._collection_meta, + label_description=label_description, + chunk_size=self._chunk_size, + preprocessing_functions=self._preprocessing_functions + ) + index_name_meta = self._collection_meta.create_index([("id", ASCENDING)]) + index_name_bin = self._collection_bin.create_index([("id", ASCENDING)]) + + def clean(self) -> None: + """ + Cleans the database + + Args: + None + + Returns: + None + """ + self._collection_bin.drop() + self._collection_meta.drop() + + def close(self) -> None: + """ + Closes the database connection + + Args: + None + + Returns: + None + """ + self._client.close() + + +if __name__ == "__main__": + + # Example usage + database_name = "mydatabase" + collection_name = "mycollection" + mongo_host = "10.245.12.58" + mongo_port = "27017" + metadata = pd.DataFrame( + columns=['t1', 't2', 'subject_id', 'age', 'gender'] + ) + metadata.loc[metadata.shape[0]] = [ + './test_data/Template-T1-U8-RALPFH-BR.nii.gz', + './test_data/Template-T2-U8-RALPFH-BR.nii.gz', + 1, 25, 'M' + ] + creator = MongoDBCreator( + database_name=database_name, + collection_name=collection_name, + host=mongo_host, + port=mongo_port, + preprocessing_functions={ + 't1': lambda x: normalize(x) * 255, + 't2': lambda x: normalize(x) * 255, + }, + chunk_size=10, + ) + creator.connect() + creator.write( + data=metadata, + input_columns=['t1'], + label_columns=['t2'], + meta_columns=['subject_id', 'age', 'gender'], + label_description={"t2": "Lesion mask"} + ) + sample_doc = creator._collection_bin.find_one({"kind": "t2"}) + print (sample_doc.keys()) + print (sample_doc['chunk_id']) + print (len(sample_doc['chunk'])) + print (sample_doc['kind']) + print (int(creator._collection_bin.find_one(sort=[('id', -1)])['id'] + 1)) + num_examples = int(creator._collection_bin.find_one(sort=[("id", -1)])["id"] + 1) + creator.clean() + creator.close() diff --git a/mindfultensors/creator/mongo_creator/test_data/Template-T1-U8-RALPFH-BR.nii.gz b/mindfultensors/creator/mongo_creator/test_data/Template-T1-U8-RALPFH-BR.nii.gz new file mode 100644 index 0000000..a9bf801 Binary files /dev/null and b/mindfultensors/creator/mongo_creator/test_data/Template-T1-U8-RALPFH-BR.nii.gz differ diff --git a/mindfultensors/creator/mongo_creator/test_data/Template-T2-U8-RALPFH-BR.nii.gz b/mindfultensors/creator/mongo_creator/test_data/Template-T2-U8-RALPFH-BR.nii.gz new file mode 100644 index 0000000..7895243 Binary files /dev/null and b/mindfultensors/creator/mongo_creator/test_data/Template-T2-U8-RALPFH-BR.nii.gz differ diff --git a/mindfultensors/creator/mongo_creator/utils.py b/mindfultensors/creator/mongo_creator/utils.py new file mode 100644 index 0000000..f267c6e --- /dev/null +++ b/mindfultensors/creator/mongo_creator/utils.py @@ -0,0 +1,173 @@ +import bson +import io +import numpy as np +import nibabel as nib +import os +import pandas as pd +from tqdm import tqdm +from typing import Dict, List, Callable, Optional, Any, Tuple + +from pymongo.collection import Collection + +import torch +from torch import Tensor + + +def tensor_2_bin(tensor: Tensor) -> bytes: + """ + Convert tensor to binary + + Args: + tensor: Tensor: tensor + + Returns: + tensor_binary: binary + """ + tensor_1d = tensor.to(torch.uint8) + # Serialize tensor and get binary + buffer = io.BytesIO() + torch.save(tensor_1d, buffer) + tensor_binary = buffer.getvalue() + return tensor_binary + + +def chunk_binobj( + tensor_compressed: Tensor, + id: int, + kind: str, + chunksize: int +) -> Dict[str, Any]: + """ + Chunk the binary object + + Args: + tensor_compressed: Tensor: compressed tensor + id: int: id + kind: str: kind + chunksize: int: chunk size + + Returns: + Dict[str, Any]: dictionary of chunk + """ + # Convert chunksize from megabytes to bytes + chunksize_bytes = chunksize * 1024 * 1024 + + # Calculate the number of chunks + num_chunks = len(tensor_compressed) // chunksize_bytes + if len(tensor_compressed) % chunksize_bytes != 0: + num_chunks += 1 + + # Yield chunks + for i in range(num_chunks): + start = i * chunksize_bytes + end = min((i + 1) * chunksize_bytes, len(tensor_compressed)) + chunk = tensor_compressed[start:end] + yield { + "id": id, + "chunk_id": i, + "kind": kind, + "chunk": bson.Binary(chunk), + } + + +def nifti_filename_2_tensor(filename: str) -> Tensor: + """ + Convert NIFTI filename to tensor + + Args: + filename: str: filename of NIFTI file + + Returns: + Tensor: tensor + """ + assert os.path.exists(filename) + assert filename.endswith(".nii") or filename.endswith(".nii.gz") + return torch.from_numpy(np.asanyarray(nib.load(filename).get_fdata())) + + +def insert_data( + column: str, + filename: str, + index: int, + collection_bin: Collection, + chunk_size: int = 10, + preprocessing_functions: Optional[Dict[str, Callable]] = None, +) -> Tuple[int]: + """ + Insert data + + Args: + column: str: column + filename: str: filename + index: int: index + collection_bin: Collection: collection bin + chunk_size: int: chunk size + preprocessing_functions: Optional[Dict[str, Callable]]: dictionary of preprocessing functions + + Returns: + shape: Tuple[int]: shape + """ + tensor_data = nifti_filename_2_tensor(filename) + shape = tensor_data.shape + if preprocessing_functions and column in preprocessing_functions: + tensor_data = preprocessing_functions[column](tensor_data) + tensor_data = tensor_2_bin(tensor_data) + # write data + for chunk in chunk_binobj(tensor_data, index, column, chunk_size): + collection_bin.insert_one(chunk) + return shape + + +def insert_samples( + data: pd.DataFrame, + input_columns: List[str], + label_columns: List[str], + meta_columns: List[str], + collection_bin: Collection, + collection_meta: Collection, + label_description: Optional[Dict[str, str]] = None, + chunk_size: int = 10, + preprocessing_functions: Optional[Dict[str, Callable]] = None, +) -> None: + """ + Insert samples + + Args: + data: pd.DataFrame: data + input_columns: List[str]: list of input columns + label_columns: List[str]: list of label columns + meta_columns: List[str]: list of meta columns + collection_bin: Collection: collection bin + collection_meta: Collection: collection meta + label_description: Optional[Dict[str, str]]: dictionary of label description + chunk_size: int: chunk size + preprocessing_functions: Optional[Dict[str, Callable]]: dictionary of preprocessing functions + + Returns: + None + """ + selected_columns = input_columns + label_columns + meta_columns + for index in tqdm(data.index): + meta_data = {"id": index, "labels": {}} + for column in selected_columns: + shape = None + value = data[column].iloc[index] + if column in meta_columns: + meta_data[column] = str(value) + else: + shape = insert_data( + column, value, index, + collection_bin, chunk_size, preprocessing_functions=preprocessing_functions + ) + if "shape" not in meta_data: + meta_data["shape"] = shape + else: + assert meta_data["shape"] == shape + if column in label_columns: + if column in label_description: + meta_data["labels"][ + column] = label_description[column] + else: + meta_data["labels"][ + column] = "Label is not described" + collection_meta.insert_one(meta_data) diff --git a/mindfultensors/gencoords.py b/mindfultensors/gencoords.py index 91a6d5c..82775dd 100644 --- a/mindfultensors/gencoords.py +++ b/mindfultensors/gencoords.py @@ -32,16 +32,16 @@ def __init__(self, list_shape=None, list_sub_shape=None, mus=None, sigmas=None): mus = np.array( [ self.volume_shape[0] // 2, - self.volume_shape[0] // 2, - self.volume_shape[0] // 2, + self.volume_shape[1] // 2, + self.volume_shape[2] // 2, ] ) if sigmas is None: sigmas = np.array( [ self.volume_shape[0] // 4, - self.volume_shape[0] // 4, - self.volume_shape[0] // 4, + self.volume_shape[1] // 4, + self.volume_shape[2] // 4, ] ) self.truncnorm_coordinates = truncnorm( diff --git a/mindfultensors/mongoloader.py b/mindfultensors/mongoloader.py index fdbe0ee..784c6d6 100644 --- a/mindfultensors/mongoloader.py +++ b/mindfultensors/mongoloader.py @@ -102,7 +102,6 @@ def __getitem__(self, batch): # Separate processing for each 'kind' data = self.make_serial(samples_for_id, self.sample[0]) label = self.make_serial(samples_for_id, self.sample[1]) - # Add to results results[id] = { "input": self.normalize(self.transform(data).float()), diff --git a/mindfultensors/utils.py b/mindfultensors/utils.py index f46a6fa..750408b 100644 --- a/mindfultensors/utils.py +++ b/mindfultensors/utils.py @@ -86,7 +86,7 @@ class DBBatchSampler(Sampler): data_source: Sized - def __init__(self, data_source, batch_size=1, seed=None): + def __init__(self, data_source, batch_size=1, seed=None, shuffle=True): """TODO describe function :param data_source: a dataset of Dataset class @@ -98,6 +98,7 @@ def __init__(self, data_source, batch_size=1, seed=None): self.data_source = data_source self.data_size = len(self.data_source) self.seed = seed + self.shuffle = shuffle def __chunks__(self, l, n): for i in range(0, len(l), n): @@ -106,9 +107,12 @@ def __chunks__(self, l, n): def __iter__(self): if self.seed is not None: np.random.seed(self.seed) - return self.__chunks__( - np.random.permutation(self.data_size), self.batch_size - ) + if self.shuffle: + return self.__chunks__( + np.random.permutation(self.data_size), self.batch_size + ) + else: + return self.__chunks__(range(self.data_size), self.batch_size) def __len__(self): return (