Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread-local Session Management and Cookie Reuse to Address EDL DSE issue #909

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
36 changes: 32 additions & 4 deletions earthaccess/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pickle import dumps, loads
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from uuid import uuid4
import threading

import fsspec
import requests
Expand All @@ -18,7 +19,7 @@

import earthaccess

from .auth import Auth
from .auth import Auth, SessionWithHeaderRedirection
from .daac import DAAC_TEST_URLS, find_provider
from .results import DataGranule
from .search import DataCollections
Expand Down Expand Up @@ -119,6 +120,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
Parameters:
auth: Auth instance to download and access data.
"""
self.thread_locals = threading.local()
if auth.authenticated is True:
self.auth = auth
self._s3_credentials: Dict[
Expand All @@ -127,7 +129,7 @@ def __init__(self, auth: Any, pre_authorize: bool = False) -> None:
oauth_profile = f"https://{auth.system.edl_hostname}/profile"
# sets the initial URS cookie
self._requests_cookies: Dict[str, Any] = {}
self.set_requests_session(oauth_profile)
self.set_requests_session(oauth_profile, bearer_token=True)
if pre_authorize:
# collect cookies from other DAACs
for url in DAAC_TEST_URLS:
Expand Down Expand Up @@ -336,7 +338,10 @@ def get_requests_session(self, bearer_token: bool = True) -> requests.Session:
Returns:
requests Session
"""
return self.auth.get_session()
if hasattr(self, "_http_session"):
return self._http_session
else:
raise AttributeError("The requests session hasn't been set up yet.")

def open(
self,
Expand Down Expand Up @@ -652,6 +657,27 @@ def _get_granules(
data_links, local_path, pqdm_kwargs=pqdm_kwargs
)

def _clone_session_in_local_thread(
self, original_session: SessionWithHeaderRedirection
) -> None:
"""Clone the original session and store it in the local thread context.

This method creates a new session that replicates the headers, cookies, and authentication settings
from the provided original session. The new session is stored in a thread-local storage.

Parameters:
original_session (SessionWithHeaderRedirection): The session to be cloned.

Returns:
None
"""
if not hasattr(self.thread_locals, "local_thread_session"):
local_thread_session = SessionWithHeaderRedirection()
local_thread_session.headers.update(original_session.headers)
local_thread_session.cookies.update(original_session.cookies)
local_thread_session.auth = original_session.auth
self.thread_locals.local_thread_session = local_thread_session

def _download_file(self, url: str, directory: Path) -> str:
"""Download a single file from an on-prem location, a DAAC data center.

Expand All @@ -669,7 +695,9 @@ def _download_file(self, url: str, directory: Path) -> str:
path = directory / Path(local_filename)
if not path.exists():
try:
session = self.auth.get_session()
original_session = self.get_requests_session()
self._clone_session_in_local_thread(original_session)
session = self.thread_locals.local_thread_session
with session.get(
url,
stream=True,
Expand Down
Loading