Skip to content

Commit

Permalink
Add session setting support for specs (#381)
Browse files Browse the repository at this point in the history
This adds support for session settings for specs. The session settings
will be applied to each connection before executing the spec.

Co-authored-by: Mathias Fussenegger <[email protected]>
  • Loading branch information
mkleen and mfussenegger authored Sep 9, 2024
1 parent 9d07688 commit d910d77
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 40 deletions.
4 changes: 3 additions & 1 deletion cr8/bench_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def from_dict(d):


class Spec:
def __init__(self, setup, teardown, queries=None, load_data=None, meta=None):
def __init__(self, setup, teardown, queries=None, load_data=None, meta=None, session_settings=None):
self.setup = setup
self.teardown = teardown
self.queries = queries
self.load_data = load_data
self.meta = meta or {}
self.session_settings = session_settings or {}

@staticmethod
def from_dict(d):
Expand All @@ -45,6 +46,7 @@ def from_dict(d):
meta=d.get('meta', {}),
queries=d.get('queries', []),
load_data=d.get('load_data', []),
session_settings=d.get('session_settings', {}),
)

@staticmethod
Expand Down
90 changes: 61 additions & 29 deletions cr8/clients.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json

import aiohttp
import itertools
import calendar
import types
import time
from urllib.parse import urlparse, parse_qs, urlunparse
from datetime import datetime, date
from typing import List, Union, Iterable
from typing import List, Union, Iterable, Dict
from decimal import Decimal
from cr8.aio import asyncio # import via aio for uvloop setup

Expand Down Expand Up @@ -216,18 +217,25 @@ def _verify_ssl_from_first(hosts):


class AsyncpgClient:
def __init__(self, hosts, pool_size=25):
def __init__(self, hosts, pool_size=25, session_settings=None):
self.dsn = _to_dsn(hosts)
self.pool_size = pool_size
self._pool = None
self.is_cratedb = True
self.session_settings = session_settings or {}

async def _get_pool(self):

async def set_session_settings(conn):
for setting, value in self.session_settings.items():
await conn.execute(f'set {setting}={value}')

if not self._pool:
self._pool = await asyncpg.create_pool(
self.dsn,
min_size=self.pool_size,
max_size=self.pool_size
max_size=self.pool_size,
init=set_session_settings
)
return self._pool

Expand Down Expand Up @@ -308,59 +316,83 @@ def _append_sql(host):


class HttpClient:
def __init__(self, hosts, conn_pool_limit=25):
def __init__(self, hosts, conn_pool_limit=25, session_settings=None):
self.hosts = hosts
self.urls = itertools.cycle(list(map(_append_sql, hosts)))
self._connector_params = {
'limit': conn_pool_limit,
'verify_ssl': _verify_ssl_from_first(self.hosts)
}
self.__session = None
self.conn_pool_limit = conn_pool_limit
self.is_cratedb = True

@property
async def _session(self):
session = self.__session
if session is None:
conn = aiohttp.TCPConnector(**self._connector_params)
self.__session = session = aiohttp.ClientSession(connector=conn)
return session
self._pools = {}
self.session_settings = session_settings or {}

async def _session(self, url):
pool = self._pools.get(url)
if not pool:
pool = asyncio.Queue(maxsize=self.conn_pool_limit)
self._pools[url] = pool
_connector_params = {
'limit': 1,
'verify_ssl': _verify_ssl_from_first(self.hosts)
}
for n in range(0, self.conn_pool_limit):
tcp_connector = aiohttp.TCPConnector(**_connector_params)
session = aiohttp.ClientSession(connector=tcp_connector)
for setting, value in self.session_settings.items():
payload = {'stmt': f'set {setting}={value}'}
await _exec(
session,
url,
dumps(payload, cls=CrateJsonEncoder)
)
pool.put_nowait(session)

return await pool.get()

async def execute(self, stmt, args=None):
payload = {'stmt': _plain_or_callable(stmt)}
if args:
payload['args'] = _plain_or_callable(args)
session = await self._session
return await _exec(
url = next(self.urls)
session = await self._session(url)
result = await _exec(
session,
next(self.urls),
url,
dumps(payload, cls=CrateJsonEncoder)
)
await self._pools[url].put(session)
return result

async def execute_many(self, stmt, bulk_args):
data = dumps(dict(
stmt=_plain_or_callable(stmt),
bulk_args=_plain_or_callable(bulk_args)
), cls=CrateJsonEncoder)
session = await self._session
return await _exec(session, next(self.urls), data)
url = next(self.urls)
session = await self._session(url)
result = await _exec(session, url, data)
await self._pools[url].put(session)
return result

async def get_server_version(self):
session = await self._session
urlparts = urlparse(self.hosts[0])
url = urlunparse((urlparts.scheme, urlparts.netloc, '/', '', '', ''))
session = await self._session(url)
async with session.get(url) as resp:
r = await resp.json()
version = r['version']
return {
result = {
'hash': version['build_hash'],
'number': version['number'],
'date': _date_or_none(version['build_timestamp'][:10])
}
await self._pools[url].put(session)
return result

async def _close(self):
if self.__session:
await self.__session.close()
for url, pool in self._pools.items():
while not pool.empty():
session = await pool.get()
await session.close()
self._pools = {}

def close(self):
asyncio.get_event_loop().run_until_complete(self._close())
Expand All @@ -372,10 +404,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.close()


def client(hosts, concurrency=25):
def client(hosts, session_settings=None, concurrency=25):
hosts = hosts or 'localhost:4200'
if hosts.startswith('asyncpg://'):
if not asyncpg:
raise ValueError('Cannot use "asyncpg" scheme if asyncpg is not available')
return AsyncpgClient(hosts, pool_size=concurrency)
return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency)
return AsyncpgClient(hosts, pool_size=concurrency, session_settings=session_settings)
return HttpClient(_to_http_hosts(hosts), conn_pool_limit=concurrency, session_settings=session_settings)
7 changes: 3 additions & 4 deletions cr8/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from cr8 import aio
from cr8.metrics import Stats, get_sampler
from cr8.clients import client

from cr8.clients import client, HttpClient

TimedStats = namedtuple('TimedStats', ['started', 'ended', 'stats'])

Expand Down Expand Up @@ -69,9 +68,9 @@ def _generate_statements(stmt, args, iterations, duration):


class Runner:
def __init__(self, hosts, concurrency, sample_mode):
def __init__(self, hosts, concurrency, sample_mode, session_settings=None):
self.concurrency = concurrency
self.client = client(hosts, concurrency=concurrency)
self.client = client(hosts, session_settings=session_settings, concurrency=concurrency)
self.sampler = get_sampler(sample_mode)

def warmup(self, stmt, num_warmup, concurrency=0, args=None):
Expand Down
6 changes: 3 additions & 3 deletions cr8/run_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _skip_message(self, min_version, stmt):
server_version='.'.join((str(x) for x in self.server_version)))
return msg

def run_queries(self, queries: Iterable[dict], meta=None):
def run_queries(self, queries: Iterable[dict], meta=None, session_settings=None):
for query in queries:
stmt = query['statement']
iterations = query.get('iterations', 1)
Expand All @@ -204,7 +204,7 @@ def run_queries(self, queries: Iterable[dict], meta=None):
f' Concurrency: {concurrency}\n'
f' {mode_desc}: {duration or iterations}')
)
with Runner(self.benchmark_hosts, concurrency, self.sample_mode) as runner:
with Runner(self.benchmark_hosts, concurrency, self.sample_mode, session_settings) as runner:
if warmup > 0:
runner.warmup(stmt, warmup, concurrency, args)
timed_stats = runner.run(
Expand Down Expand Up @@ -266,7 +266,7 @@ def do_run_spec(spec,
queries = (q for q in spec.queries if 'name' in q and rex.match(q['name']))
else:
queries = spec.queries
executor.run_queries(queries, spec.meta)
executor.run_queries(queries, spec.meta, spec.session_settings)
finally:
if not action or 'teardown' in action:
log.info('# Running tearDown')
Expand Down
4 changes: 4 additions & 0 deletions specs/count_countries.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
}
]
},
"session_settings": {
"application_name": "my_app",
"timezone": "UTC"
},
"queries": [{
"iterations": 1000,
"statement": "select count(*) from countries"
Expand Down
4 changes: 2 additions & 2 deletions specs/sample.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

from itertools import count
from cr8.bench_spec import Spec, Instructions

Expand All @@ -21,4 +20,5 @@ def queries():
setup=Instructions(statements=["create table t (x int)"]),
teardown=Instructions(statements=["drop table t"]),
queries=queries(),
)
session_settings={'application_name': 'my_app', 'timezone': 'UTC'}
)
4 changes: 4 additions & 0 deletions specs/sample.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ statement_files = ["sql/create_countries.sql"]
target = "countries"
cmd = ['echo', '{"capital": "Demo"}']

[session_settings]
application_name = 'my_app'
timezone = 'UTC'

[[queries]]
name = "count countries" # Can be used to give the queries a name for easier analytics of the results
statement = "select count(*) from countries"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def parse(self, string, name='<string>'):
class SourceBuildTest(TestCase):

def test_build_from_branch(self):
self.assertIsNotNone(get_crate('4.1'))
self.assertIsNotNone(get_crate('5.8'))


def load_tests(loader, tests, ignore):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from unittest import TestCase
from doctest import DocTestSuite

from cr8.bench_spec import load_spec

from cr8 import engine


class SpecTest(TestCase):

def test_session_settings_from_spec(self):
spec = self.get_spec('sample.py')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def test_session_settings_from_toml(self):
spec = self.get_spec('sample.toml')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def test_session_settings_from_json(self):
spec = self.get_spec('count_countries.json')
self.assertEqual(spec.session_settings, {'application_name': 'my_app', 'timezone': 'UTC'})

def get_spec(self, name):
return load_spec(os.path.abspath(os.path.join(os.path.dirname(__file__), '../specs/', name)))


def load_tests(loader, tests, ignore):
tests.addTests(DocTestSuite(engine))
return tests

0 comments on commit d910d77

Please sign in to comment.