Skip to content

Commit

Permalink
Merge pull request #279 from kyleknap/crt-streams
Browse files Browse the repository at this point in the history
Support file-like objects in CRT transfer manager
  • Loading branch information
kyleknap authored Nov 3, 2023
2 parents cc9345a + 7d5ec27 commit b8906b3
Show file tree
Hide file tree
Showing 6 changed files with 531 additions and 82 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-crt-51520.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``crt``",
"description": "Add support for uploading and downloading file-like objects using CRT transfer manager. It supports both seekable and non-seekable file-like objects."
}
168 changes: 130 additions & 38 deletions s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,12 @@ def _crt_request_from_aws_request(self, aws_request):
headers_list.append((name, str(value, 'utf-8')))

crt_headers = awscrt.http.HttpHeaders(headers_list)
# CRT requires body (if it exists) to be an I/O stream.
crt_body_stream = None
if aws_request.body:
if hasattr(aws_request.body, 'seek'):
crt_body_stream = aws_request.body
else:
crt_body_stream = BytesIO(aws_request.body)

crt_request = awscrt.http.HttpRequest(
method=aws_request.method,
path=crt_path,
headers=crt_headers,
body_stream=crt_body_stream,
body_stream=aws_request.body,
)
return crt_request

Expand All @@ -453,6 +446,25 @@ def _convert_to_crt_http_request(self, botocore_http_request):
crt_request.headers.set("host", url_parts.netloc)
if crt_request.headers.get('Content-MD5') is not None:
crt_request.headers.remove("Content-MD5")

# In general, the CRT S3 client expects a content length header. It
# only expects a missing content length header if the body is not
# seekable. However, botocore does not set the content length header
# for GetObject API requests and so we set the content length to zero
# to meet the CRT S3 client's expectation that the content length
# header is set even if there is no body.
if crt_request.headers.get('Content-Length') is None:
if botocore_http_request.body is None:
crt_request.headers.add('Content-Length', "0")

# Botocore sets the Transfer-Encoding header when it cannot determine
# the content length of the request body (e.g. it's not seekable).
# However, CRT does not support this header, but it supports
# non-seekable bodies. So we remove this header to not cause issues
# in the downstream CRT S3 request.
if crt_request.headers.get('Transfer-Encoding') is not None:
crt_request.headers.remove('Transfer-Encoding')

return crt_request

def _capture_http_request(self, request, **kwargs):
Expand Down Expand Up @@ -555,39 +567,20 @@ def __init__(self, crt_request_serializer, os_utils):
def get_make_request_args(
self, request_type, call_args, coordinator, future, on_done_after_calls
):
recv_filepath = None
send_filepath = None
s3_meta_request_type = getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
request_args_handler = getattr(
self,
f'_get_make_request_args_{request_type}',
self._default_get_make_request_args,
)
on_done_before_calls = []
if s3_meta_request_type == S3RequestType.GET_OBJECT:
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
file_ondone_call = RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
on_done_before_calls.append(file_ondone_call)
elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len

crt_request = self._request_serializer.serialize_http_request(
request_type, future
return request_args_handler(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=[],
on_done_after_calls=on_done_after_calls,
)

return {
'request': crt_request,
'type': s3_meta_request_type,
'recv_filepath': recv_filepath,
'send_filepath': send_filepath,
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}

def get_crt_callback(
self,
future,
Expand All @@ -613,6 +606,97 @@ def invoke_all_callbacks(*args, **kwargs):

return invoke_all_callbacks

def _get_make_request_args_put_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
send_filepath = None
if isinstance(call_args.fileobj, str):
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len
else:
call_args.extra_args["Body"] = call_args.fileobj

# Suppress botocore's automatic MD5 calculation by setting an override
# value that will get deleted in the BotocoreCRTRequestSerializer.
# The CRT S3 client is able automatically compute checksums as part of
# requests it makes, and the intention is to configure automatic
# checksums in a future update.
call_args.extra_args["ContentMD5"] = "override-to-be-removed"

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['send_filepath'] = send_filepath
return make_request_args

def _get_make_request_args_get_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
recv_filepath = None
on_body = None
if isinstance(call_args.fileobj, str):
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
on_done_before_calls.append(
RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
)
else:
on_body = OnBodyFileObjWriter(call_args.fileobj)

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['recv_filepath'] = recv_filepath
make_request_args['on_body'] = on_body
return make_request_args

def _default_get_make_request_args(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
return {
'request': self._request_serializer.serialize_http_request(
request_type, future
),
'type': getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
),
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}


class RenameTempFileHandler:
def __init__(self, coordinator, final_filename, temp_filename, osutil):
Expand Down Expand Up @@ -642,3 +726,11 @@ def __init__(self, coordinator):

def __call__(self, **kwargs):
self._coordinator.set_done_callbacks_complete()


class OnBodyFileObjWriter:
def __init__(self, fileobj):
self._fileobj = fileobj

def __call__(self, chunk, **kwargs):
self._fileobj.write(chunk)
3 changes: 3 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,9 @@ def write(self, b):
def read(self, n=-1):
return self._data.read(n)

def readinto(self, b):
return self._data.readinto(b)


class NonSeekableWriter(io.RawIOBase):
def __init__(self, fileobj):
Expand Down
Loading

0 comments on commit b8906b3

Please sign in to comment.