Skip to content

Commit

Permalink
Merge pull request #92 from michalc/test/add-case-for-coverage
Browse files Browse the repository at this point in the history
test: increase test coverage by adding an extra case
  • Loading branch information
michalc authored Mar 16, 2024
2 parents 7fe8861 + c2a2b10 commit 178d09b
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def stream(self, method, url, params, headers):

def test_too_many_bytes(self):
@contextmanager
def server():
def server(port):
def _run(server_sock):
while True:
try:
Expand All @@ -671,52 +671,54 @@ def _run(server_sock):
connection_t.start()

with shutdown(get_new_socket()) as server_sock:
server_sock.bind(('127.0.0.1', 9001))
server_sock.bind(('127.0.0.1', port))
server_sock.listen(socket.IPPROTO_TCP)
threading.Thread(target=_run, args=(server_sock,)).start()
yield server_sock

def get_http_client():
def get_http_client(port):
@contextmanager
def client():
with httpx.Client() as original_client:
class Client():
@contextmanager
def stream(self, method, url, params, headers):
parsed_url = urllib.parse.urlparse(url)
url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9001'))
url = urllib.parse.urlunparse(parsed_url._replace(netloc=f'localhost:{port}'))
range_query = dict(headers).get('range')
is_query = range_query and range_query != 'bytes=0-99'
yield_extra = not only_after_header or (range_query and range_query != 'bytes=0-99')
headers_proxy_host = tuple((key, value) for key, value in headers if key != 'host') + (('host', 'localhost:9000'),)
with original_client.stream(method, url,
params=params, headers=headers_proxy_host
) as response:
chunks = response.iter_bytes()
def iter_bytes(chunk_size=None):
yield from chunks
if is_query:
if yield_extra:
yield b'e'
response.iter_bytes = iter_bytes
yield response
yield Client()
return client()

with server() as server_sock:
with get_db([
("CREATE TABLE my_table (my_col_a text, my_col_b text);",()),
] + [
("INSERT INTO my_table VALUES " + ','.join(["('some-text-a', 'some-text-b')"] * 500),()),
]) as db:
put_object_with_versioning('my-bucket', 'my.db', db)

with sqlite_s3_query('http://localhost:9000/my-bucket/my.db', get_credentials=lambda now: (
'us-east-1',
'AKIAIOSFODNN7EXAMPLE',
'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
None,
), get_http_client=get_http_client, get_libsqlite3=get_libsqlite3) as query:
with self.assertRaisesRegex(SQLiteError, 'disk I/O error'):
query('SELECT my_col_a FROM my_table').__enter__()
for only_after_header, port in [(False, 9001), (True, 9002)]:
with self.subTest((only_after_header, port)):
with server(port) as server_sock:
with get_db([
("CREATE TABLE my_table (my_col_a text, my_col_b text);",()),
] + [
("INSERT INTO my_table VALUES " + ','.join(["('some-text-a', 'some-text-b')"] * 500),()),
]) as db:
put_object_with_versioning('my-bucket', 'my.db', db)

with sqlite_s3_query('http://localhost:9000/my-bucket/my.db', get_credentials=lambda now: (
'us-east-1',
'AKIAIOSFODNN7EXAMPLE',
'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
None,
), get_http_client=functools.partial(get_http_client, port), get_libsqlite3=get_libsqlite3) as query:
with self.assertRaisesRegex(SQLiteError, 'disk I/O error'):
query('SELECT my_col_a FROM my_table').__enter__()

def test_disconnection(self):
@contextmanager
Expand All @@ -732,7 +734,7 @@ def _run(server_sock):
connection_t.start()

with shutdown(get_new_socket()) as server_sock:
server_sock.bind(('127.0.0.1', 9001))
server_sock.bind(('127.0.0.1', 9003))
server_sock.listen(socket.IPPROTO_TCP)
threading.Thread(target=_run, args=(server_sock,)).start()
yield server_sock
Expand All @@ -744,7 +746,7 @@ def client():
class Client():
def stream(self, method, url, headers, params):
parsed_url = urllib.parse.urlparse(url)
url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9001'))
url = urllib.parse.urlunparse(parsed_url._replace(netloc='localhost:9003'))
headers_proxy_host = tuple((key, value) for key, value in headers if key != 'host') + (('host', 'localhost:9000'),)
return original_client.stream(method, url, headers=headers_proxy_host)
yield Client()
Expand Down

0 comments on commit 178d09b

Please sign in to comment.