Skip to content

Commit

Permalink
Move connection to ConnectionStore class
Browse files Browse the repository at this point in the history
  • Loading branch information
bhirsz committed Nov 10, 2023
1 parent 32722b1 commit c4f2a72
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/DatabaseLibrary/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def table_must_exist(
| Table Must Exist | first_name | msg=my error message |
"""
logger.info(f"Executing : Table Must Exist | {tableName}")
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
if db_connection.module_name in ["cx_Oracle", "oracledb"]:
query = (
"SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND "
Expand Down
102 changes: 60 additions & 42 deletions src/DatabaseLibrary/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,66 @@ class Connection:
module_name: str


class ConnectionManager:
"""
Connection Manager handles the connection & disconnection to the database.
"""

class ConnectionStore:
def __init__(self):
self.omit_trailing_semicolon: bool = False
self._connections: Dict[str, Connection] = {}
self.default_alias: str = "default"

def _register_connection(self, client: Any, module_name: str, alias: str):
def register_connection(self, client: Any, module_name: str, alias: str):
if alias in self._connections:
if alias == self.default_alias:
logger.warn("Overwriting not closed connection.")
else:
logger.warn(f"Overwriting not closed connection for alias = '{alias}'")
self._connections[alias] = Connection(client, module_name)

def get_connection(self, alias: Optional[str]):
"""
Return connection with given alias.
If alias is not provided, it will return default connection.
If there is no default connection, it will return last opened connection.
"""
if not self._connections:
raise ValueError(f"No database connection is open.")
if not alias:
if self.default_alias in self._connections:
return self._connections[self.default_alias]
return list(self._connections.values())[-1]
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
return self._connections[alias]

def pop_connection(self, alias: Optional[str]):
if not self._connections:
return None
if not alias:
alias = self.default_alias
if alias not in self._connections:
alias = list(self._connections.keys())[-1]
return self._connections.pop(alias, None)

def clear(self):
self._connections = {}

def switch(self, alias: str):
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
self.default_alias = alias

def __iter__(self):
return iter(self._connections.values())


class ConnectionManager:
"""
Connection Manager handles the connection & disconnection to the database.
"""

def __init__(self):
self.omit_trailing_semicolon: bool = False
self.connection_store: ConnectionStore = ConnectionStore()

def connect_to_database(
self,
dbapiModuleName: Optional[str] = None,
Expand Down Expand Up @@ -279,7 +321,7 @@ def connect_to_database(
host=dbHost,
port=dbPort,
)
self._register_connection(db_connection, db_api_module_name, alias)
self.connection_store.register_connection(db_connection, db_api_module_name, alias)

def connect_to_database_using_custom_params(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
Expand Down Expand Up @@ -317,7 +359,7 @@ def connect_to_database_using_custom_params(
)

db_connection = eval(db_connect_string)
self._register_connection(db_connection, db_api_module_name, alias)
self.connection_store.register_connection(db_connection, db_api_module_name, alias)

def connect_to_database_using_custom_connection_string(
self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default"
Expand All @@ -341,7 +383,7 @@ def connect_to_database_using_custom_connection_string(
f"'{db_connect_string}')"
)
db_connection = db_api_2.connect(db_connect_string)
self._register_connection(db_connection, db_api_module_name, alias)
self.connection_store.register_connection(db_connection, db_api_module_name, alias)

def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = None):
"""
Expand All @@ -356,19 +398,14 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias:
| Disconnect From Database | alias=my_alias | # disconnects from current connection to the database |
"""
logger.info("Executing : Disconnect From Database")
if not alias:
if not self._connections or self.default_alias in self._connections:
alias = self.default_alias
else:
alias = list(self._connections.keys())[-1]
try:
db_connection = self._connections.pop(alias)
db_connection.client.close()
except KeyError: # Non-existing alias
db_connection = self.connection_store.pop_connection(alias)
if db_connection is None:
log_msg = "No open database connection to close"
if error_if_no_connection:
raise ConnectionError(log_msg) from None
logger.info(log_msg)
else:
db_connection.client.close()

def disconnect_from_all_databases(self):
"""
Expand All @@ -378,9 +415,9 @@ def disconnect_from_all_databases(self):
| Disconnect From All Databases | # disconnects from all connections to the database |
"""
logger.info("Executing : Disconnect From All Databases")
for db_connection in self._connections.values():
for db_connection in self.connection_store:
db_connection.client.close()
self._connections = {}
self.connection_store.clear()

def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None):
"""
Expand All @@ -400,7 +437,7 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None):
| Set Auto Commit | False
"""
logger.info("Executing : Set Auto Commit")
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
db_connection.client.autocommit = autoCommit

def switch_database(self, alias: str):
Expand All @@ -411,23 +448,4 @@ def switch_database(self, alias: str):
| Switch Database | my_alias |
| Switch Database | alias=my_alias |
"""
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
self.default_alias = alias

def _get_connection_with_alias(self, alias: Optional[str]) -> Connection:
"""
Return connection with given alias.
If alias is not provided, it will return default connection.
If there is no default connection, it will return last opened connection.
"""
if not self._connections:
raise ValueError(f"No database connection is open.")
if not alias:
if self.default_alias in self._connections:
return self._connections[self.default_alias]
return list(self._connections.values())[-1]
if alias not in self._connections:
raise ValueError(f"Alias '{alias}' not found in existing connections.")
return self._connections[alias]
self.connection_store.switch(alias)
14 changes: 7 additions & 7 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def query(
Using optional ``sansTran`` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Query | SELECT * FROM person | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
Expand Down Expand Up @@ -110,7 +110,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
Using optional ``sansTran`` to run command without an explicit transaction commit or rollback:
| ${rowCount} | Row Count | SELECT * FROM person | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
Expand Down Expand Up @@ -149,7 +149,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Description | SELECT * FROM person | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
Expand Down Expand Up @@ -187,7 +187,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Delete All Rows From Table | person | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
cur = None
query = f"DELETE FROM {tableName}"
try:
Expand Down Expand Up @@ -265,7 +265,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
with open(sqlScriptFileName, encoding="UTF-8") as sql_file:
cur = None
try:
Expand Down Expand Up @@ -351,7 +351,7 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
Expand Down Expand Up @@ -407,7 +407,7 @@ def call_stored_procedure(
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True |
"""
db_connection = self._get_connection_with_alias(alias)
db_connection = self.connection_store.get_connection(alias)
if spParams is None:
spParams = []
cur = None
Expand Down

0 comments on commit c4f2a72

Please sign in to comment.