Skip to content

Commit

Permalink
Merge pull request #8291 from kyleknap/v2-crt-nonseekable
Browse files Browse the repository at this point in the history
[v2] Add support for stdin/stdout streams for CRT client
  • Loading branch information
kyleknap committed Nov 22, 2023
2 parents 2b6c0a7 + 8aa85d2 commit 04ba860
Show file tree
Hide file tree
Showing 11 changed files with 575 additions and 100 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/enhancement-s3cp-40300.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "enhancement",
"category": "``s3 cp``",
"description": "Support streaming uploads from stdin and streaming downloads to stdout for CRT transfer client"
}
2 changes: 0 additions & 2 deletions awscli/customizations/s3/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ def create_transfer_manager(self, params, runtime_config,
def _compute_transfer_client_type(self, params, runtime_config):
if params.get('paths_type') == 's3s3':
return constants.DEFAULT_TRANSFER_CLIENT
if params.get('is_stream'):
return constants.DEFAULT_TRANSFER_CLIENT
return runtime_config.get(
'preferred_transfer_client', constants.DEFAULT_TRANSFER_CLIENT)

Expand Down
168 changes: 130 additions & 38 deletions awscli/s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,19 +428,12 @@ def _crt_request_from_aws_request(self, aws_request):
headers_list.append((name, str(value, 'utf-8')))

crt_headers = awscrt.http.HttpHeaders(headers_list)
# CRT requires body (if it exists) to be an I/O stream.
crt_body_stream = None
if aws_request.body:
if hasattr(aws_request.body, 'seek'):
crt_body_stream = aws_request.body
else:
crt_body_stream = BytesIO(aws_request.body)

crt_request = awscrt.http.HttpRequest(
method=aws_request.method,
path=crt_path,
headers=crt_headers,
body_stream=crt_body_stream,
body_stream=aws_request.body,
)
return crt_request

Expand All @@ -453,6 +446,25 @@ def _convert_to_crt_http_request(self, botocore_http_request):
crt_request.headers.set("host", url_parts.netloc)
if crt_request.headers.get('Content-MD5') is not None:
crt_request.headers.remove("Content-MD5")

# In general, the CRT S3 client expects a content length header. It
# only expects a missing content length header if the body is not
# seekable. However, botocore does not set the content length header
# for GetObject API requests and so we set the content length to zero
# to meet the CRT S3 client's expectation that the content length
# header is set even if there is no body.
if crt_request.headers.get('Content-Length') is None:
if botocore_http_request.body is None:
crt_request.headers.add('Content-Length', "0")

# Botocore sets the Transfer-Encoding header when it cannot determine
# the content length of the request body (e.g. it's not seekable).
# However, CRT does not support this header, but it supports
# non-seekable bodies. So we remove this header to not cause issues
# in the downstream CRT S3 request.
if crt_request.headers.get('Transfer-Encoding') is not None:
crt_request.headers.remove('Transfer-Encoding')

return crt_request

def _capture_http_request(self, request, **kwargs):
Expand Down Expand Up @@ -555,39 +567,20 @@ def __init__(self, crt_request_serializer, os_utils):
def get_make_request_args(
self, request_type, call_args, coordinator, future, on_done_after_calls
):
recv_filepath = None
send_filepath = None
s3_meta_request_type = getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
request_args_handler = getattr(
self,
f'_get_make_request_args_{request_type}',
self._default_get_make_request_args,
)
on_done_before_calls = []
if s3_meta_request_type == S3RequestType.GET_OBJECT:
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
file_ondone_call = RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
on_done_before_calls.append(file_ondone_call)
elif s3_meta_request_type == S3RequestType.PUT_OBJECT:
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len

crt_request = self._request_serializer.serialize_http_request(
request_type, future
return request_args_handler(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=[],
on_done_after_calls=on_done_after_calls,
)

return {
'request': crt_request,
'type': s3_meta_request_type,
'recv_filepath': recv_filepath,
'send_filepath': send_filepath,
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}

def get_crt_callback(
self,
future,
Expand All @@ -613,6 +606,97 @@ def invoke_all_callbacks(*args, **kwargs):

return invoke_all_callbacks

def _get_make_request_args_put_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
send_filepath = None
if isinstance(call_args.fileobj, str):
send_filepath = call_args.fileobj
data_len = self._os_utils.get_file_size(send_filepath)
call_args.extra_args["ContentLength"] = data_len
else:
call_args.extra_args["Body"] = call_args.fileobj

# Suppress botocore's automatic MD5 calculation by setting an override
# value that will get deleted in the BotocoreCRTRequestSerializer.
# The CRT S3 client is able automatically compute checksums as part of
# requests it makes, and the intention is to configure automatic
# checksums in a future update.
call_args.extra_args["ContentMD5"] = "override-to-be-removed"

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['send_filepath'] = send_filepath
return make_request_args

def _get_make_request_args_get_object(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
recv_filepath = None
on_body = None
if isinstance(call_args.fileobj, str):
final_filepath = call_args.fileobj
recv_filepath = self._os_utils.get_temp_filename(final_filepath)
on_done_before_calls.append(
RenameTempFileHandler(
coordinator, final_filepath, recv_filepath, self._os_utils
)
)
else:
on_body = OnBodyFileObjWriter(call_args.fileobj)

make_request_args = self._default_get_make_request_args(
request_type=request_type,
call_args=call_args,
coordinator=coordinator,
future=future,
on_done_before_calls=on_done_before_calls,
on_done_after_calls=on_done_after_calls,
)
make_request_args['recv_filepath'] = recv_filepath
make_request_args['on_body'] = on_body
return make_request_args

def _default_get_make_request_args(
self,
request_type,
call_args,
coordinator,
future,
on_done_before_calls,
on_done_after_calls,
):
return {
'request': self._request_serializer.serialize_http_request(
request_type, future
),
'type': getattr(
S3RequestType, request_type.upper(), S3RequestType.DEFAULT
),
'on_done': self.get_crt_callback(
future, 'done', on_done_before_calls, on_done_after_calls
),
'on_progress': self.get_crt_callback(future, 'progress'),
}


class RenameTempFileHandler:
def __init__(self, coordinator, final_filename, temp_filename, osutil):
Expand Down Expand Up @@ -642,3 +726,11 @@ def __init__(self, coordinator):

def __call__(self, **kwargs):
self._coordinator.set_done_callbacks_complete()


class OnBodyFileObjWriter:
def __init__(self, fileobj):
self._fileobj = fileobj

def __call__(self, chunk, **kwargs):
self._fileobj.write(chunk)
3 changes: 0 additions & 3 deletions awscli/topics/s3-config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,6 @@ files to and from S3. Valid choices are:

* S3 to S3 copies - Falls back to using the ``default`` transfer client

* Streaming uploads from standard input and downloads to standard output -
Falls back to using ``default`` transfer client.

* Region redirects - Transfers fail for requests sent to a region that does
not match the region of the targeted S3 bucket.

Expand Down
22 changes: 19 additions & 3 deletions tests/functional/s3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def setUp(self):
self.mock_crt_client.return_value.make_request.side_effect = \
self.simulate_make_request_side_effect
self.files = FileCreator()
self.expected_download_content = b'content'

def tearDown(self):
super(BaseCRTTransferClientTest, self).tearDown()
Expand All @@ -456,6 +457,8 @@ def get_config_file_contents(self):
def simulate_make_request_side_effect(self, *args, **kwargs):
if kwargs.get('recv_filepath'):
self.simulate_file_download(kwargs['recv_filepath'])
elif kwargs.get('on_body'):
self.simulate_on_body(kwargs['on_body'])
s3_request = FakeCRTS3Request(
future=FakeCRTFuture(kwargs.get('on_done'))
)
Expand All @@ -465,11 +468,14 @@ def simulate_file_download(self, recv_filepath):
parent_dir = os.path.dirname(recv_filepath)
if not os.path.isdir(parent_dir):
os.makedirs(parent_dir)
with open(recv_filepath, 'w') as f:
with open(recv_filepath, 'wb') as f:
# The content is arbitrary as most functional tests are just going
# to assert the file exists since it is the CRT writing the
# data to the file.
f.write('content')
f.write(self.expected_download_content)

def simulate_on_body(self, on_body_callback):
on_body_callback(chunk=self.expected_download_content, offset=0)

def get_crt_make_request_calls(self):
return self.mock_crt_client.return_value.make_request.call_args_list
Expand All @@ -489,7 +495,8 @@ def assert_crt_make_request_call(
self, make_request_call, expected_type, expected_host,
expected_path, expected_http_method=None,
expected_send_filepath=None,
expected_recv_startswith=None):
expected_recv_startswith=None,
expected_body_content=None):
make_request_kwargs = make_request_call[1]
self.assertEqual(
make_request_kwargs['type'], expected_type)
Expand Down Expand Up @@ -522,6 +529,15 @@ def assert_crt_make_request_call(
f"start with {expected_recv_startswith}"
)
)
if expected_body_content is not None:
# Note: The underlying CRT awscrt.io.InputStream does not expose
# a public read method so we have to reach into the private,
# underlying stream to determine the content. We should update
# to use a public interface if a public interface is ever exposed.
self.assertEqual(
make_request_kwargs['request'].body_stream._stream.read(),
expected_body_content
)


class FakeCRTS3Request:
Expand Down
33 changes: 23 additions & 10 deletions tests/functional/s3/test_cp_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2116,25 +2116,38 @@ def test_does_not_use_crt_client_for_copies(self):
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()

def test_does_not_use_crt_client_for_streaming_upload(self):
def test_streaming_upload_using_crt_client(self):
cmdline = [
's3', 'cp', '-', 's3://bucket/key'
]
self.add_botocore_put_object_response()
with mock.patch('sys.stdin', BufferedBytesIO(b'foo')):
self.run_command(cmdline)
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()
crt_requests = self.get_crt_make_request_calls()
self.assertEqual(len(crt_requests), 1)
self.assert_crt_make_request_call(
crt_requests[0],
expected_type=S3RequestType.PUT_OBJECT,
expected_host=self.get_virtual_s3_host('bucket'),
expected_path='/key',
expected_body_content=b'foo',
)

def test_does_not_use_crt_client_for_streaming_download(self):
def test_streaming_download_using_crt_client(self):
cmdline = [
's3', 'cp', 's3://bucket/key', '-'
]
self.add_botocore_head_object_response()
self.add_botocore_get_object_response()
self.run_command(cmdline)
self.assertEqual(self.get_crt_make_request_calls(), [])
self.assert_no_remaining_botocore_responses()
result = self.run_command(cmdline)
crt_requests = self.get_crt_make_request_calls()
self.assertEqual(len(crt_requests), 1)
self.assert_crt_make_request_call(
crt_requests[0],
expected_type=S3RequestType.GET_OBJECT,
expected_host=self.get_virtual_s3_host('bucket'),
expected_path='/key',
)
self.assertEqual(
result.stdout, self.expected_download_content.decode('utf-8')
)

def test_respects_region_parameter(self):
filename = self.files.create_file('myfile', 'mycontent')
Expand Down
Loading

0 comments on commit 04ba860

Please sign in to comment.