Skip to content

Commit

Permalink
cache meta in self._meta_cache
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <[email protected]>
  • Loading branch information
KumoLiu committed Jul 27, 2023
1 parent 304efee commit d3c240e
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,7 @@ def __init__(
**kwargs,
)
self.device = device
self._meta_cache: dict[Any, dict[Any, Any]] = {}

Check warning on line 1572 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1571-L1572

Added lines #L1571 - L1572 were not covered by tests

def _cachecheck(self, item_transformed):
"""
Expand Down Expand Up @@ -1599,13 +1600,13 @@ def _cachecheck(self, item_transformed):
if isinstance(item_transformed, dict):
item: dict[Any, Any] = {} # type:ignore
for k in item_transformed:
meta_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta") # type:ignore
meta_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta")
item[k] = kvikio_numpy.fromfile(f"{hashfile}-{k}", dtype=meta_k["dtype"], like=cp.empty(()))
item[k] = convert_to_tensor(item[k].reshape(meta_k["shape"]), device=f"cuda:{self.device}")
item[f"{k}_meta_dict"] = meta_k
return item
elif isinstance(item_transformed, (np.ndarray, torch.Tensor)):
_meta = torch.load(self.cache_dir / f"{hashfile.name}-meta") # type:ignore
_meta = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-meta")
_data = kvikio_numpy.fromfile(f"{hashfile}", dtype=_meta.pop("dtype"), like=cp.empty(()))
_data = convert_to_tensor(_data.reshape(_meta.pop("shape")), device=f"cuda:{self.device}")
if bool(_meta):
Expand All @@ -1615,7 +1616,7 @@ def _cachecheck(self, item_transformed):
item: list[dict[Any, Any]] = [{} for _ in range(len(item_transformed))] # type:ignore
for i, _item in enumerate(item_transformed):
for k in _item:
meta_i_k = torch.load(self.cache_dir / f"{hashfile.name}-{k}-meta-{i}") # type:ignore
meta_i_k = self._load_meta_cache(meta_hash_file_name=f"{hashfile.name}-{k}-meta-{i}")
item_k = kvikio_numpy.fromfile(f"{hashfile}-{k}-{i}", dtype=np.float32, like=cp.empty(()))
item_k = convert_to_tensor(item[i].reshape(meta_i_k["shape"]), device=f"cuda:{self.device}")
item[i].update({k: item_k, f"{k}_meta_dict": meta_i_k})
Expand Down Expand Up @@ -1647,12 +1648,12 @@ def _cachecheck(self, item_transformed):
return _item_transformed

Check warning on line 1648 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1642-L1648

Added lines #L1642 - L1648 were not covered by tests

def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
_item_transformed_meta = data.meta if isinstance(data, MetaTensor) else {}
self._meta_cache[meta_hash_file_name] = copy(data.meta) if isinstance(data, MetaTensor) else {}
_item_transformed_data = data.array if isinstance(data, MetaTensor) else data
if isinstance(_item_transformed_data, torch.Tensor):
_item_transformed_data = _item_transformed_data.numpy()
_item_transformed_meta["shape"] = _item_transformed_data.shape
_item_transformed_meta["dtype"] = _item_transformed_data.dtype
self._meta_cache[meta_hash_file_name]["shape"] = _item_transformed_data.shape
self._meta_cache[meta_hash_file_name]["dtype"] = _item_transformed_data.dtype
kvikio_numpy.tofile(_item_transformed_data, data_hashfile)
try:

Check warning on line 1658 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1651-L1658

Added lines #L1651 - L1658 were not covered by tests
# NOTE: Writing to a temporary directory and then using a nearly atomic rename operation
Expand All @@ -1662,7 +1663,7 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
meta_hash_file = self.cache_dir / meta_hash_file_name
temp_hash_file = Path(tmpdirname) / meta_hash_file_name
torch.save(

Check warning on line 1665 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1662-L1665

Added lines #L1662 - L1665 were not covered by tests
obj=_item_transformed_meta,
obj=self._meta_cache[meta_hash_file_name],
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand All @@ -1677,3 +1678,9 @@ def _create_new_cache(self, data, data_hashfile, meta_hash_file_name):
pass
except PermissionError: # project-monai/monai issue #3613
pass

Check warning on line 1680 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1675-L1680

Added lines #L1675 - L1680 were not covered by tests

def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]

Check warning on line 1684 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1683-L1684

Added lines #L1683 - L1684 were not covered by tests
else:
return torch.load(self.cache_dir / meta_hash_file_name) # type:ignore

Check warning on line 1686 in monai/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

monai/data/dataset.py#L1686

Added line #L1686 was not covered by tests

0 comments on commit d3c240e

Please sign in to comment.