Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parameters and exception behaviour to pqdm #792

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
instead ([#766](https://github.com/nsidc/earthaccess/issues/766))([**@Sherwin-14**](https://github.com/Sherwin-14))
- Added Issue Templates([#281](https://github.com/nsidc/earthaccess/issues/281))([**@Sherwin-14**](https://github.com/Sherwin-14))

- Fixed earthaccess.download() ignoring errors ([#581])(https://github.com/nsidc/earthaccess/issues/581) ([**@Sherwin-14**](https://github.com/Sherwin-14))


## [v0.10.0] 2024-07-19
Expand Down
10 changes: 8 additions & 2 deletions earthaccess/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def download(
local_path: Optional[str],
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand All @@ -201,7 +202,9 @@ def download(
elif isinstance(granules, str):
granules = [granules]
try:
results = earthaccess.__store__.get(granules, local_path, provider, threads)
results = earthaccess.__store__.get(
granules, local_path, provider, threads, fail_fast=fail_fast
)
except AttributeError as err:
logger.error(
f"{err}: You must call earthaccess.login() before you can download data"
Expand All @@ -213,6 +216,7 @@ def download(
def open(
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[AbstractFileSystem]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -226,7 +230,9 @@ def open(
a list of s3fs "file pointers" to s3 files.
"""
provider = _normalize_location(provider)
results = earthaccess.__store__.open(granules=granules, provider=provider)
results = earthaccess.__store__.open(
granules=granules, provider=provider, fail_fast=fail_fast
)
return results


Expand Down
64 changes: 48 additions & 16 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,20 @@ def _open_files(
url_mapping: Mapping[str, Union[DataGranule, None]],
fs: fsspec.AbstractFileSystem,
threads: Optional[int] = 8,
fail_fast: bool = True,
) -> List[fsspec.AbstractFileSystem]:
def multi_thread_open(data: tuple) -> EarthAccessFile:
urls, granule = data
return EarthAccessFile(fs.open(urls), granule)

fileset = pqdm(url_mapping.items(), multi_thread_open, n_jobs=threads)
exception_behavior = "immediate" if fail_fast else "deferred"

fileset = pqdm(
url_mapping.items(),
multi_thread_open,
n_jobs=threads,
exception_behaviour=exception_behavior,
)
return fileset


Expand Down Expand Up @@ -322,6 +330,7 @@ def open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[Any]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -335,14 +344,15 @@ def open(
A list of s3fs "file pointers" to s3 files.
"""
if len(granules):
return self._open(granules, provider)
return self._open(granules, provider, fail_fast=fail_fast)
return []

@singledispatchmethod
def _open(
self,
granules: Union[List[str], List[DataGranule]],
provider: Optional[str] = None,
fail_fast: bool = True,
) -> List[Any]:
"""Returns a list of fsspec file-like objects that can be used to access files
hosted on S3 or HTTPS by third party libraries like xarray.
Expand All @@ -363,6 +373,7 @@ def _open_granules(
granules: List[DataGranule],
provider: Optional[str] = None,
threads: Optional[int] = 8,
fail_fast: bool = True,
) -> List[Any]:
fileset: List = []
total_size = round(sum([granule.size() for granule in granules]) / 1024, 2)
Expand Down Expand Up @@ -393,9 +404,7 @@ def _open_granules(
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
url_mapping, fs=s3_fs, threads=threads, fail_fast=fail_fast
)
except Exception as e:
raise RuntimeError(
Expand All @@ -404,11 +413,15 @@ def _open_granules(
f"Exception: {traceback.format_exc()}"
) from e
else:
fileset = self._open_urls_https(url_mapping, threads=threads)
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
)
return fileset
else:
url_mapping = _get_url_granule_mapping(granules, access="on_prem")
fileset = self._open_urls_https(url_mapping, threads=threads)
fileset = self._open_urls_https(
url_mapping, threads=threads, fail_fast=fail_fast
)
return fileset

@_open.register
Expand All @@ -417,6 +430,7 @@ def _open_urls(
granules: List[str],
provider: Optional[str] = None,
threads: Optional[int] = 8,
fail_fast: bool = True,
) -> List[Any]:
fileset: List = []

Expand All @@ -441,9 +455,7 @@ def _open_urls(
if s3_fs is not None:
try:
fileset = _open_files(
url_mapping,
fs=s3_fs,
threads=threads,
url_mapping, fs=s3_fs, threads=threads, fail_fast=fail_fast
)
except Exception as e:
raise RuntimeError(
Expand All @@ -463,7 +475,7 @@ def _open_urls(
raise ValueError(
"We cannot open S3 links when we are not in-region, try using HTTPS links"
)
fileset = self._open_urls_https(url_mapping, threads)
fileset = self._open_urls_https(url_mapping, threads, fail_fast=fail_fast)
return fileset

def get(
Expand All @@ -472,6 +484,7 @@ def get(
local_path: Union[Path, str, None] = None,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand Down Expand Up @@ -500,7 +513,9 @@ def get(
local_path = Path(local_path)

if len(granules):
files = self._get(granules, local_path, provider, threads)
files = self._get(
granules, local_path, provider, threads, fail_fast=fail_fast
)
return files
else:
raise ValueError("List of URLs or DataGranule instances expected")
Expand All @@ -512,6 +527,7 @@ def _get(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
"""Retrieves data granules from a remote storage system.

Expand Down Expand Up @@ -541,6 +557,7 @@ def _get_urls(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
data_links = granules
downloaded_files: List = []
Expand All @@ -562,7 +579,9 @@ def _get_urls(

else:
# if we are not in AWS
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, fail_fast=fail_fast
)

@_get.register
def _get_granules(
Expand All @@ -571,6 +590,7 @@ def _get_granules(
local_path: Path,
provider: Optional[str] = None,
threads: int = 8,
fail_fast: bool = True,
) -> List[str]:
data_links: List = []
downloaded_files: List = []
Expand Down Expand Up @@ -611,7 +631,9 @@ def _get_granules(
else:
# if the data are cloud-based, but we are not in AWS,
# it will be downloaded as if it was on prem
return self._download_onprem_granules(data_links, local_path, threads)
return self._download_onprem_granules(
data_links, local_path, threads, fail_fast=fail_fast
)

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.
Expand Down Expand Up @@ -649,7 +671,7 @@ def _download_file(self, url: str, directory: Path) -> str:
return str(path)

def _download_onprem_granules(
self, urls: List[str], directory: Path, threads: int = 8
self, urls: List[str], directory: Path, threads: int = 8, fail_fast: bool = True
) -> List[Any]:
"""Downloads a list of URLS into the data directory.

Expand All @@ -658,6 +680,9 @@ def _download_onprem_granules(
directory: local directory to store the downloaded files
threads: parallel number of threads to use to download the files;
adjust as necessary, default = 8
fail_fast: if set to True, the download process will stop immediately
upon encountering the first error. If set to False, errors will be
deferred, allowing the download of remaining files to continue.

Returns:
A list of local filepaths to which the files were downloaded.
Expand All @@ -671,23 +696,30 @@ def _download_onprem_granules(
directory.mkdir(parents=True, exist_ok=True)

arguments = [(url, directory) for url in urls]

exception_behavior = "immediate" if fail_fast else "deferred"

results = pqdm(
arguments,
self._download_file,
n_jobs=threads,
argument_type="args",
exception_behaviour=exception_behavior,
)
return results

def _open_urls_https(
self,
url_mapping: Mapping[str, Union[DataGranule, None]],
threads: Optional[int] = 8,
fail_fast: bool = True,
) -> List[fsspec.AbstractFileSystem]:
https_fs = self.get_fsspec_session()
if https_fs is not None:
try:
fileset = _open_files(url_mapping, https_fs, threads)
fileset = _open_files(
url_mapping, https_fs, threads, fail_fast=fail_fast
)
except Exception:
logger.exception(
"An exception occurred while trying to access remote files via HTTPS"
Expand Down
Loading