From a2fb16c431152bd677093f9910f405a37533ad5a Mon Sep 17 00:00:00 2001 From: Dengke Tang Date: Mon, 29 Jul 2024 13:12:51 -0700 Subject: [PATCH] adapt change from "TLS deliver buffer data during shutdown" (#474) --- include/aws/http/private/h1_connection.h | 6 +- source/h1_connection.c | 71 ++++++++-- source/websocket.c | 5 + tests/CMakeLists.txt | 1 + tests/test_connection.c | 3 - tests/test_tls.c | 171 ++++++++++++++++++----- 6 files changed, 202 insertions(+), 55 deletions(-) diff --git a/include/aws/http/private/h1_connection.h b/include/aws/http/private/h1_connection.h index 86a5124ea..bff0695b8 100644 --- a/include/aws/http/private/h1_connection.h +++ b/include/aws/http/private/h1_connection.h @@ -128,12 +128,16 @@ struct aws_h1_connection { /* If non-zero, reason to immediately reject new streams. (ex: closing) */ int new_stream_error_code; + /* If true, user has called connection_close() or stream_cancel(), + * but the cross_thread_work_task hasn't processed it yet */ + bool shutdown_requested; + int shutdown_requested_error_code; + /* See `cross_thread_work_task` */ bool is_cross_thread_work_task_scheduled : 1; /* For checking status from outside the event-loop thread. */ bool is_open : 1; - } synced_data; }; diff --git a/source/h1_connection.c b/source/h1_connection.c index 903cf0381..29aa82fd8 100644 --- a/source/h1_connection.c +++ b/source/h1_connection.c @@ -130,7 +130,6 @@ void aws_h1_connection_unlock_synced_data(struct aws_h1_connection *connection) * - Channel is shutting down in the read direction. * - Channel is shutting down in the write direction. * - An error occurs. - * - User wishes to close the connection (this is the only case where the function may run off-thread). */ static void s_stop( struct aws_h1_connection *connection, @@ -139,15 +138,14 @@ static void s_stop( bool schedule_shutdown, int error_code) { + AWS_ASSERT(aws_channel_thread_is_callers_thread(connection->base.channel_slot->channel)); AWS_ASSERT(stop_reading || stop_writing || schedule_shutdown); /* You are required to stop at least 1 thing */ if (stop_reading) { - AWS_ASSERT(aws_channel_thread_is_callers_thread(connection->base.channel_slot->channel)); connection->thread_data.is_reading_stopped = true; } if (stop_writing) { - AWS_ASSERT(aws_channel_thread_is_callers_thread(connection->base.channel_slot->channel)); connection->thread_data.is_writing_stopped = true; } { /* BEGIN CRITICAL SECTION */ @@ -169,6 +167,11 @@ static void s_stop( aws_error_name(error_code)); aws_channel_shutdown(connection->base.channel_slot->channel, error_code); + if (stop_reading) { + /* Increase the window size after shutdown starts, to prevent deadlock when data still pending in the TLS + * handler. */ + aws_channel_slot_increment_read_window(connection->base.channel_slot, SIZE_MAX); + } } } @@ -189,14 +192,45 @@ static void s_shutdown_due_to_error(struct aws_h1_connection *connection, int er s_stop(connection, true /*stop_reading*/, true /*stop_writing*/, true /*schedule_shutdown*/, error_code); } +/** + * Helper to shutdown the connection from non-channel thread. (User wishes to close the connection) + **/ +static void s_shutdown_from_off_thread(struct aws_h1_connection *connection, int error_code) { + bool should_schedule_task = false; + { /* BEGIN CRITICAL SECTION */ + aws_h1_connection_lock_synced_data(connection); + if (!connection->synced_data.is_cross_thread_work_task_scheduled) { + connection->synced_data.is_cross_thread_work_task_scheduled = true; + should_schedule_task = true; + } + if (!connection->synced_data.shutdown_requested) { + connection->synced_data.shutdown_requested = true; + connection->synced_data.shutdown_requested_error_code = error_code; + } + /* Connection has shutdown, new streams should not be allowed. */ + connection->synced_data.is_open = false; + connection->synced_data.new_stream_error_code = AWS_ERROR_HTTP_CONNECTION_CLOSED; + aws_h1_connection_unlock_synced_data(connection); + } /* END CRITICAL SECTION */ + + if (should_schedule_task) { + AWS_LOGF_TRACE( + AWS_LS_HTTP_CONNECTION, "id=%p: Scheduling connection cross-thread work task.", (void *)&connection->base); + aws_channel_schedule_task_now(connection->base.channel_slot->channel, &connection->cross_thread_work_task); + } else { + AWS_LOGF_TRACE( + AWS_LS_HTTP_CONNECTION, + "id=%p: Connection cross-thread work task was already scheduled", + (void *)&connection->base); + } +} + /** * Public function for closing connection. */ static void s_connection_close(struct aws_http_connection *connection_base) { struct aws_h1_connection *connection = AWS_CONTAINER_OF(connection_base, struct aws_h1_connection, base); - - /* Don't stop reading/writing immediately, let that happen naturally during the channel shutdown process. */ - s_stop(connection, false /*stop_reading*/, false /*stop_writing*/, true /*schedule_shutdown*/, AWS_ERROR_SUCCESS); + s_shutdown_from_off_thread(connection, AWS_ERROR_SUCCESS); } static void s_connection_stop_new_request(struct aws_http_connection *connection_base) { @@ -412,8 +446,7 @@ void aws_h1_stream_cancel(struct aws_http_stream *stream, int error_code) { (void *)stream, error_code, aws_error_name(error_code)); - - s_stop(connection, false /*stop_reading*/, false /*stop_writing*/, true /*schedule_shutdown*/, error_code); + s_shutdown_from_off_thread(connection, error_code); } struct aws_http_stream *s_make_request( @@ -495,10 +528,17 @@ static void s_cross_thread_work_task(struct aws_channel_task *channel_task, void bool has_new_client_streams = !aws_linked_list_empty(&connection->synced_data.new_client_stream_list); aws_linked_list_move_all_back( &connection->thread_data.stream_list, &connection->synced_data.new_client_stream_list); + bool shutdown_requested = connection->synced_data.shutdown_requested; + int shutdown_error = connection->synced_data.shutdown_requested_error_code; + connection->synced_data.shutdown_requested = false; + connection->synced_data.shutdown_requested_error_code = 0; aws_h1_connection_unlock_synced_data(connection); /* END CRITICAL SECTION */ + if (shutdown_requested) { + s_stop(connection, true /*stop_reading*/, true /*stop_writing*/, true /*schedule_shutdown*/, shutdown_error); + } /* Kick off outgoing-stream task if necessary */ if (has_new_client_streams) { aws_h1_connection_try_write_outgoing_stream(connection); @@ -785,13 +825,8 @@ static void s_http_stream_response_first_byte_timeout_task( (void *)connection_base, response_first_byte_timeout_ms); - /* Don't stop reading/writing immediately, let that happen naturally during the channel shutdown process. */ - s_stop( - connection, - false /*stop_reading*/, - false /*stop_writing*/, - true /*schedule_shutdown*/, - AWS_ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT); + /* Shutdown the connection. */ + s_shutdown_due_to_error(connection, AWS_ERROR_HTTP_RESPONSE_FIRST_BYTE_TIMEOUT); } static void s_set_outgoing_message_done(struct aws_h1_stream *stream) { @@ -1804,6 +1839,12 @@ static int s_handler_process_read_message( AWS_LOGF_TRACE( AWS_LS_HTTP_CONNECTION, "id=%p: Incoming message of size %zu.", (void *)&connection->base, message_size); + if (connection->thread_data.is_reading_stopped) { + /* Read has stopped, ignore the data, shutdown the channel incase it has not started yet. */ + aws_mem_release(message->allocator, message); /* Release the message as we return success. */ + s_shutdown_due_to_error(connection, AWS_ERROR_HTTP_CONNECTION_CLOSED); + return AWS_OP_SUCCESS; + } /* Shrink connection window by amount of data received. See comments at variable's * declaration site on why we use this instead of the official `aws_channel_slot.window_size`. */ diff --git a/source/websocket.c b/source/websocket.c index 8b5795362..8637081c0 100644 --- a/source/websocket.c +++ b/source/websocket.c @@ -989,8 +989,13 @@ static void s_shutdown_channel_task(struct aws_channel_task *task, void *arg, en s_unlock_synced_data(websocket); /* END CRITICAL SECTION */ + websocket->thread_data.is_reading_stopped = true; + websocket->thread_data.is_writing_stopped = true; aws_channel_shutdown(websocket->channel_slot->channel, error_code); + /* Increase the window size after shutdown starts, to prevent deadlock when data still pending in the upstream + * handler. */ + aws_channel_slot_increment_read_window(websocket->channel_slot, SIZE_MAX); } /* Tell the channel to shut down. It is safe to call this multiple times. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6a0a4de15..7b2bdf7e5 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -159,6 +159,7 @@ add_test_case(strutil_is_http_pseudo_header_name) add_net_test_case(tls_download_medium_file_h1) add_net_test_case(tls_download_medium_file_h2) +add_net_test_case(test_tls_download_shutdown_with_window_size_0) add_test_case(websocket_decoder_sanity_check) add_test_case(websocket_decoder_simplest_frame) diff --git a/tests/test_connection.c b/tests/test_connection.c index 71c0c36e9..ab04ce888 100644 --- a/tests/test_connection.c +++ b/tests/test_connection.c @@ -508,9 +508,6 @@ AWS_TEST_CASE(connection_setup_shutdown, s_test_connection_setup_shutdown); static int s_test_connection_setup_shutdown_tls(struct aws_allocator *allocator, void *ctx) { (void)ctx; -#ifdef __APPLE__ /* Something is wrong with APPLE */ - return AWS_OP_SUCCESS; -#endif struct tester_options options = { .alloc = allocator, .tls = true, diff --git a/tests/test_tls.c b/tests/test_tls.c index 779bb49c4..22b92e8d0 100644 --- a/tests/test_tls.c +++ b/tests/test_tls.c @@ -58,8 +58,8 @@ static void s_on_connection_shutdown(struct aws_http_connection *connection, int test->client_connection_is_shutdown = true; test->wait_result = error_code; - AWS_FATAL_ASSERT(aws_mutex_unlock(&test->wait_lock) == AWS_OP_SUCCESS); aws_condition_variable_notify_one(&test->wait_cvar); + AWS_FATAL_ASSERT(aws_mutex_unlock(&test->wait_lock) == AWS_OP_SUCCESS); } static int s_test_wait(struct test_ctx *test, bool (*pred)(void *user_data)) { @@ -108,6 +108,48 @@ static bool s_stream_wait_pred(void *user_data) { return test->wait_result || test->stream_complete; } +static int s_test_ctx_init(struct aws_allocator *allocator, struct test_ctx *test, bool h2_required) { + + AWS_ZERO_STRUCT(*test); + test->alloc = allocator; + + aws_mutex_init(&test->wait_lock); + aws_condition_variable_init(&test->wait_cvar); + + test->event_loop_group = aws_event_loop_group_new_default(test->alloc, 1, NULL); + + struct aws_host_resolver_default_options resolver_options = { + .el_group = test->event_loop_group, + .max_entries = 1, + }; + + test->host_resolver = aws_host_resolver_new_default(test->alloc, &resolver_options); + + struct aws_client_bootstrap_options bootstrap_options = { + .event_loop_group = test->event_loop_group, + .host_resolver = test->host_resolver, + }; + ASSERT_NOT_NULL(test->client_bootstrap = aws_client_bootstrap_new(test->alloc, &bootstrap_options)); + struct aws_tls_ctx_options tls_ctx_options; + aws_tls_ctx_options_init_default_client(&tls_ctx_options, allocator); + char *apln = h2_required ? "h2" : "http/1.1"; + aws_tls_ctx_options_set_alpn_list(&tls_ctx_options, apln); + ASSERT_NOT_NULL(test->tls_ctx = aws_tls_client_ctx_new(allocator, &tls_ctx_options)); + + aws_tls_ctx_options_clean_up(&tls_ctx_options); + return AWS_OP_SUCCESS; +} + +static void s_test_ctx_clean_up(struct test_ctx *test) { + aws_client_bootstrap_release(test->client_bootstrap); + aws_host_resolver_release(test->host_resolver); + aws_event_loop_group_release(test->event_loop_group); + aws_tls_ctx_release(test->tls_ctx); + + aws_mutex_clean_up(&test->wait_lock); + aws_condition_variable_clean_up(&test->wait_cvar); +} + static int s_test_tls_download_medium_file_general( struct aws_allocator *allocator, struct aws_byte_cursor url, @@ -125,31 +167,8 @@ static int s_test_tls_download_medium_file_general( }; struct test_ctx test; - AWS_ZERO_STRUCT(test); - test.alloc = allocator; - - aws_mutex_init(&test.wait_lock); - aws_condition_variable_init(&test.wait_cvar); - - test.event_loop_group = aws_event_loop_group_new_default(test.alloc, 1, NULL); - - struct aws_host_resolver_default_options resolver_options = { - .el_group = test.event_loop_group, - .max_entries = 1, - }; + ASSERT_SUCCESS(s_test_ctx_init(allocator, &test, h2_required)); - test.host_resolver = aws_host_resolver_new_default(test.alloc, &resolver_options); - - struct aws_client_bootstrap_options bootstrap_options = { - .event_loop_group = test.event_loop_group, - .host_resolver = test.host_resolver, - }; - ASSERT_NOT_NULL(test.client_bootstrap = aws_client_bootstrap_new(test.alloc, &bootstrap_options)); - struct aws_tls_ctx_options tls_ctx_options; - aws_tls_ctx_options_init_default_client(&tls_ctx_options, allocator); - char *apln = h2_required ? "h2" : "http/1.1"; - aws_tls_ctx_options_set_alpn_list(&tls_ctx_options, apln); - ASSERT_NOT_NULL(test.tls_ctx = aws_tls_client_ctx_new(allocator, &tls_ctx_options)); struct aws_tls_connection_options tls_connection_options; aws_tls_connection_options_init_from_ctx(&tls_connection_options, test.tls_ctx); aws_tls_connection_options_set_server_name( @@ -208,21 +227,11 @@ static int s_test_tls_download_medium_file_general( aws_http_connection_release(test.client_connection); ASSERT_SUCCESS(s_test_wait(&test, s_test_connection_shutdown_pred)); - - aws_client_bootstrap_release(test.client_bootstrap); - aws_host_resolver_release(test.host_resolver); - aws_event_loop_group_release(test.event_loop_group); - - aws_tls_ctx_options_clean_up(&tls_ctx_options); aws_tls_connection_options_clean_up(&tls_connection_options); - aws_tls_ctx_release(test.tls_ctx); - + s_test_ctx_clean_up(&test); aws_uri_clean_up(&uri); aws_http_library_clean_up(); - aws_mutex_clean_up(&test.wait_lock); - aws_condition_variable_clean_up(&test.wait_cvar); - return AWS_OP_SUCCESS; } @@ -243,3 +252,93 @@ static int s_tls_download_medium_file_h2(struct aws_allocator *allocator, void * return AWS_OP_SUCCESS; } AWS_TEST_CASE(tls_download_medium_file_h2, s_tls_download_medium_file_h2); + +static int s_test_tls_download_shutdown_with_window_size_0(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + + aws_http_library_init(allocator); + struct aws_byte_cursor uri_str = + aws_byte_cursor_from_c_str("https://aws-crt-test-stuff.s3.amazonaws.com/http_test_doc.txt"); + struct aws_uri uri; + AWS_ZERO_STRUCT(uri); + aws_uri_init_parse(&uri, allocator, &uri_str); + + size_t window_size = 20 * 1024; + + struct aws_socket_options socket_options = { + .type = AWS_SOCKET_STREAM, + .domain = AWS_SOCKET_IPV4, + .connect_timeout_ms = + (uint32_t)aws_timestamp_convert(TEST_TIMEOUT_SEC, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_MILLIS, NULL), + }; + + struct test_ctx test; + ASSERT_SUCCESS(s_test_ctx_init(allocator, &test, false)); + struct aws_tls_connection_options tls_connection_options; + aws_tls_connection_options_init_from_ctx(&tls_connection_options, test.tls_ctx); + aws_tls_connection_options_set_server_name( + &tls_connection_options, allocator, (struct aws_byte_cursor *)aws_uri_host_name(&uri)); + struct aws_http_client_connection_options http_options = AWS_HTTP_CLIENT_CONNECTION_OPTIONS_INIT; + http_options.allocator = test.alloc; + http_options.bootstrap = test.client_bootstrap; + http_options.host_name = *aws_uri_host_name(&uri); + http_options.port = 443; + http_options.on_setup = s_on_connection_setup; + http_options.on_shutdown = s_on_connection_shutdown; + http_options.socket_options = &socket_options; + http_options.tls_options = &tls_connection_options; + http_options.user_data = &test; + http_options.manual_window_management = true; + http_options.initial_window_size = window_size; + + ASSERT_SUCCESS(aws_http_client_connect(&http_options)); + ASSERT_SUCCESS(s_test_wait(&test, s_test_connection_setup_pred)); + ASSERT_INT_EQUALS(0, test.wait_result); + ASSERT_NOT_NULL(test.client_connection); + ASSERT_INT_EQUALS(aws_http_connection_get_version(test.client_connection), AWS_HTTP_VERSION_1_1); + + struct aws_http_message *request = aws_http_message_new_request(allocator); + ASSERT_NOT_NULL(request); + ASSERT_SUCCESS(aws_http_message_set_request_method(request, aws_http_method_get)); + ASSERT_SUCCESS(aws_http_message_set_request_path(request, *aws_uri_path_and_query(&uri))); + + struct aws_http_header header_host = { + .name = AWS_BYTE_CUR_INIT_FROM_STRING_LITERAL("Host"), + .value = *aws_uri_host_name(&uri), + }; + ASSERT_SUCCESS(aws_http_message_add_header(request, header_host)); + + struct aws_http_make_request_options req_options = { + .self_size = sizeof(req_options), + .request = request, + .on_response_body = s_on_stream_body, + .on_complete = s_on_stream_complete, + .user_data = &test, + }; + + ASSERT_NOT_NULL(test.stream = aws_http_connection_make_request(test.client_connection, &req_options)); + aws_http_stream_activate(test.stream); + + /* wait for the request to hit the window limitation. */ + aws_thread_current_sleep(aws_timestamp_convert(2, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + aws_http_connection_close(test.client_connection); + + aws_http_stream_release(test.stream); + test.stream = NULL; + + aws_http_connection_release(test.client_connection); + ASSERT_SUCCESS(s_test_wait(&test, s_stream_wait_pred)); + /* Reset the wait error. */ + test.wait_result = 0; + ASSERT_SUCCESS(s_test_wait(&test, s_test_connection_shutdown_pred)); + ASSERT_INT_EQUALS(window_size, test.body_size); + + aws_http_message_destroy(request); + aws_tls_connection_options_clean_up(&tls_connection_options); + s_test_ctx_clean_up(&test); + aws_uri_clean_up(&uri); + aws_http_library_clean_up(); + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(test_tls_download_shutdown_with_window_size_0, s_test_tls_download_shutdown_with_window_size_0);