Skip to content

Commit

Permalink
Update resource and test_assets_defs
Browse files Browse the repository at this point in the history
  • Loading branch information
maximearmstrong committed Aug 7, 2024
1 parent b71a036 commit e11cfb3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def all_additional_request_params(self) -> Mapping[str, Any]:
raise NotImplementedError()

def make_request(
self, endpoint: str, data: Optional[Mapping[str, object]] = None, method: str = "POST"
self,
endpoint: str,
data: Optional[Mapping[str, object]] = None,
method: str = "POST",
include_additional_request_params: bool = True,
) -> Optional[Mapping[str, object]]:
"""Creates and sends a request to the desired Airbyte REST API endpoint.
Expand All @@ -122,10 +126,11 @@ def make_request(
if data:
request_args["json"] = data

request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)
if include_additional_request_params:
request_args = deep_merge_dicts(
request_args,
self.all_additional_request_params,
)

response = requests.request(
**request_args,
Expand Down Expand Up @@ -275,8 +280,8 @@ class AirbyteCloudResource(BaseAirbyteResource):
client_id: str = Field(..., description="The Airbyte Cloud client ID.")
client_secret: str = Field(..., description="The Airbyte Cloud client secret.")

_access_token_value: str = PrivateAttr()
_access_token_timestamp: float = PrivateAttr()
_access_token_value: Optional[str] = PrivateAttr(default=None)
_access_token_timestamp: Optional[float] = PrivateAttr(default=None)

def setup_for_execution(self, context: InitResourceContext) -> None:
# Refresh access token when the resource is initialized
Expand All @@ -288,7 +293,7 @@ def api_base_url(self) -> str:

@property
def all_additional_request_params(self) -> Mapping[str, Any]:
# Make sure the access token is refreshed before using it.
# Make sure the access token is refreshed before using it when calling the API.
if self._needs_refreshed_access_token():
self._refresh_access_token()
return {
Expand Down Expand Up @@ -333,6 +338,8 @@ def _refresh_access_token(self) -> None:
"client_id": self.client_id,
"client_secret": self.client_secret,
},
# Must not pass the bearer access token when refreshing it.
include_additional_request_params=False,
)
)
self._access_token_value = str(response["access_token"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def test_assets_with_normalization(


def test_assets_cloud() -> None:
ab_resource = AirbyteCloudResource(api_key="some_key", poll_interval=0)
ab_resource = AirbyteCloudResource(
client_id="some_client_id", client_secret="some_client_secret", poll_interval=0
)
ab_url = ab_resource.api_base_url

ab_assets = build_airbyte_assets(
Expand Down Expand Up @@ -236,6 +238,11 @@ def test_assets_cloud() -> None:
f"{ab_url}/jobs/1",
json={"jobId": 1, "status": "succeeded", "jobType": "sync"},
)
rsps.add(
rsps.POST,
f"{ab_url}/applications/token",
json={"access_token": "some_access_token"},
)

res = materialize_to_memory(
ab_assets,
Expand Down

0 comments on commit e11cfb3

Please sign in to comment.