Skip to content

Commit

Permalink
Merge pull request #6 from Point72/tkp/sqla2
Browse files Browse the repository at this point in the history
Support sqlalchemy>=2
  • Loading branch information
timkpaine authored Feb 7, 2024
2 parents edabf80 + 01b199d commit a5a59d3
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 34 deletions.
66 changes: 66 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,69 @@ jobs:
# run: make test

##########################################################################################################################


#################################
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##|~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~|##|~~~~~~~~~~#
#~~~~~~~~~|##|~~~~|##|~~~~~~~~~~#
#~~~~~~~~~|##########|~~~~~~~~~~#
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
# Test Dependencies/Regressions #
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#
test_dependencies:
needs:
- initialize
- build

strategy:
matrix:
os:
- ubuntu-20.04
python-version:
- 3.9
package:
- "sqlalchemy>=2"
- "sqlalchemy<2"

runs-on: ${{ matrix.os }}

steps:
- name: Checkout
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Python ${{ matrix.python-version }}
uses: ./.github/actions/setup-python
with:
version: '${{ matrix.python-version }}'

- name: Install python dependencies
run: make requirements

- name: Install test dependencies
shell: bash
run: sudo apt-get install graphviz

# Download artifact
- name: Download wheel
uses: actions/download-artifact@v4
with:
name: csp-dist-${{ runner.os }}-${{ runner.arch }}-${{ matrix.python-version }}

- name: Install wheel
run: python -m pip install -U *manylinux2014*.whl --target .

- name: Install package - ${{ matrix.package }}
run: python -m pip install -U "${{ matrix.package }}"

# Run tests
- name: Python Test Steps
run: make test TEST_ARGS="-k TestDBReader"
if: ${{ contains( 'sqlalchemy', matrix.package )}}
7 changes: 4 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ EXTRA_ARGS :=
#########
.PHONY: develop build-py build install

requirements: ## install python dev dependnecies
requirements: ## install python dev and runtime dependencies
python -m pip install toml
python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["build-system"]["requires"]))'`
python -m pip install `python -c 'import toml; c = toml.load("pyproject.toml"); print("\n".join(c["project"]["optional-dependencies"]["develop"]))'`
Expand Down Expand Up @@ -64,11 +64,12 @@ checks: check
#########
.PHONY: test-py coverage-py test tests

TEST_ARGS :=
test-py: ## Clean and Make unit tests
python -m pytest -v csp/tests --junitxml=junit.xml
python -m pytest -v csp/tests --junitxml=junit.xml $(TEST_ARGS)

coverage-py:
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing
python -m pytest -v csp/tests --junitxml=junit.xml --cov=csp --cov-report xml --cov-report html --cov-branch --cov-fail-under=80 --cov-report term-missing $(TEST_ARGS)

test: test-py ## run the tests

Expand Down
72 changes: 59 additions & 13 deletions csp/adapters/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,28 @@
from backports import zoneinfo

import pytz
import sqlalchemy as db
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version

from csp import PushMode, ts
from csp.impl.adaptermanager import AdapterManagerImpl, ManagedSimInputAdapter
from csp.impl.wiring import py_managed_adapter_def

UTC = zoneinfo.ZoneInfo("UTC")

try:
if version.parse(get_package_version("sqlalchemy")) >= version.parse("2"):
_SQLALCHEMY_2 = True
else:
_SQLALCHEMY_2 = False

import sqlalchemy as db

_HAS_SQLALCHEMY = True
except (PackageNotFoundError, ValueError, TypeError, ImportError):
_HAS_SQLALCHEMY = False
db = None


class TimeAccessor(ABC):
@abstractmethod
Expand Down Expand Up @@ -185,6 +199,8 @@ def __init__(
:param log_query: set to True to see what query was generated to access the data
:param use_raw_user_query: Don't do any alteration to user query, assume it contains all the needed columns and sorting
"""
if not _HAS_SQLALCHEMY:
raise RuntimeError("Could not find SQLAlchemy installation")
self._connection = connection
self._table_name = table_name
self._schema_name = schema_name
Expand Down Expand Up @@ -248,7 +264,7 @@ def schema_struct(self):
name = "DBDynStruct_{table}_{schema}".format(table=self._table_name or "", schema=self._schema_name or "")
if name not in globals():
db_metadata = db.MetaData(schema=self._schema_name)
table = db.Table(self._table_name, db_metadata, autoload=True, autoload_with=self._connection)
table = db.Table(self._table_name, db_metadata, autoload_with=self._connection)
struct_metadata = {col: col_obj.type.python_type for col, col_obj in table.columns.items()}

from csp.impl.struct import defineStruct
Expand Down Expand Up @@ -301,23 +317,44 @@ def __init__(self, engine, adapterRep):
self._row = None

def start(self, starttime, endtime):
query = self.build_query(starttime, endtime)
self._query = self.build_query(starttime, endtime)
if self._rep._log_query:
import logging

logging.info("DBReader query: %s", query)
self._q = self._rep._connection.execute(query)
logging.info("DBReader query: %s", self._query)
if _SQLALCHEMY_2:
self._data_yielder = self._data_yielder_function()
else:
self._q = self._rep._connection.execute(self._query)

def _data_yielder_function(self):
# Connection yielder for SQLAlchemy 2
with self._rep._connection.connect() as conn:
for result in conn.execute(self._query).mappings():
yield result
# Signify the end
yield None

def build_query(self, starttime, endtime):
if self._rep._table_name:
metadata = db.MetaData(schema=self._rep._schema_name)
table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(cols)

if _SQLALCHEMY_2:
table = db.Table(self._rep._table_name, metadata, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(*cols)
else:
table = db.Table(self._rep._table_name, metadata, autoload=True, autoload_with=self._rep._connection)
cols = [table.c[colname] for colname in self._rep._requested_cols]
q = db.select(cols)

elif self._rep._use_raw_user_query:
return db.text(self._rep._query)
else: # self._rep._query
from_obj = db.text(f"({self._rep._query}) AS user_query")
if _SQLALCHEMY_2:
from_obj = db.text(f"({self._rep._query})")
else:
from_obj = db.text(f"({self._rep._query}) AS user_query")

time_columns = self._rep._time_accessor.get_time_columns(self._rep._connection)
if time_columns:
Expand All @@ -330,7 +367,11 @@ def build_query(self, starttime, endtime):
time_columns = []
time_select = []
select_cols = [db.column(colname) for colname in self._rep._requested_cols.difference(set(time_columns))]
q = db.select(select_cols + time_select, from_obj=from_obj)

if _SQLALCHEMY_2:
q = db.select(*(select_cols + time_select)).select_from(from_obj)
else:
q = db.select(select_cols + time_select, from_obj=from_obj)

cond = self._rep._time_accessor.get_time_constraint(starttime.replace(tzinfo=UTC), endtime.replace(tzinfo=UTC))

Expand Down Expand Up @@ -361,16 +402,21 @@ def register_input_adapter(self, symbol, adapter):

def process_next_sim_timeslice(self, now):
if self._row is None:
self._row = self._q.fetchone()
if _SQLALCHEMY_2:
self._row = next(self._data_yielder)
else:
self._row = self._q.fetchone()

now = now.replace(tzinfo=UTC)
while self._row is not None:
time = self._rep._time_accessor.get_time(self._row)
if time > now:
return time
self.process_row(self._row)
self._row = self._q.fetchone()

if _SQLALCHEMY_2:
self._row = next(self._data_yielder)
else:
self._row = self._q.fetchone()
return None

def process_row(self, row):
Expand Down
13 changes: 8 additions & 5 deletions csp/adapters/output_adapters/parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy
import os
import pkg_resources
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version
from typing import Callable, Dict, Optional, TypeVar

Expand Down Expand Up @@ -37,10 +37,13 @@ def resolve_compression(self):


def _get_default_parquet_version():
if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("6.0.1"):
return "2.6"
else:
return "2.0"
try:
if version.parse(get_package_version("pyarrow")) >= version.parse("6.0.1"):
return "2.6"
except PackageNotFoundError:
# Don't need to do anything in particular
...
return "2.0"


class ParquetWriter:
Expand Down
6 changes: 3 additions & 3 deletions csp/adapters/parquet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import datetime
import io
import numpy
import pkg_resources
import platform
import pyarrow
import pyarrow.parquet
from importlib.metadata import PackageNotFoundError, version as get_package_version
from packaging import version
from typing import TypeVar

Expand All @@ -28,9 +28,9 @@

try:
_CAN_READ_ARROW_BINARY = False
if version.parse(pkg_resources.get_distribution("pyarrow").version) >= version.parse("4.0.1"):
if version.parse(get_package_version("pyarrow")) >= version.parse("4.0.1"):
_CAN_READ_ARROW_BINARY = True
except (ValueError, TypeError):
except (PackageNotFoundError, ValueError, TypeError):
# Cannot read binary arrow
...

Expand Down
27 changes: 19 additions & 8 deletions csp/tests/adapters/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import date, datetime, time

import csp
from csp.adapters.db import DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor
from csp.adapters.db import _SQLALCHEMY_2, DateTimeAccessor, DBReader, EngineStartTimeAccessor, TimestampAccessor


class PriceQuantity(csp.Struct):
Expand All @@ -21,6 +21,15 @@ class PriceQuantity2(csp.Struct):
side: str


def execute_with_commit(engine, query, values):
if _SQLALCHEMY_2:
with engine.connect() as conn:
conn.execute(query, values)
conn.commit()
else:
engine.execute(query, values)


class TestDBReader(unittest.TestCase):
def _prepopulate_in_mem_engine(self):
engine = db.create_engine("sqlite:///:memory:") # in-memory sqlite db
Expand All @@ -46,7 +55,7 @@ def _prepopulate_in_mem_engine(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0, "SIZE": 400, "SIDE": "BUY"},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0, "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list)
execute_with_commit(engine, query, values_list)
return engine

def test_sqlite_basic(self):
Expand Down Expand Up @@ -92,7 +101,7 @@ def graph():

# UTC
result = csp.run(graph, starttime=datetime(2020, 3, 3, 9, 30))
print(result)

self.assertEqual(len(result["aapl"]), 4)
self.assertTrue(all(v[1].SYMBOL == "AAPL" for v in result["aapl"]))

Expand Down Expand Up @@ -211,7 +220,8 @@ def test_sqlite_constraints(self):
"SIDE": "BUY",
},
]
engine.execute(query, values_list)

execute_with_commit(engine, query, values_list)

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down Expand Up @@ -310,7 +320,7 @@ def test_join_query(self):
{"TIME": starttime.replace(second=5), "SYMBOL": "AAPL", "PRICE": 200.0},
{"TIME": starttime.replace(second=6), "SYMBOL": "GM", "PRICE": 2.0},
]
engine.execute(query, values_list1)
execute_with_commit(engine, query, values_list1)

query = db.insert(test2)
values_list2 = [
Expand All @@ -322,15 +332,15 @@ def test_join_query(self):
# { 'TIME': starttime.replace( second = 5 ), 'SIZE': 400, 'SIDE': 'BUY' },
{"TIME": starttime.replace(second=6), "SIZE": 1, "SIDE": "BUY"},
]
engine.execute(query, values_list2)
execute_with_commit(engine, query, values_list2)

metadata.create_all(engine)

def graph():
time_accessor = TimestampAccessor(time_column="TIME", tz=pytz.timezone("US/Eastern"))
query = "select * from test1 inner join test2 on test2.TIME=test1.TIME"
reader = DBReader.create_from_connection(
connection=engine.connect(), query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
connection=engine, query=query, time_accessor=time_accessor, symbol_column="SYMBOL"
)

# Struct
Expand Down Expand Up @@ -414,7 +424,8 @@ def test_DateTimeAccessor(self):
(datetime(2020, 3, 5, 12), 700.0),
]
values_list = [{"DATE": v[0].date(), "TIME": v[0].time(), "SYMBOL": "AAPL", "PRICE": v[1]} for v in values]
engine.execute(query, values_list)

execute_with_commit(engine, query, values_list)

def graph():
time_accessor = DateTimeAccessor(date_column="DATE", time_column="TIME", tz=pytz.timezone("US/Eastern"))
Expand Down
2 changes: 1 addition & 1 deletion dev-environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- ruamel.yaml
- scikit-build
- psutil
- sqlalchemy<2
- sqlalchemy
- bump2version>=1.0.0
- python-graphviz
- httpx
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ develop = [
"pytest-cov",
"pytest-sugar",
"scikit-build",
"sqlalchemy<2",
"sqlalchemy",
"tornado",
]

Expand Down

0 comments on commit a5a59d3

Please sign in to comment.