Skip to content

Commit

Permalink
Changes made to public functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherwin-14 committed Sep 3, 2024
1 parent e4fbe9a commit 2afc65b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
10 changes: 8 additions & 2 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def download(
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand All @@ -201,7 +202,9 @@ def download(
elif isinstance(granules, str):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
results = earthaccess.__store__.get(
granules, local_path, provider, threads, fail_fast=fail_fast
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
Expand All @@ -213,6 +216,7 @@ def download(
def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[AbstractFileSystem]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -226,7 +230,9 @@ def open(
a list of s3fs "file pointers" to s3 files.
"""
provider = _normalize_location(provider)
results = earthaccess.__store__.open(granules=granules, provider=provider)
results = earthaccess.__store__.open(
granules=granules, provider=provider, fail_fast=fail_fast
)
return results


Expand Down
25 changes: 17 additions & 8 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[Any]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -343,14 +344,15 @@ def open(
A list of s3fs "file pointers" to s3 files.
"""
if len(granules):
return self._open(granules, provider)
return self._open(granules, provider, fail_fast=fail_fast)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[Any]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -371,6 +373,7 @@ def _open_granules(
granules: List[DataGranule],
provider: Optional[str] = None,
threads: Optional[int] = 8,
fail_fast: bool = True,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand Down Expand Up @@ -401,9 +404,7 @@ def _open_granules(
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
url_mapping, fs=s3_fs, threads=threads, fail_fast=fail_fast
)
except Exception as e:
raise RuntimeError(
Expand All @@ -412,11 +413,15 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(url_mapping, threads=threads)
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
)
return fileset
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(url_mapping, threads=threads)
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
)
return fileset

@_open.register
Expand Down Expand Up @@ -470,7 +475,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads)
fileset = self._open_urls_https(url_mapping, threads, fail_fast=fail_fast)
return fileset

def get(
Expand All @@ -479,6 +484,7 @@ def get(
local_path: Union[Path, str, None] = None,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down Expand Up @@ -507,7 +513,9 @@ def get(
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads)
files = self._get(
granules, local_path, provider, threads, fail_fast=fail_fast
)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
Expand All @@ -519,6 +527,7 @@ def _get(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.
Expand Down

0 comments on commit 2afc65b

Please sign in to comment.