Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
provide ability to cache boto client instances directly and on `S3Buc…
Browse files Browse the repository at this point in the history
…ket` (#369)
  • Loading branch information
zzstoatzz authored Jan 19, 2024
1 parent eeb3252 commit f697760
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Handle `boto3` clients more efficiently with `lru_cache` - [#361](https://github.com/PrefectHQ/prefect-aws/pull/361)

### Fixed

### Deprecated
Expand Down Expand Up @@ -105,6 +107,7 @@ Released August 31st, 2023.
Released July 20th, 2023.

### Changed

- Promoted workers to GA, removed beta disclaimers

## 0.3.5
Expand Down Expand Up @@ -293,6 +296,7 @@ Released on October 28th, 2022.
- `ECSTask` is no longer experimental — [#137](https://github.com/PrefectHQ/prefect-aws/pull/137)

### Fixed

- Fix ignore_file option in `S3Bucket` skipping files which should be included — [#139](https://github.com/PrefectHQ/prefect-aws/pull/139)
- Fixed bug where `basepath` is used twice in the path when using `S3Bucket.put_directory` - [#143](https://github.com/PrefectHQ/prefect-aws/pull/143)

Expand Down
12 changes: 12 additions & 0 deletions prefect_aws/client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class AwsClientParameters(BaseModel):
title="Botocore Config",
)

def __hash__(self):
return hash(
(
self.api_version,
self.use_ssl,
self.verify,
self.verify_cert_path,
self.endpoint_url,
self.config,
)
)

@validator("config", pre=True)
def instantiate_config(cls, value: Union[Config, Dict[str, Any]]) -> Dict[str, Any]:
"""
Expand Down
76 changes: 66 additions & 10 deletions prefect_aws/credentials.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module handling AWS credentials"""

from enum import Enum
from functools import lru_cache
from threading import Lock
from typing import Any, Optional, Union

import boto3
Expand All @@ -16,14 +18,43 @@

from prefect_aws.client_parameters import AwsClientParameters

_LOCK = Lock()


class ClientType(Enum):
"""The supported boto3 clients."""

S3 = "s3"
ECS = "ecs"
BATCH = "batch"
SECRETS_MANAGER = "secretsmanager"


@lru_cache(maxsize=8, typed=True)
def _get_client_cached(ctx, client_type: Union[str, ClientType]) -> Any:
"""
Helper method to cache and dynamically get a client type.
Args:
client_type: The client's service name.
Returns:
An authenticated client.
Raises:
ValueError: if the client is not supported.
"""
with _LOCK:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = ctx.get_boto3_session().client(
service_name=client_type,
**ctx.aws_client_parameters.get_params_override(),
)
return client


class AwsCredentials(CredentialsBlock):
"""
Block used to manage authentication with AWS. AWS authentication is
Expand Down Expand Up @@ -75,6 +106,22 @@ class AwsCredentials(CredentialsBlock):
title="AWS Client Parameters",
)

class Config:
"""Config class for pydantic model."""

arbitrary_types_allowed = True

def __hash__(self):
field_hashes = (
hash(self.aws_access_key_id),
hash(self.aws_secret_access_key),
hash(self.aws_session_token),
hash(self.profile_name),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
)
return hash(field_hashes)

def get_boto3_session(self) -> boto3.Session:
"""
Returns an authenticated boto3 session that can be used to create clients
Expand Down Expand Up @@ -104,7 +151,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType]):
"""
Helper method to dynamically get a client type.
Expand All @@ -120,10 +167,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = self.get_boto3_session().client(
service_name=client_type, **self.aws_client_parameters.get_params_override()
)
return client
return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
"""
Expand Down Expand Up @@ -186,6 +230,21 @@ class MinIOCredentials(CredentialsBlock):
description="Extra parameters to initialize the Client.",
)

class Config:
"""Config class for pydantic model."""

arbitrary_types_allowed = True

def __hash__(self):
return hash(
(
hash(self.minio_root_user),
hash(self.minio_root_password),
hash(self.region_name),
hash(frozenset(self.aws_client_parameters.dict().items())),
)
)

def get_boto3_session(self) -> boto3.Session:
"""
Returns an authenticated boto3 session that can be used to create clients
Expand Down Expand Up @@ -218,7 +277,7 @@ def get_boto3_session(self) -> boto3.Session:
region_name=self.region_name,
)

def get_client(self, client_type: Union[str, ClientType]) -> Any:
def get_client(self, client_type: Union[str, ClientType]):
"""
Helper method to dynamically get a client type.
Expand All @@ -234,10 +293,7 @@ def get_client(self, client_type: Union[str, ClientType]) -> Any:
if isinstance(client_type, ClientType):
client_type = client_type.value

client = self.get_boto3_session().client(
service_name=client_type, **self.aws_client_parameters.get_params_override()
)
return client
return _get_client_cached(ctx=self, client_type=client_type)

def get_s3_client(self) -> S3Client:
"""
Expand Down
2 changes: 1 addition & 1 deletion prefect_aws/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def _get_s3_client(self) -> boto3.client:
Authenticate MinIO credentials or AWS credentials and return an S3 client.
This is a helper function called by read_path() or write_path().
"""
return self.credentials.get_s3_client()
return self.credentials.get_client("s3")

def _get_bucket_resource(self) -> boto3.resource:
"""
Expand Down
122 changes: 121 additions & 1 deletion tests/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from botocore.client import BaseClient
from moto import mock_s3

from prefect_aws.credentials import AwsCredentials, ClientType, MinIOCredentials
from prefect_aws.credentials import (
AwsCredentials,
ClientType,
MinIOCredentials,
_get_client_cached,
)


def test_aws_credentials_get_boto3_session():
Expand Down Expand Up @@ -44,3 +49,118 @@ def test_minio_credentials_get_boto3_session():
def test_credentials_get_client(credentials, client_type):
with mock_s3():
assert isinstance(credentials.get_client(client_type), BaseClient)


@pytest.mark.parametrize(
"credentials",
[
AwsCredentials(region_name="us-east-1"),
MinIOCredentials(
minio_root_user="root_user",
minio_root_password="root_password",
region_name="us-east-1",
),
],
)
@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_get_client_cached(credentials, client_type):
"""
Test to ensure that _get_client_cached function returns the same instance
for multiple calls with the same parameters and properly utilizes lru_cache.
"""

_get_client_cached.cache_clear()

assert _get_client_cached.cache_info().hits == 0, "Initial call count should be 0"

credentials.get_client(client_type)
credentials.get_client(client_type)
credentials.get_client(client_type)

assert _get_client_cached.cache_info().misses == 1
assert _get_client_cached.cache_info().hits == 2


@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_aws_credentials_change_causes_cache_miss(client_type):
"""
Test to ensure that changing configuration on an AwsCredentials instance
after fetching a client causes a cache miss.
"""

_get_client_cached.cache_clear()

credentials = AwsCredentials(region_name="us-east-1")

initial_client = credentials.get_client(client_type)

credentials.region_name = "us-west-2"

new_client = credentials.get_client(client_type)

assert (
initial_client is not new_client
), "Client should be different after configuration change"

assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice"


@pytest.mark.parametrize("client_type", [member.value for member in ClientType])
def test_minio_credentials_change_causes_cache_miss(client_type):
"""
Test to ensure that changing configuration on an AwsCredentials instance
after fetching a client causes a cache miss.
"""

_get_client_cached.cache_clear()

credentials = MinIOCredentials(
minio_root_user="root_user",
minio_root_password="root_password",
region_name="us-east-1",
)

initial_client = credentials.get_client(client_type)

credentials.region_name = "us-west-2"

new_client = credentials.get_client(client_type)

assert (
initial_client is not new_client
), "Client should be different after configuration change"

assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice"


@pytest.mark.parametrize(
"credentials_type, initial_field, new_field",
[
(
AwsCredentials,
{"region_name": "us-east-1"},
{"region_name": "us-east-2"},
),
(
MinIOCredentials,
{
"region_name": "us-east-1",
"minio_root_user": "root_user",
"minio_root_password": "root_password",
},
{
"region_name": "us-east-2",
"minio_root_user": "root_user",
"minio_root_password": "root_password",
},
),
],
)
def test_aws_credentials_hash_changes(credentials_type, initial_field, new_field):
credentials = credentials_type(**initial_field)
initial_hash = hash(credentials)

setattr(credentials, list(new_field.keys())[0], list(new_field.values())[0])
new_hash = hash(credentials)

assert initial_hash != new_hash, "Hash should change when region_name changes"

0 comments on commit f697760

Please sign in to comment.