Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put the RunDB API interface and MongoDB interface together #1442

Merged
merged 18 commits into from
Oct 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 218 additions & 23 deletions straxen/storage/mongo_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,83 @@
from datetime import datetime
from warnings import warn
import pytz
from typing import Tuple, Dict
from typing import Tuple, Dict, Any, List, Optional, Union
from strax import exporter, to_str_tuple
import gridfs
from tqdm import tqdm
from shutil import move
import hashlib
from pymongo.collection import Collection
import utilix
from straxen import uconfig
from utilix.rundb import DB, xent_collection
from utilix import uconfig, logger


export, __all__ = exporter()


@export
class GridFsInterface:
class GridFsBase:
"""Base class for GridFS operations."""

def __init__(self, config_identifier: str = "config_name", **kwargs: Any) -> None:
self.config_identifier = config_identifier

def get_query_config(self, config: str) -> Dict[str, str]:
"""Generate query identifier for a config."""
return {self.config_identifier: config}

def document_format(self, config):
"""Format of the document to upload.

:param config: str, name of the file of interest
:return: dict, that will be used to add the document

"""
doc = self.get_query_config(config)
doc.update(
{
"added": datetime.now(tz=pytz.utc),
}
)
return doc

def config_exists(self, config: str) -> bool:
"""Check if a config exists."""
raise NotImplementedError

def md5_stored(self, abs_path: str) -> bool:
"""Check if file with given MD5 is stored."""
raise NotImplementedError

def test_find(self) -> None:
"""Test the find operation."""
raise NotImplementedError

def list_files(self) -> List[str]:
"""List all files in the database."""
raise NotImplementedError

@staticmethod
def compute_md5(abs_path: str) -> str:
"""Compute MD5 hash of a file.

RAM intensive operation.

"""
if not os.path.exists(abs_path):
return ""
# bandit: disable=B303
hash_md5 = hashlib.md5()
with open(abs_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()


@export
class GridFsInterfaceMongo(GridFsBase):
"""
Base class to upload/download the files to a database using GridFS
Class to upload/download the files to a database using GridFS
for PyMongo:
https://pymongo.readthedocs.io/en/stable/api/gridfs/index.html#module-gridfs

Expand Down Expand Up @@ -75,7 +134,7 @@ def __init__(
}
# We can safely hard-code the collection as that is always
# the same with GridFS.
collection = utilix.rundb.xent_collection(**mongo_kwargs, collection="fs.files")
collection = xent_collection(**mongo_kwargs, collection="fs.files")
else:
# Check the user input is fine for what we want to do.
if not isinstance(collection, Collection):
Expand Down Expand Up @@ -103,21 +162,6 @@ def get_query_config(self, config):
"""
return {self.config_identifier: config}

def document_format(self, config):
"""Format of the document to upload.

:param config: str, name of the file of interest
:return: dict, that will be used to add the document

"""
doc = self.get_query_config(config)
doc.update(
{
"added": datetime.now(tz=pytz.utc),
}
)
return doc

def config_exists(self, config):
"""Quick check if this config is already saved in the collection.

Expand Down Expand Up @@ -189,7 +233,7 @@ def compute_md5(abs_path):


@export
class MongoUploader(GridFsInterface):
class MongoUploader(GridFsInterfaceMongo):
"""Class to upload files to GridFs."""

def __init__(self, readonly=False, *args, **kwargs):
Expand Down Expand Up @@ -251,7 +295,7 @@ def upload_single(self, config, abs_path):


@export
class MongoDownloader(GridFsInterface):
class MongoDownloader(GridFsInterfaceMongo):
"""Class to download files from GridFs."""

_instances: Dict[Tuple, "MongoDownloader"] = {}
Expand Down Expand Up @@ -380,6 +424,157 @@ def _check_store_files_at(cache_folder_alternatives):
)


@export
class GridFsInterfaceAPI(GridFsBase):
"""Interface to gridfs using the runDB API."""

def __init__(self, config_identifier: str = "config_name") -> None:
super().__init__(config_identifier=config_identifier)
self.db = DB()

def config_exists(self, config: str) -> bool:
"""Check if config is saved in the collection."""
query = self.get_query_config(config)
return self.db.count_files(query) > 0

def md5_stored(self, abs_path: str) -> bool:
"""Check if file with same MD5 is stored.

RAM intensive.

"""
if not os.path.exists(abs_path):
return False
query = {"md5": self.compute_md5(abs_path)}
return self.db.count_files(query) > 0

def test_find(self) -> None:
"""Test the connection to the collection."""
if self.db.get_files({}, projection={"_id": 1}) is None:
raise ConnectionError("Could not find any data in this collection")

def list_files(self) -> List[str]:
"""Get list of files stored in the database."""
return [
doc[self.config_identifier]
for doc in self.db.get_files({}, projection={self.config_identifier: 1})
if self.config_identifier in doc
]


@export
class APIUploader(GridFsInterfaceAPI):
"""Upload files to gridfs using the runDB API."""

def __init__(self, config_identifier: str = "config_name") -> None:
super().__init__(config_identifier=config_identifier)

def upload_single(self, config: str, abs_path: str) -> None:
"""Upload a single file to gridfs.

:param config: str, the name under which this file should be stored
:param abs_path: str, the absolute path of the file

"""
if not os.path.exists(abs_path):
raise CouldNotLoadError(f"{abs_path} does not exist")

logger.info(f"uploading file {config} from {abs_path}")
self.db.upload_file(abs_path, config)


@export
class APIDownloader(GridFsInterfaceAPI):
"""Download files from gridfs using the runDB API."""

_instances: Dict[Tuple, "APIDownloader"] = {}
_initialized: Dict[Tuple, bool] = {}

def __new__(cls, *args: Any, **kwargs: Any) -> "APIDownloader":
key = (args, frozenset(kwargs.items()))
if key not in cls._instances:
cls._instances[key] = super(APIDownloader, cls).__new__(cls)
cls._initialized[key] = False
return cls._instances[key]

def __init__(self, *args: Any, **kwargs: Any) -> None:
key = (args, frozenset(kwargs.items()))
if not self._initialized[key]:
self._instances[key].initialize(*args, **kwargs)
self._initialized[key] = True

def initialize(
self,
config_identifier: str = "config_name",
store_files_at: Optional[Union[str, Tuple[str, ...], List[str]]] = None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(config_identifier=config_identifier)

if store_files_at is None:
store_files_at = (
"./resource_cache",
"/tmp/straxen_resource_cache",
)
elif isinstance(store_files_at, str):
store_files_at = (store_files_at,)
elif isinstance(store_files_at, list):
store_files_at = tuple(store_files_at)
elif not isinstance(store_files_at, tuple):
raise ValueError(f"{store_files_at} should be a string, list, or tuple of paths!")

self.storage_options: Tuple[str, ...] = store_files_at

def download_single(
self,
config_name: str,
write_to: Optional[str] = None,
human_readable_file_name: bool = False,
) -> str:
"""Download the config_name if it exists."""
target_file_name = (
config_name if human_readable_file_name else self.db.get_file_md5(config_name)
)

# check if self.storage_options is None or empty
if not self.storage_options:
raise ValueError("No storage options available")

if write_to is None:
for cache_folder in self.storage_options:
possible_path = os.path.join(cache_folder, target_file_name)
if os.path.exists(possible_path):
return possible_path

store_files_at = self._check_store_files_at(self.storage_options)
else:
store_files_at = write_to

# make sure store_files_at is a string
if not isinstance(store_files_at, str):
raise TypeError(f"Expected string for store_files_at, got {type(store_files_at)}")

destination_path = os.path.join(store_files_at, target_file_name)

with tempfile.TemporaryDirectory() as temp_directory_name:
temp_path = self.db.download_file(config_name, save_dir=temp_directory_name)
if not os.path.exists(destination_path):
move(temp_path, destination_path)
else:
warn(f"File {destination_path} already exists. Not overwriting.")
return destination_path

def _check_store_files_at(self, options: Union[str, Tuple[str, ...]]) -> str:
"""Check and return a valid storage location."""
if isinstance(options, str):
return options
for option in options:
if os.path.isdir(option):
return option
raise ValueError("No valid storage location found")


class DownloadWarning(UserWarning):
pass

Expand Down