diff --git a/test.py b/test.py index 8184c17..17e605a 100644 --- a/test.py +++ b/test.py @@ -660,7 +660,7 @@ def stream(self, method, url, params, headers): def test_too_many_bytes(self): @contextmanager - def server(): + def server(port): def _run(server_sock): while True: try: @@ -671,12 +671,12 @@ def _run(server_sock): connection_t.start() with shutdown(get_new_socket()) as server_sock: - server_sock.bind(('127.0.0.1', 9001)) + server_sock.bind(('127.0.0.1', port)) server_sock.listen(socket.IPPROTO_TCP) threading.Thread(target=_run, args=(server_sock,)).start() yield server_sock - def get_http_client(): + def get_http_client(port): @contextmanager def client(): with httpx.Client() as original_client: @@ -684,9 +684,9 @@ class Client(): @contextmanager def stream(self, method, url, params, headers): parsed_url = urllib.parse.urlparse(url) - url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9001')) + url = urllib.parse.urlunparse(parsed_url._replace(netloc=f'localhost:{port}')) range_query = dict(headers).get('range') - is_query = range_query and range_query != 'bytes=0-99' + yield_extra = not only_after_header or (range_query and range_query != 'bytes=0-99') headers_proxy_host = tuple((key, value) for key, value in headers if key != 'host') + (('host', 'localhost:9000'),) with original_client.stream(method, url, params=params, headers=headers_proxy_host @@ -694,29 +694,31 @@ def stream(self, method, url, params, headers): chunks = response.iter_bytes() def iter_bytes(chunk_size=None): yield from chunks - if is_query: + if yield_extra: yield b'e' response.iter_bytes = iter_bytes yield response yield Client() return client() - with server() as server_sock: - with get_db([ - ("CREATE TABLE my_table (my_col_a text, my_col_b text);",()), - ] + [ - ("INSERT INTO my_table VALUES " + ','.join(["('some-text-a', 'some-text-b')"] * 500),()), - ]) as db: - put_object_with_versioning('my-bucket', 'my.db', db) - - with sqlite_s3_query('http://localhost:9000/my-bucket/my.db', get_credentials=lambda now: ( - 'us-east-1', - 'AKIAIOSFODNN7EXAMPLE', - 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', - None, - ), get_http_client=get_http_client, get_libsqlite3=get_libsqlite3) as query: - with self.assertRaisesRegex(SQLiteError, 'disk I/O error'): - query('SELECT my_col_a FROM my_table').__enter__() + for only_after_header, port in [(False, 9001), (True, 9002)]: + with self.subTest((only_after_header, port)): + with server(port) as server_sock: + with get_db([ + ("CREATE TABLE my_table (my_col_a text, my_col_b text);",()), + ] + [ + ("INSERT INTO my_table VALUES " + ','.join(["('some-text-a', 'some-text-b')"] * 500),()), + ]) as db: + put_object_with_versioning('my-bucket', 'my.db', db) + + with sqlite_s3_query('http://localhost:9000/my-bucket/my.db', get_credentials=lambda now: ( + 'us-east-1', + 'AKIAIOSFODNN7EXAMPLE', + 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY', + None, + ), get_http_client=functools.partial(get_http_client, port), get_libsqlite3=get_libsqlite3) as query: + with self.assertRaisesRegex(SQLiteError, 'disk I/O error'): + query('SELECT my_col_a FROM my_table').__enter__() def test_disconnection(self): @contextmanager @@ -732,7 +734,7 @@ def _run(server_sock): connection_t.start() with shutdown(get_new_socket()) as server_sock: - server_sock.bind(('127.0.0.1', 9001)) + server_sock.bind(('127.0.0.1', 9003)) server_sock.listen(socket.IPPROTO_TCP) threading.Thread(target=_run, args=(server_sock,)).start() yield server_sock @@ -744,7 +746,7 @@ def client(): class Client(): def stream(self, method, url, headers, params): parsed_url = urllib.parse.urlparse(url) - url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9001')) + url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9003')) headers_proxy_host = tuple((key, value) for key, value in headers if key != 'host') + (('host', 'localhost:9000'),) return original_client.stream(method, url, headers=headers_proxy_host) yield Client()