Skip to content

Commit

Permalink
query: Add Query class and begin implementing decode
Browse files Browse the repository at this point in the history
  • Loading branch information
ohsayan committed May 1, 2024
1 parent 5bb7bdc commit 077b7ff
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 52 deletions.
4 changes: 3 additions & 1 deletion src/skytable_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .connection import Config, Connection
from .connection import Connection
from .query import Query, UInt, SInt
from .config import Config
62 changes: 62 additions & 0 deletions src/skytable_py/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2024, Sayan Nandan <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from .connection import Connection
from .exception import ClientException


class Config:
def __init__(self, username: str, password: str, host: str = "127.0.0.1", port: int = 2003) -> None:
self._username = username
self._password = password
self._host = host
self._port = port

def get_username(self) -> str:
return self._username

def get_password(self) -> str:
return self._password

def get_host(self) -> str:
return self._host

def get_port(self) -> int:
return self._port

def __hs(self) -> bytes:
return f"H\0\0\0\0\0{len(self.get_username())}\n{len(self.get_password())}\n{self.get_username()}{self.get_password()}".encode()

async def connect(self) -> Connection:
"""
Establish a connection to the database instance using the set configuration.
## Exceptions
Exceptions are raised in the following scenarios:
- If the server responds with a handshake error
- If the server sends an unknown handshake (usually caused by version incompatibility)
"""
reader, writer = await asyncio.open_connection(self.get_host(), self.get_port())
con = Connection(reader, writer)
await con._write_all(self.__hs())
resp = await con._read_exact(4)
a, b, c, d = resp[0], resp[1], resp[2], resp[3]
if resp == b"H\0\0\0":
return con
elif a == ord(b'H') and b == 0 and c == 1:
raise ClientException(f"handshake error {d}")
else:
raise ClientException("unknown handshake")
114 changes: 63 additions & 51 deletions src/skytable_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from asyncio import StreamReader, StreamWriter


class ClientException(Exception):
"""
An exception thrown by this client library
"""
pass
from .query import Query
from .exception import ProtocolException


class Connection:
Expand All @@ -32,6 +26,8 @@ class Connection:
def __init__(self, reader: StreamReader, writer: StreamWriter) -> None:
self._reader = reader
self._writer = writer
self._cursor = 0
self.buffer = bytes()

async def _write_all(self, bytes: bytes):
self._write(bytes)
Expand All @@ -40,6 +36,9 @@ async def _write_all(self, bytes: bytes):
def _write(self, bytes: bytes) -> None:
self._writer.write(bytes)

def __buffer(self) -> bytes:
return self.buffer[:self._cursor]

async def _flush(self):
await self._writer.drain()

Expand All @@ -53,46 +52,59 @@ async def close(self):
self._writer.close()
await self._writer.wait_closed()


class Config:
def __init__(self, username: str, password: str, host: str = "127.0.0.1", port: int = 2003) -> None:
self._username = username
self._password = password
self._host = host
self._port = port

def get_username(self) -> str:
return self._username

def get_password(self) -> str:
return self._password

def get_host(self) -> str:
return self._host

def get_port(self) -> int:
return self._port

def __hs(self) -> bytes:
return f"H\0\0\0\0\0{len(self.get_username())}\n{len(self.get_password())}\n{self.get_username()}{self.get_password()}".encode()

async def connect(self) -> Connection:
"""
Establish a connection to the database instance using the set configuration.
## Exceptions
Exceptions are raised in the following scenarios:
- If the server responds with a handshake error
- If the server sends an unknown handshake (usually caused by version incompatibility)
"""
reader, writer = await asyncio.open_connection(self.get_host(), self.get_port())
con = Connection(reader, writer)
await con._write_all(self.__hs())
resp = await con._read_exact(4)
a, b, c, d = resp[0], resp[1], resp[2], resp[3]
if resp == b"H\0\0\0":
return con
elif a == ord(b'H') and b == 0 and c == 1:
raise ClientException(f"handshake error {d}")
else:
raise ClientException("unknown handshake")
def __parse_string(self) -> None | str:
strlen = self.__parse_int()
if strlen:
if len(self.__buffer()) >= strlen:
strlen = self.__buffer()[:strlen].decode()
self._cursor += strlen
return strlen

def __parse_binary(self) -> None | bytes:
binlen = self.__parse_int()
if binlen:
if len(self.__buffer()) >= binlen:
binlen = self.__buffer()[:binlen].decode()
self._cursor += binlen
return binlen

def __parse_int(self) -> None | int:
i = 0
strlen = 0
stop = False
buffer = self.__buffer()

while i < len(buffer) and not stop:
digit = None
if 48 <= buffer[i] <= 57:
digit = buffer[i] - 48

if digit is not None:
strlen = (10 * strlen) + digit
i += 1
else:
raise ProtocolException("invalid response from server")

if i < len(buffer) and buffer[i] == ord(b'\n'):
stop = True
i += 1

if stop:
self._cursor += i
self._cursor += 1 # for LF
return strlen

async def run_simple_query(self, query: Query):
query_window_str = str(len(query._q_window))
total_packet_size = len(query_window_str) + 1 + len(query._buffer)
# write metaframe
metaframe = f"S{str(total_packet_size)}\n{query_window_str}\n"
await self._write_all(metaframe.encode())
# write dataframe
await self._write_all(query._buffer)
# now enter read loop
while True:
read = await self._reader.read(1024)
if len(read) == 0:
raise ConnectionResetError
self.buffer = self.buffer + read
28 changes: 28 additions & 0 deletions src/skytable_py/exception.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2024, Sayan Nandan <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.


class ClientException(Exception):
"""
An exception thrown by this client library
"""
pass


class ProtocolException(ClientException):
"""
An exception thrown by the protocol
"""
pass
76 changes: 76 additions & 0 deletions src/skytable_py/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024, Sayan Nandan <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
# internal
from .exception import ClientException


class Query:
def __init__(self, query: str, *argv) -> None:
self._buffer = query.encode()
self._param_cnt = 0
self._q_window = len(self._buffer)
for param in argv:
self.add_param(param)

def add_param(self, param: any) -> None:
payload, param_cnt = encode_parameter(param)
self._param_cnt += param_cnt
self._buffer = self._buffer + payload

def get_param_count(self) -> int:
return self._param_cnt


class SkyhashParameter(ABC):
def encode_self(self) -> tuple[bytes, int]: pass


class UInt(SkyhashParameter):
def __init__(self, v: int) -> None:
if v < 0:
raise ClientException("unsigned int can't be negative")
self.v = v

def encode_self(self) -> tuple[bytes, int]:
return (f"\x02{self.v}\n".encode(), 1)


class SInt(SkyhashParameter):
def __init__(self, v: int) -> None:
self.v = v

def encode_self(self) -> tuple[bytes, int]:
return (f"\x03{self.v}\n".encode(), 1)


def encode_parameter(parameter: any) -> tuple[bytes, int]:
encoded = None
if isinstance(parameter, SkyhashParameter):
return parameter.encode_self()
elif parameter is None:
encoded = "\0".encode()
elif isinstance(parameter, bool):
encoded = f"\1{1 if parameter else 0}".encode()
elif isinstance(parameter, float):
encoded = f"\x04{parameter}\n".encode()
elif isinstance(parameter, bytes):
encoded = f"\x05{len(parameter)}\n".encode() + parameter
elif isinstance(parameter, str):
encoded = f"\x06{len(parameter)}\n{parameter}".encode()
else:
raise ClientException("unsupported type")
return (encoded, 1)
54 changes: 54 additions & 0 deletions tests/test_encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2024, Sayan Nandan <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from src.skytable_py.query import encode_parameter, UInt, SInt
from src.skytable_py.exception import ClientException


class TestConfig(unittest.TestCase):
def test_encode_null(self):
self.assertEqual(encode_parameter(None), (b"\0", 1))

def test_encode_bool(self):
self.assertEqual(encode_parameter(False), (b"\x010", 1))
self.assertEqual(encode_parameter(True), (b"\x011", 1))

def test_encode_uint(self):
self.assertEqual(encode_parameter(UInt(1234)), (b"\x021234\n", 1))

def test_encode_sint(self):
self.assertEqual(encode_parameter(SInt(-1234)), (b"\x03-1234\n", 1))

def test_encode_float(self):
self.assertEqual(encode_parameter(3.141592654),
(b"\x043.141592654\n", 1))

def test_encode_bin(self):
self.assertEqual(encode_parameter(b"binary"), (b"\x056\nbinary", 1))

def test_encode_str(self):
self.assertEqual(encode_parameter("string"), (b"\x066\nstring", 1))

def test_int_causes_exception(self):
try:
encode_parameter(1234)
except ClientException as e:
if str(e) == "unsupported type":
pass
else:
self.fail(f"expected 'unsupported type' but got '{e}'")
else:
self.fail("expected exception but no exception was raised")
Loading

0 comments on commit 077b7ff

Please sign in to comment.