Skip to content

Commit

Permalink
Merge pull request #40 from mkmkme/mkmkme/python312
Browse files Browse the repository at this point in the history
Make pytest pass locally with Python 3.12
  • Loading branch information
aiven-sal authored Jul 11, 2024
2 parents f1898b9 + e001c88 commit 9c5e2b7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
9 changes: 8 additions & 1 deletion tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, addr, valkey_addr):
self.send_event = asyncio.Event()
self.server = None
self.task = None
self.pipes = None
self.n_connections = 0

async def start(self):
Expand All @@ -79,12 +80,18 @@ async def handle(self, reader, writer):
self.n_connections += 1
pipe1 = asyncio.create_task(self.pipe(reader, valkey_writer))
pipe2 = asyncio.create_task(self.pipe(valkey_reader, writer))
await asyncio.gather(pipe1, pipe2)
self.pipes = asyncio.gather(pipe1, pipe2)
await self.pipes
except asyncio.CancelledError:
writer.close()
finally:
valkey_writer.close()

async def aclose(self):
self.task.cancel()
# self.pipes can be None if handle was never called
if self.pipes is not None:
self.pipes.cancel()
try:
await self.task
except asyncio.CancelledError:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_asyncio/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ async def _handler(reader, writer):
await aserver.start_serving()
try:
await conn.connect()
stop_event.set()
await conn.disconnect()
except ConnectionError:
finished.set()
raise
finally:
stop_event.set()
stop_event.set() # Set stop_event in case of a connection error
aserver.close()
await aserver.wait_closed()
await finished.wait()
Expand Down
20 changes: 13 additions & 7 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ def test_tcp_ssl_tls12_custom_ciphers(tcp_address, ssl_ciphers):
ssl_min_version=ssl.TLSVersion.TLSv1_2,
ssl_ciphers=ssl_ciphers,
)
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile)
_assert_connect(
conn,
tcp_address,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.TLSVersion.TLSv1_2,
)


@pytest.mark.ssl
Expand All @@ -118,7 +124,7 @@ def test_tcp_ssl_version_mismatch(tcp_address):
tcp_address,
certfile=certfile,
keyfile=keyfile,
ssl_version=ssl.PROTOCOL_TLSv1_2,
ssl_version=ssl.TLSVersion.TLSv1_3,
)


Expand Down Expand Up @@ -147,7 +153,7 @@ def __init__(
*args,
certfile=None,
keyfile=None,
ssl_version=ssl.PROTOCOL_TLS,
ssl_version=ssl.TLSVersion.TLSv1,
**kw,
) -> None:
self._ready_event = threading.Event()
Expand All @@ -174,12 +180,12 @@ def get_request(self):
if self._certfile is None:
return super().get_request()
newsocket, fromaddr = self.socket.accept()
connstream = ssl.wrap_socket(
sslctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
sslctx.load_cert_chain(self._certfile, self._keyfile)
sslctx.minimum_version = self._ssl_version
connstream = sslctx.wrap_socket(
newsocket,
server_side=True,
certfile=self._certfile,
keyfile=self._keyfile,
ssl_version=self._ssl_version,
)
return connstream, fromaddr

Expand Down
19 changes: 17 additions & 2 deletions valkey/commands/graph/query_result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
from collections import OrderedDict
from distutils.util import strtobool

# from prettytable import PrettyTable
from valkey import ResponseError
Expand Down Expand Up @@ -39,6 +38,22 @@
]


def strtobool(value: str) -> bool:
"""
Convert a string representation of truth to True or False.
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'value' is anything else.
"""
value = value.lower()
if value in ("y", "yes", "t", "true", "on", "1"):
return True
elif value in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError(f"invalid truth value {value}")


class ResultSetColumnTypes:
COLUMN_UNKNOWN = 0
COLUMN_SCALAR = 1
Expand Down Expand Up @@ -270,7 +285,7 @@ def parse_boolean(self, value):
"""
value = value.decode() if isinstance(value, bytes) else value
try:
scalar = True if strtobool(value) else False
scalar = strtobool(value)
except ValueError:
sys.stderr.write("unknown boolean type\n")
scalar = None
Expand Down

0 comments on commit 9c5e2b7

Please sign in to comment.