Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow partial download of imagesets #52

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ravenml/data/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import yaml
import shutil
from pkg_resources import iter_entry_points
from click_plugins import with_plugins
from colorama import Fore
from pathlib import Path
from ravenml.utils.imageset import get_imageset_names, get_imageset_metadata
Expand Down
2 changes: 1 addition & 1 deletion ravenml/data/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def read_json_metadata(dir_entry, image_id):
dataframe with image_id key and True/False values
for each tag.
"""
with open(dir_entry.path, "r") as read_file:
with open(dir_entry, "r") as read_file:
data = json.load(read_file)
tag_list = data.get("tags", ['untagged'])
if len(tag_list) == 0:
Expand Down
12 changes: 11 additions & 1 deletion ravenml/data/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import shutil
import json
import asyncio
from pathlib import Path
from datetime import datetime
from ravenml.utils.local_cache import RMLCache
Expand Down Expand Up @@ -94,7 +95,16 @@ def __init__(self, config:dict=None, plugin_name:str=None):
## Download imagesets
self.imageset_cache.ensure_subpath_exists('imagesets')
self.imageset_paths = []
self.download_imagesets(imageset_list)
if config.get("download_full_imagesets") and config["download_full_imagesets"]:
self.download_imagesets(imageset_list)
self.lazy_loading = False
else:
self.imageset_cache.ensure_subpath_exists('imagesets/')
for imageset in imageset_list:
self.imageset_cache.ensure_subpath_exists(f'imagesets/{imageset}')
self.imageset_paths.append(self.imageset_cache.path / 'imagesets' / imageset)
self.lazy_loading = True

# local imagesets
else:
imageset_paths = config.get('imageset')
Expand Down
57 changes: 37 additions & 20 deletions ravenml/data/write_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, shutil, time, json
import os, shutil, time, json, asyncio
import pandas as pd
from random import sample
from pathlib import Path
Expand All @@ -7,6 +7,7 @@
from ravenml.utils.question import cli_spinner, cli_spinner_wrapper, DecoratorSuperClass, user_input
from ravenml.utils.config import get_config
from ravenml.data.helpers import default_filter, copy_associated_files, split_data, read_json_metadata
from ravenml.utils.aws import conditional_download, download_file_list

class DatasetWriter(DecoratorSuperClass):
"""Interface for creating datasets, methods are in order of what is expected to be
Expand Down Expand Up @@ -46,7 +47,8 @@ def __init__(self, create: CreateInput, **kwargs):
created_by (String): name of person creating dataset
comments (String): comments on dataset
plugin_name (String): name of the plugin being used
imageset_paths (list): list of paths to all imagesets being used
imageset_paths (list): list of paths to all imagesets being used,
is empty when doing lazy loading.
tags_df (pandas dataframe): after load_image_ids() is run, holds
tags associated with each image_id
image_ids (list): list of tuples containing a path to an imageset
Expand All @@ -73,6 +75,8 @@ def __init__(self, create: CreateInput, **kwargs):
self.filter_metadata = {"groups": []}
self.obj_dict = {}
self.metadata_format = None
self.lazy_loading = create.lazy_loading
self.imageset_cache = create.imageset_cache

@cli_spinner_wrapper("Loading Image Ids...")
def load_image_ids(self):
Expand Down Expand Up @@ -194,6 +198,21 @@ def load_image_ids(self, metadata_format: tuple):
self.metadata_format = metadata_format
metadata_prefix = metadata_format[0]
metadata_suffix = metadata_format[1]

# Handle lazy loading case
if self.lazy_loading:
bucketConfig = get_config()
image_bucket_name = bucketConfig.get('image_bucket_name')
metadata_cond = lambda x : x.startswith(metadata_prefix) and x.endswith(metadata_suffix)
loop = asyncio.get_event_loop()

# Download all metadata files in order to enumerate image ids
for imageset in self.imageset_paths:
loop.run_until_complete(conditional_download(image_bucket_name,
os.path.basename(imageset),
self.imageset_cache.path / 'imagesets',
metadata_cond))

if metadata_suffix != '.json':
raise Exception("Currently non-json metadata files are not supported for the default loading of image ids")

Expand All @@ -206,18 +225,19 @@ def load_image_ids(self, metadata_format: tuple):
image_id = dir_entry.name.replace(metadata_prefix, '').replace(metadata_suffix, '')
self.image_ids.append((data_dir, image_id))

def set_size_filter(self, set_sizes: dict=None):
def set_size_filter(self, set_sizes: dict=None, associated_files: list=[]):
"""Method is expected to only be called after 'load_image_ids' is called, as it relies on
'self.image_ids' to be prepopulated. Method filters by choosing specified amount of images
from each imageset.
from each imageset. If LazyLoading is enabled, also downloads associated files.

If overridden, method is expected to set 'self.image_ids' to whatever image_ids are still
to be used after filtering. 'self.filter_metadata' also needs to be set to a dict containing
imageset names as keys and lists of image_ids as values.

Args:
set_sizes (dict): contains the amoung of images from each imageset in the following format,
set_sizes (dict): contains the amount of images from each imageset in the following format,
{ imageset_name (str) : num_images (int) }
associated_files (list): contains file formats for all files associated with an image id
Variables Needed:
image_ids (list): needed for filtering
"""
Expand All @@ -241,6 +261,17 @@ def set_size_filter(self, set_sizes: dict=None):
# Updates image_ids with the new information
self.image_ids = filtered_image_ids

# If using lazy loading, after this filtering is complete, the images and related files can now be downloaded
if self.lazy_loading:
bucketConfig = get_config()
image_bucket_name = bucketConfig.get('image_bucket_name')

# Create list of s3 uri and local path for each related file
files_to_download = [(os.path.basename(image_id[0]) + '/' + file_format[0] + image_id[1] + file_format[1],
str(image_id[0]) + '/' + file_format[0] + image_id[1] + file_format[1]) for image_id in self.image_ids for file_format in associated_files]

cli_spinner("Downloading Files...", download_file_list, image_bucket_name, files_to_download)

def interactive_tag_filter(self):
"""Method is expected to only be called after 'load_image_ids' is called, as it relies on
'self.image_ids' to be prepopulated. Method prompts user through interactive filtering
Expand All @@ -262,25 +293,11 @@ def interactive_tag_filter(self):
imageset_to_image_ids_dict[os.path.basename(image_id[0])].append(image_id)

for image_id in self.image_ids:
temp = read_json_metadata(image_id[0] / f'{image_id[1]}{self.metadata_format[1]}', image_id[1])
temp = read_json_metadata(image_id[0] / f'{self.metadata_format[0]}{image_id[1]}{self.metadata_format[1]}', image_id[1])
self.tags_df = pd.concat((self.tags_df, temp), sort=False)
self.tags_df = self.tags_df.fillna(False)
self.image_ids = default_filter(self.tags_df, self.filter_metadata)

def load_data(self):
"""Method is expected to be called after 'load_image_ids' and filtering methods if filtering is
desired. Method goes through each image_id and copies its corresponing files into a temp directory
which will be later used by the plugin to create their dataset.

If overloaded, method is expected to copy all files the plugin needs into the provided 'temp_dir'.

Variables Needed:
image_ids (list): needed to find what needs to be copied (provided by 'load_image_ids'/filtering)
temp_dir (Path): needed to know where to copy to (provided by 'create' input)
associated_files (dict): needed to know what files need to be copied (provided by plugin)
"""
copy_associated_files(self.image_ids, self.temp_dir, self.associated_files)

def write_metadata(self):
"""Method writes out metadata in JSON format in file 'metadata.json',
in root directory of dataset.
Expand Down
73 changes: 73 additions & 0 deletions ravenml/utils/aws.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import aioboto3
import boto3
import json
import subprocess
import asyncio
from pathlib import Path
from ravenml.utils.config import get_config
from ravenml.utils.local_cache import RMLCache
Expand Down Expand Up @@ -54,6 +56,77 @@ def download_prefix(bucket_name: str, prefix: str, cache: RMLCache, custom_path:
except:
return False

def download_imageset_file(prefix: str, local_path: str):
"""Downloads file into the provided location. Meant for plugins to
download imageset-wide required files, not images or related information

Args:
prefix (str): path to s3 file
local_path (str): local path for where files are
downloaded to

Returns:
bool: T if successful, F if no objects found
"""
bucketConfig = get_config()
image_bucket_name = bucketConfig.get('image_bucket_name')
try:
s3_uri = 's3://' + image_bucket_name + '/' + prefix

subprocess.call(["aws", "s3", "cp", s3_uri, str(local_path), '--quiet'])
return True
except:
return False

async def conditional_download(bucket_name, prefix, local_path, cond_func = lambda x: True):
"""Downloads all files with the specified prefix into the provided local cache based on a condition.

Args:
bucket_name (str): name of bucket
prefix (str): prefix to filter on
local_path (str): where to download to
cond_func (function, optional): boolean function specifying which
files to download

Returns:
bool: T if successful, F if no objects found
"""
try:
async with aioboto3.resource("s3") as s3:
bucket = await s3.Bucket(bucket_name)
async for s3_object in bucket.objects.filter(Prefix=prefix+"/"):
if cond_func(s3_object.key.split('/')[-1]) and not os.path.exists(local_path / s3_object.key):
await bucket.download_file(s3_object.key, local_path / s3_object.key)
except:
return False
return True

def download_file_list(bucket_name, file_list):
"""Downloads all files with the specified prefix into the provided local cache based on a condition.

Args:
bucket_name (str): name of bucket
file_list (str): list of files in format [(s3_prefix, local_path)] for
what to download

Returns:
bool: T if successful, F if no objects found
"""
async def download_helper(bucket_name, file_list):
async with aioboto3.resource("s3") as s3:
bucket = await s3.Bucket(bucket_name)
for f in file_list:
if not os.path.exists(f[1]):
try:
await bucket.download_file(f[0], f[1])
except Exception as e:
if e.response['Error']['Code'] != '404':
return False
return True

loop = asyncio.get_event_loop()
loop.run_until_complete(download_helper(bucket_name, file_list))

### UPLOAD FUNCTIONS ###
def upload_file_to_s3(prefix: str, file_path: Path, alternate_name=None):
"""Uploads file at given file path to model bucket on S3.
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ shortuuid==0.5.0
halo==0.0.26
colorama==0.3.9
pyaml==19.4.1
aiboto3==8.0.5