Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of azure blob container interface #1853

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion api/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

EXPIRE_TIME = 60 * 60 * 24 * 30
REDIS_USE_SSL = os.getenv("REDIS_USE_SSL", "FALSE").lower() == "true"
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)


class TatorCache:
Expand All @@ -25,7 +28,9 @@ class TatorCache:
def setup_redis(cls):
retry = Retry(ExponentialBackoff(), 3)
cls.rds = Redis(
host=os.getenv("REDIS_HOST"),
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
retry=retry,
retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError],
health_check_interval=30,
Expand Down
223 changes: 223 additions & 0 deletions api/main/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from botocore.client import Config
from google.cloud import storage
from google.oauth2.service_account import Credentials
from azure.storage.blob import (
BlobClient,
BlobSasPermissions,
BlobServiceClient,
generate_blob_sas,
)

logger = logging.getLogger(__name__)

Expand All @@ -30,6 +36,7 @@ class ObjectStore(Enum):
GCP = "GCP"
OCI = "OCI"
VAST = "VAST"
AZURE = "AZURE"


# TODO Deprecated: remove once support for old bucket model is removed
Expand All @@ -56,13 +63,15 @@ class OldObjectStore(Enum):
ObjectStore.GCP: ["STANDARD", "COLDLINE"],
ObjectStore.OCI: ["STANDARD"],
ObjectStore.VAST: ["STANDARD"],
ObjectStore.AZURE: ["Cold", "Archive"],
},
"live_sc": {
ObjectStore.AWS: ["STANDARD"],
ObjectStore.MINIO: ["STANDARD"],
ObjectStore.GCP: ["STANDARD"],
ObjectStore.OCI: ["STANDARD"],
ObjectStore.VAST: ["STANDARD"],
ObjectStore.AZURE: ["Hot"],
},
}

Expand All @@ -74,13 +83,15 @@ class OldObjectStore(Enum):
ObjectStore.GCP: "COLDLINE",
ObjectStore.OCI: "STANDARD",
ObjectStore.VAST: "STANDARD",
ObjectStore.AZURE: "Archive",
},
"live_sc": {
ObjectStore.AWS: "STANDARD",
ObjectStore.MINIO: "STANDARD",
ObjectStore.GCP: "STANDARD",
ObjectStore.OCI: "STANDARD",
ObjectStore.VAST: "STANDARD",
ObjectStore.AZURE: "Hot",
},
}

Expand Down Expand Up @@ -115,6 +126,9 @@ def _client_from_config(
if store_type == ObjectStore.GCP:
return storage.Client(config["project_id"], Credentials.from_service_account_info(config))

if store_type == ObjectStore.AZURE:
return BlobServiceClient.from_connection_string(config["connection_string"])

raise ValueError(f"Received unhandled store type '{store_type}'")


Expand Down Expand Up @@ -167,6 +181,8 @@ def get_tator_store(server, bucket, client, bucket_name, external_host=None):
return OCIStorage(bucket, client, bucket_name, external_host)
if server is ObjectStore.VAST:
return VASTStorage(bucket, client, bucket_name, external_host)
if server is ObjectStore.AZURE:
return AzureStorage(bucket, client, bucket_name, external_host)

raise ValueError(f"Server type '{server}' is not supported")

Expand Down Expand Up @@ -676,6 +692,213 @@ def restore_object(self, path, desired_storage_class, min_exp_days):
# TODO determine if we need to update the `current_time` field
self._update_storage_class(path, desired_storage_class)

class AzureStorage(TatorStorage):
def __init__(self, bucket, client, bucket_name, external_host=None):
super().__init__(bucket, client, bucket_name, external_host)
self._server = ObjectStore.AZURE

def check_key(self, path: str) -> bool:
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
return blob_client.exists()

def object_tagged_for_archive(self, path: str) -> bool:
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
try:
tags = blob_client.get_blob_tags()
return tags.get(ARCHIVE_KEY) == "true"
except Exception:
return False

def _put_archive_tag(self, path: str):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
tags = blob_client.get_blob_tags() or {}
tags[ARCHIVE_KEY] = "true"
blob_client.set_blob_tags(tags)

def put_media_id_tag(self, path: str, media_id):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
tags = blob_client.get_blob_tags() or {}
tags[MEDIA_ID_KEY] = str(media_id)
blob_client.set_blob_tags(tags)

def _head_object(self, path: str) -> dict:
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
properties = blob_client.get_blob_properties()
return {
"ContentLength": properties.size,
"ContentType": properties.content_settings.content_type,
"LastModified": properties.last_modified,
"StorageClass": properties.blob_tier,
}

def copy(self, source_path: str, dest_path: str, extra_args: Optional[dict] = None):
source_blob = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(source_path)
)
dest_blob = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(dest_path)
)
dest_blob.start_copy_from_url(source_blob.url)

if extra_args:
if "StorageClass" in extra_args:
dest_blob.set_standard_blob_tier(extra_args["StorageClass"])

if extra_args.get("MetadataDirective") == "REPLACE":
metadata = extra_args.get("Metadata", {})
dest_blob.set_blob_metadata(metadata)

if extra_args.get("TaggingDirective") == "REPLACE":
tags = extra_args.get("Tagging", {})
dest_blob.set_blob_tags(tags)

def restore_object(
self, path: str, desired_storage_class: str, min_exp_days: int
):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
blob_client.set_standard_blob_tier(desired_storage_class)

def delete_object(self, path: str):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
blob_client.delete_blob()

def get_download_url(self, path: str, expiration: int) -> str:
sas_token = generate_blob_sas(
account_name=self.client.account_name,
container_name=self.bucket_name,
blob_name=self.path_to_key(path),
account_key=self.client.credential.account_key,
permission=BlobSasPermissions(read=True),
expiry=datetime.utcnow() + timedelta(seconds=expiration),
)
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
url = f"{blob_client.url}?{sas_token}"

if self.external_host:
parsed = urlsplit(url)
external = urlsplit(self.external_host, scheme=self.proto)
parsed = parsed._replace(
netloc=external.netloc + external.path, scheme=external.scheme
)
url = urlunsplit(parsed)
return url

def _get_multiple_upload_urls(
self, key: str, expiration: int, num_parts: int, domain: str
) -> Tuple[List[str], str]:
sas_token = generate_blob_sas(
account_name=self.client.account_name,
container_name=self.bucket_name,
blob_name=key,
account_key=self.client.credential.account_key,
permission=BlobSasPermissions(write=True, create=True),
expiry=datetime.utcnow() + timedelta(seconds=expiration),
)
url = f"{self.client.primary_endpoint}/{self.bucket_name}/{key}?{sas_token}"
urls = [url] * num_parts
upload_id = "" # Azure does not use upload IDs in the same way
return urls, upload_id

def _get_single_upload_url(
self, key: str, expiration: int, domain: str
) -> Tuple[List[str], str]:
sas_token = generate_blob_sas(
account_name=self.client.account_name,
container_name=self.bucket_name,
blob_name=key,
account_key=self.client.credential.account_key,
permission=BlobSasPermissions(write=True, create=True),
expiry=datetime.utcnow() + timedelta(seconds=expiration),
)
url = f"{self.client.primary_endpoint}/{self.bucket_name}/{key}?{sas_token}"
return [url], ""

def _list_objects_v2(self, prefix: Optional[str] = None, **kwargs) -> list:
container_client = self.client.get_container_client(self.bucket_name)
blobs = container_client.list_blobs(name_starts_with=prefix)
contents = []
for blob in blobs:
contents.append(
{
"Key": blob.name,
"LastModified": blob.last_modified,
"ETag": blob.etag,
"Size": blob.size,
"StorageClass": blob.blob_tier,
}
)
return contents

def complete_multipart_upload(
self, path: str, parts: int, upload_id: str
) -> bool:
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
# In Azure, parts are blocks identified by block IDs
block_list = [f"{i:06d}" for i in range(1, parts + 1)]
try:
blob_client.commit_block_list(block_list)
except Exception as excep:
logger.info(f"Multipart failed: {excep}")
return False
return True

def put_object(self, path: str, body: IO):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
blob_client.upload_blob(body, overwrite=True)

def put_string(self, path: str, body: Union[bytes, str]):
if isinstance(body, str):
body = body.encode("utf-8")
self.put_object(path, body)

def get_object(
self, path: str, start: Optional[int] = None, stop: Optional[int] = None
) -> bytes:
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
if start is not None and stop is not None:
length = stop - start + 1
stream = blob_client.download_blob(offset=start, length=length)
elif start is None and stop is None:
stream = blob_client.download_blob()
else:
raise ValueError("Must specify both or neither start and stop arguments")
return stream.readall()

def download_fileobj(self, path: str, fp: IO):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
stream = blob_client.download_blob()
stream.download_to_stream(fp)

def _update_storage_class(self, path: str, desired_storage_class: str):
blob_client = self.client.get_blob_client(
container=self.bucket_name, blob=self.path_to_key(path)
)
blob_client.set_standard_blob_tier(desired_storage_class)


def get_tator_store(
bucket=None, connect_timeout=5, read_timeout=5, max_attempts=5, upload=False, backup=False
Expand Down
8 changes: 7 additions & 1 deletion api/main/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
import logging

REDIS_USE_SSL = os.getenv("REDIS_USE_SSL", "FALSE").lower() == "true"
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))
REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD", None)

if os.getenv("DD_LOGS_INJECTION"):
import ddtrace.auto
Expand All @@ -37,9 +40,12 @@ def push_job(queue, function, *args, **kwargs):
"""
retry = Retry(ExponentialBackoff(), 3)
redis = Redis(
host=os.getenv("REDIS_HOST"),
host=REDIS_HOST,
port=REDIS_PORT,
password=REDIS_PASSWORD,
retry=retry,
retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError],
health_check_interval=30,
ssl=REDIS_USE_SSL,
)
queue = Queue(queue, connection=redis)
Expand Down
1 change: 1 addition & 0 deletions containers/tator/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
azure-storage-blob==12.23.1
boto3==1.34.66
certifi==2024.7.4
cryptography==42.0.5
Expand Down
Loading