Skip to content

Commit

Permalink
Using new begin() syntax of Session for transactional scope handling
Browse files Browse the repository at this point in the history
  • Loading branch information
freol35241 committed Feb 28, 2023
1 parent 301e86c commit 56355ef
Showing 1 changed file with 61 additions and 62 deletions.
123 changes: 61 additions & 62 deletions custom_components/ltss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
EVENT_HOMEASSISTANT_START,
EVENT_HOMEASSISTANT_STOP,
EVENT_STATE_CHANGED,
STATE_UNKNOWN
STATE_UNKNOWN,
)
from homeassistant.components import persistent_notification
from homeassistant.core import CoreState, HomeAssistant, callback
Expand Down Expand Up @@ -56,7 +56,9 @@
DOMAIN: INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA.extend(
{
vol.Required(CONF_DB_URL): cv.string,
vol.Optional(CONF_CHUNK_TIME_INTERVAL, default=2592000000000): cv.positive_int, # 30 days
vol.Optional(
CONF_CHUNK_TIME_INTERVAL, default=2592000000000
): cv.positive_int, # 30 days
}
)
},
Expand Down Expand Up @@ -84,37 +86,15 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
return await instance.async_db_ready


@contextmanager
def session_scope(*, session=None):
"""Provide a transactional scope around a series of operations."""

if session is None:
raise RuntimeError("Session required")

need_rollback = False
try:
yield session
if session.transaction:
need_rollback = True
session.commit()
except Exception as err: # pylint: disable=broad-except
_LOGGER.error("Error executing query: %s", err)
if need_rollback:
session.rollback()
raise
finally:
session.close()


class LTSS_DB(threading.Thread):
"""A threaded LTSS class."""

def __init__(
self,
hass: HomeAssistant,
uri: str,
chunk_time_interval: int,
entity_filter: Callable[[str], bool],
self,
hass: HomeAssistant,
uri: str,
chunk_time_interval: int,
entity_filter: Callable[[str], bool],
) -> None:
"""Initialize the ltss."""
threading.Thread.__init__(self, name="LTSS")
Expand Down Expand Up @@ -158,6 +138,7 @@ def run(self):
tries += 1

if not connected:

@callback
def connection_failed():
"""Connect failed tasks."""
Expand Down Expand Up @@ -222,17 +203,18 @@ def notify_hass_started(event):
if tries != 1:
time.sleep(CONNECT_RETRY_WAIT)
try:
with session_scope(session=self.get_session()) as session:
try:
row = LTSS.from_event(event)
session.add(row)
except (TypeError, ValueError):
_LOGGER.warning(
"State is not JSON serializable: %s",
event.data.get("new_state"),
)

updated = True
with self.get_session() as session:
with session.begin():
try:
row = LTSS.from_event(event)
session.add(row)
except (TypeError, ValueError):
_LOGGER.warning(
"State is not JSON serializable: %s",
event.data.get("new_state"),
)

updated = True

except exc.OperationalError as err:
_LOGGER.error(
Expand Down Expand Up @@ -273,36 +255,49 @@ def _setup_connection(self):
if self.engine is not None:
self.engine.dispose()

self.engine = create_engine(self.db_url, echo=False,
json_serializer=lambda obj: json.dumps(obj, cls=JSONEncoder))
self.engine = create_engine(
self.db_url,
echo=False,
json_serializer=lambda obj: json.dumps(obj, cls=JSONEncoder),
)

inspector = inspect(self.engine)

with self.engine.connect() as con:
con = con.execution_options(isolation_level="AUTOCOMMIT")
available_extensions = {row.name: row.installed_version for row in
con.execute(text("SELECT name, installed_version FROM pg_available_extensions"))}
available_extensions = {
row.name: row.installed_version
for row in con.execute(
text("SELECT name, installed_version FROM pg_available_extensions")
)
}

# create table if necessary
if not inspector.has_table(LTSS.__tablename__):
self._create_table(available_extensions)

if 'timescaledb' in available_extensions:
if "timescaledb" in available_extensions:
# chunk_time_interval can be adjusted even after first setup
try:
con.execute(
text(f"SELECT set_chunk_time_interval('{LTSS.__tablename__}', {self.chunk_time_interval})")
text(
f"SELECT set_chunk_time_interval('{LTSS.__tablename__}', {self.chunk_time_interval})"
)
)
except exc.ProgrammingError as exception:
if isinstance(exception.orig, psycopg2.errors.UndefinedTable):
# The table does exist but is not a hypertable, not much we can do except log that fact
_LOGGER.exception(
"TimescaleDB is available as an extension but the LTSS table is not a hypertable!")
"TimescaleDB is available as an extension but the LTSS table is not a hypertable!"
)
else:
raise

# check if table has been set up with location extraction
if "location" in [column_conf["name"] for column_conf in inspector.get_columns(LTSS.__tablename__)]:
if "location" in [
column_conf["name"]
for column_conf in inspector.get_columns(LTSS.__tablename__)
]:
# activate location extraction in model/ORM
LTSS.activate_location_extraction()

Expand All @@ -315,28 +310,32 @@ def _create_table(self, available_extensions):
_LOGGER.info("Creating LTSS table")
with self.engine.connect() as con:
con = con.execution_options(isolation_level="AUTOCOMMIT")
if 'postgis' in available_extensions:
_LOGGER.info("PostGIS extension is available, activating location extraction...")
con.execute(
text("CREATE EXTENSION IF NOT EXISTS postgis CASCADE"
))
if "postgis" in available_extensions:
_LOGGER.info(
"PostGIS extension is available, activating location extraction..."
)
con.execute(text("CREATE EXTENSION IF NOT EXISTS postgis CASCADE"))

# activate location extraction in model/ORM to add necessary column when calling create_all()
LTSS.activate_location_extraction()

Base.metadata.create_all(self.engine)

if 'timescaledb' in available_extensions:
_LOGGER.info("TimescaleDB extension is available, creating hypertable...")
con.execute(
text("CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE"
))
if "timescaledb" in available_extensions:
_LOGGER.info(
"TimescaleDB extension is available, creating hypertable..."
)
con.execute(text("CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE"))

# Create hypertable
con.execute(text(f"""SELECT create_hypertable(
'{LTSS.__tablename__}',
'time',
if_not_exists => TRUE);"""))
con.execute(
text(
f"""SELECT create_hypertable(
'{LTSS.__tablename__}',
'time',
if_not_exists => TRUE);"""
)
)

def _close_connection(self):
"""Close the connection."""
Expand Down

0 comments on commit 56355ef

Please sign in to comment.