Skip to content

Commit

Permalink
- Bugfix: schema and con params conflict in list_tables(), list_views…
Browse files Browse the repository at this point in the history
…(), list_schema_objects()
  • Loading branch information
joeflack4 committed Jan 24, 2024
1 parent 853d538 commit 1c1ad33
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion backend/db/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _current_counts(
return df
# Get tables
with get_db_connection(schema=schema, local=local) as con:
tables: List[str] = list_tables(con, _filter_temp_refresh_tables=filter_temp_refresh_tables)
tables: List[str] = list_tables(con, filter_temp_refresh_tables=filter_temp_refresh_tables)
with get_db_connection(schema='', local=local) as con:
# Get previous counts
timestamps: List[datetime] = [
Expand Down
9 changes: 5 additions & 4 deletions backend/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def run_sql(con: Connection, command: str, params: Dict = {}) -> Any:

# todo: should add a 'include_views' param because it is ambiguous as to whether or not this returns views. it does.
def list_schema_objects(
con: Connection = None, schema=SCHEMA, filter_views=False, filter_sequences=False,
con: Connection = None, schema: str = None, filter_views=False, filter_sequences=False,
filter_temp_refresh_objects=False, filter_tables=False, names_only=False, verbose=True
) -> Union[List[Row], List[str]]:
"""Show tables
Expand All @@ -577,6 +577,7 @@ def list_schema_objects(
"""
if con and schema:
raise ValueError('`con` and `schema` params should not both be passed; choose one')
schema = SCHEMA if not schema else schema
conn = con if con else get_db_connection(schema=schema)
# Query
query = """
Expand Down Expand Up @@ -616,7 +617,7 @@ def list_schema_objects(
return res


def list_views(con: Connection = None, schema: str = SCHEMA, filter_temp_refresh_views=False) -> List[str]:
def list_views(con: Connection = None, schema: str = None, filter_temp_refresh_views=False) -> List[str]:
"""Get list of names of views in schema"""
return list_schema_objects(con, schema, False, True, filter_temp_refresh_views, True, True, False)

Expand Down Expand Up @@ -707,9 +708,9 @@ def load_csv(
update_db_status_var(f'last_updated_{table}', str(current_datetime()), local)


def list_tables(con: Connection = None, schema: str = SCHEMA, _filter_temp_refresh_tables=False) -> List[str]:
def list_tables(con: Connection = None, schema: str = None, filter_temp_refresh_tables=False) -> List[str]:
"""List table names in schema"""
return list_schema_objects(con, schema, True, True, _filter_temp_refresh_tables, names_only=True, verbose=False)
return list_schema_objects(con, schema, True, True, filter_temp_refresh_tables, names_only=True, verbose=False)


def get_ddl_statements(
Expand Down

0 comments on commit 1c1ad33

Please sign in to comment.