diff --git a/ravenml/data/commands.py b/ravenml/data/commands.py index ff3ee77..b98f09f 100644 --- a/ravenml/data/commands.py +++ b/ravenml/data/commands.py @@ -10,7 +10,6 @@ import yaml import shutil from pkg_resources import iter_entry_points -from click_plugins import with_plugins from colorama import Fore from pathlib import Path from ravenml.utils.imageset import get_imageset_names, get_imageset_metadata diff --git a/ravenml/data/helpers.py b/ravenml/data/helpers.py index 85b30a7..5aed205 100644 --- a/ravenml/data/helpers.py +++ b/ravenml/data/helpers.py @@ -296,7 +296,7 @@ def read_json_metadata(dir_entry, image_id): dataframe with image_id key and True/False values for each tag. """ - with open(dir_entry.path, "r") as read_file: + with open(dir_entry, "r") as read_file: data = json.load(read_file) tag_list = data.get("tags", ['untagged']) if len(tag_list) == 0: diff --git a/ravenml/data/interfaces.py b/ravenml/data/interfaces.py index cf467c5..d2e021e 100644 --- a/ravenml/data/interfaces.py +++ b/ravenml/data/interfaces.py @@ -10,6 +10,7 @@ import os import shutil import json +import asyncio from pathlib import Path from datetime import datetime from ravenml.utils.local_cache import RMLCache @@ -94,7 +95,16 @@ def __init__(self, config:dict=None, plugin_name:str=None): ## Download imagesets self.imageset_cache.ensure_subpath_exists('imagesets') self.imageset_paths = [] - self.download_imagesets(imageset_list) + if config.get("download_full_imagesets") and config["download_full_imagesets"]: + self.download_imagesets(imageset_list) + self.lazy_loading = False + else: + self.imageset_cache.ensure_subpath_exists('imagesets/') + for imageset in imageset_list: + self.imageset_cache.ensure_subpath_exists(f'imagesets/{imageset}') + self.imageset_paths.append(self.imageset_cache.path / 'imagesets' / imageset) + self.lazy_loading = True + # local imagesets else: imageset_paths = config.get('imageset') diff --git a/ravenml/data/write_dataset.py b/ravenml/data/write_dataset.py index dc05f92..23ec788 100644 --- a/ravenml/data/write_dataset.py +++ b/ravenml/data/write_dataset.py @@ -1,4 +1,4 @@ -import os, shutil, time, json +import os, shutil, time, json, asyncio import pandas as pd from random import sample from pathlib import Path @@ -7,6 +7,7 @@ from ravenml.utils.question import cli_spinner, cli_spinner_wrapper, DecoratorSuperClass, user_input from ravenml.utils.config import get_config from ravenml.data.helpers import default_filter, copy_associated_files, split_data, read_json_metadata +from ravenml.utils.aws import conditional_download, download_file_list class DatasetWriter(DecoratorSuperClass): """Interface for creating datasets, methods are in order of what is expected to be @@ -46,7 +47,8 @@ def __init__(self, create: CreateInput, **kwargs): created_by (String): name of person creating dataset comments (String): comments on dataset plugin_name (String): name of the plugin being used - imageset_paths (list): list of paths to all imagesets being used + imageset_paths (list): list of paths to all imagesets being used, + is empty when doing lazy loading. tags_df (pandas dataframe): after load_image_ids() is run, holds tags associated with each image_id image_ids (list): list of tuples containing a path to an imageset @@ -73,6 +75,8 @@ def __init__(self, create: CreateInput, **kwargs): self.filter_metadata = {"groups": []} self.obj_dict = {} self.metadata_format = None + self.lazy_loading = create.lazy_loading + self.imageset_cache = create.imageset_cache @cli_spinner_wrapper("Loading Image Ids...") def load_image_ids(self): @@ -194,6 +198,21 @@ def load_image_ids(self, metadata_format: tuple): self.metadata_format = metadata_format metadata_prefix = metadata_format[0] metadata_suffix = metadata_format[1] + + # Handle lazy loading case + if self.lazy_loading: + bucketConfig = get_config() + image_bucket_name = bucketConfig.get('image_bucket_name') + metadata_cond = lambda x : x.startswith(metadata_prefix) and x.endswith(metadata_suffix) + loop = asyncio.get_event_loop() + + # Download all metadata files in order to enumerate image ids + for imageset in self.imageset_paths: + loop.run_until_complete(conditional_download(image_bucket_name, + os.path.basename(imageset), + self.imageset_cache.path / 'imagesets', + metadata_cond)) + if metadata_suffix != '.json': raise Exception("Currently non-json metadata files are not supported for the default loading of image ids") @@ -206,18 +225,19 @@ def load_image_ids(self, metadata_format: tuple): image_id = dir_entry.name.replace(metadata_prefix, '').replace(metadata_suffix, '') self.image_ids.append((data_dir, image_id)) - def set_size_filter(self, set_sizes: dict=None): + def set_size_filter(self, set_sizes: dict=None, associated_files: list=[]): """Method is expected to only be called after 'load_image_ids' is called, as it relies on 'self.image_ids' to be prepopulated. Method filters by choosing specified amount of images - from each imageset. + from each imageset. If LazyLoading is enabled, also downloads associated files. If overridden, method is expected to set 'self.image_ids' to whatever image_ids are still to be used after filtering. 'self.filter_metadata' also needs to be set to a dict containing imageset names as keys and lists of image_ids as values. Args: - set_sizes (dict): contains the amoung of images from each imageset in the following format, + set_sizes (dict): contains the amount of images from each imageset in the following format, { imageset_name (str) : num_images (int) } + associated_files (list): contains file formats for all files associated with an image id Variables Needed: image_ids (list): needed for filtering """ @@ -241,6 +261,17 @@ def set_size_filter(self, set_sizes: dict=None): # Updates image_ids with the new information self.image_ids = filtered_image_ids + # If using lazy loading, after this filtering is complete, the images and related files can now be downloaded + if self.lazy_loading: + bucketConfig = get_config() + image_bucket_name = bucketConfig.get('image_bucket_name') + + # Create list of s3 uri and local path for each related file + files_to_download = [(os.path.basename(image_id[0]) + '/' + file_format[0] + image_id[1] + file_format[1], + str(image_id[0]) + '/' + file_format[0] + image_id[1] + file_format[1]) for image_id in self.image_ids for file_format in associated_files] + + cli_spinner("Downloading Files...", download_file_list, image_bucket_name, files_to_download) + def interactive_tag_filter(self): """Method is expected to only be called after 'load_image_ids' is called, as it relies on 'self.image_ids' to be prepopulated. Method prompts user through interactive filtering @@ -262,25 +293,11 @@ def interactive_tag_filter(self): imageset_to_image_ids_dict[os.path.basename(image_id[0])].append(image_id) for image_id in self.image_ids: - temp = read_json_metadata(image_id[0] / f'{image_id[1]}{self.metadata_format[1]}', image_id[1]) + temp = read_json_metadata(image_id[0] / f'{self.metadata_format[0]}{image_id[1]}{self.metadata_format[1]}', image_id[1]) self.tags_df = pd.concat((self.tags_df, temp), sort=False) self.tags_df = self.tags_df.fillna(False) self.image_ids = default_filter(self.tags_df, self.filter_metadata) - def load_data(self): - """Method is expected to be called after 'load_image_ids' and filtering methods if filtering is - desired. Method goes through each image_id and copies its corresponing files into a temp directory - which will be later used by the plugin to create their dataset. - - If overloaded, method is expected to copy all files the plugin needs into the provided 'temp_dir'. - - Variables Needed: - image_ids (list): needed to find what needs to be copied (provided by 'load_image_ids'/filtering) - temp_dir (Path): needed to know where to copy to (provided by 'create' input) - associated_files (dict): needed to know what files need to be copied (provided by plugin) - """ - copy_associated_files(self.image_ids, self.temp_dir, self.associated_files) - def write_metadata(self): """Method writes out metadata in JSON format in file 'metadata.json', in root directory of dataset. diff --git a/ravenml/utils/aws.py b/ravenml/utils/aws.py index 26dca93..28c016e 100644 --- a/ravenml/utils/aws.py +++ b/ravenml/utils/aws.py @@ -1,7 +1,9 @@ import os +import aioboto3 import boto3 import json import subprocess +import asyncio from pathlib import Path from ravenml.utils.config import get_config from ravenml.utils.local_cache import RMLCache @@ -54,6 +56,77 @@ def download_prefix(bucket_name: str, prefix: str, cache: RMLCache, custom_path: except: return False +def download_imageset_file(prefix: str, local_path: str): + """Downloads file into the provided location. Meant for plugins to + download imageset-wide required files, not images or related information + + Args: + prefix (str): path to s3 file + local_path (str): local path for where files are + downloaded to + + Returns: + bool: T if successful, F if no objects found + """ + bucketConfig = get_config() + image_bucket_name = bucketConfig.get('image_bucket_name') + try: + s3_uri = 's3://' + image_bucket_name + '/' + prefix + + subprocess.call(["aws", "s3", "cp", s3_uri, str(local_path), '--quiet']) + return True + except: + return False + +async def conditional_download(bucket_name, prefix, local_path, cond_func = lambda x: True): + """Downloads all files with the specified prefix into the provided local cache based on a condition. + + Args: + bucket_name (str): name of bucket + prefix (str): prefix to filter on + local_path (str): where to download to + cond_func (function, optional): boolean function specifying which + files to download + + Returns: + bool: T if successful, F if no objects found + """ + try: + async with aioboto3.resource("s3") as s3: + bucket = await s3.Bucket(bucket_name) + async for s3_object in bucket.objects.filter(Prefix=prefix+"/"): + if cond_func(s3_object.key.split('/')[-1]) and not os.path.exists(local_path / s3_object.key): + await bucket.download_file(s3_object.key, local_path / s3_object.key) + except: + return False + return True + +def download_file_list(bucket_name, file_list): + """Downloads all files with the specified prefix into the provided local cache based on a condition. + + Args: + bucket_name (str): name of bucket + file_list (str): list of files in format [(s3_prefix, local_path)] for + what to download + + Returns: + bool: T if successful, F if no objects found + """ + async def download_helper(bucket_name, file_list): + async with aioboto3.resource("s3") as s3: + bucket = await s3.Bucket(bucket_name) + for f in file_list: + if not os.path.exists(f[1]): + try: + await bucket.download_file(f[0], f[1]) + except Exception as e: + if e.response['Error']['Code'] != '404': + return False + return True + + loop = asyncio.get_event_loop() + loop.run_until_complete(download_helper(bucket_name, file_list)) + ### UPLOAD FUNCTIONS ### def upload_file_to_s3(prefix: str, file_path: Path, alternate_name=None): """Uploads file at given file path to model bucket on S3. diff --git a/requirements.txt b/requirements.txt index 4c27acf..ec535ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ shortuuid==0.5.0 halo==0.0.26 colorama==0.3.9 pyaml==19.4.1 +aiboto3==8.0.5