diff --git a/splink/database_api.py b/splink/database_api.py index ba58f0861f..5e8c54ed64 100644 --- a/splink/database_api.py +++ b/splink/database_api.py @@ -339,3 +339,8 @@ def remove_splinkdataframe_from_cache(self, splink_dataframe: SplinkDataFrame): for k in keys_to_delete: del self._intermediate_table_cache[k] + + def delete_tables_created_by_splink_from_db(self): + for splink_df in list(self._intermediate_table_cache.values()): + if splink_df.created_by_splink: + splink_df.drop_table_from_database_and_remove_from_cache() diff --git a/splink/duckdb/database_api.py b/splink/duckdb/database_api.py index e51f010060..24b5081afa 100644 --- a/splink/duckdb/database_api.py +++ b/splink/duckdb/database_api.py @@ -55,17 +55,24 @@ def __init__( """ ) + def delete_table_from_database(self, name: str): + # If the table is in fact a pandas dataframe that's been registered using + # duckdb con.register() then DROP TABLE will fail with + # Catalog Error: x is of type View + try: + drop_sql = f"DROP TABLE IF EXISTS {name}" + self._execute_sql_against_backend(drop_sql) + except duckdb.CatalogException: + drop_sql = f"DROP VIEW IF EXISTS {name}" + self._execute_sql_against_backend(drop_sql) + def _table_registration(self, input, table_name) -> None: if isinstance(input, dict): input = pd.DataFrame(input) elif isinstance(input, list): input = pd.DataFrame.from_records(input) - # Registration errors will automatically - # occur if an invalid data type is passed as an argument - self._execute_sql_against_backend( - f"CREATE TABLE {table_name} AS SELECT * FROM input" - ) + self._con.register(table_name, input) def table_to_splink_dataframe( self, templated_name, physical_name diff --git a/splink/linker.py b/splink/linker.py index f0b6e3f79b..0ad9431e91 100644 --- a/splink/linker.py +++ b/splink/linker.py @@ -806,9 +806,7 @@ def _populate_m_u_from_trained_values(self): cl.m_probability = cl._trained_m_median def delete_tables_created_by_splink_from_db(self): - for splink_df in list(self._intermediate_table_cache.values()): - if splink_df.created_by_splink: - splink_df.drop_table_from_database_and_remove_from_cache() + self.db_api.delete_tables_created_by_splink_from_db() def _raise_error_if_necessary_waterfall_columns_not_computed(self): ricc = self._settings_obj._retain_intermediate_calculation_columns diff --git a/splink/profile_data.py b/splink/profile_data.py index 9a8d5f4737..432eef2172 100644 --- a/splink/profile_data.py +++ b/splink/profile_data.py @@ -319,6 +319,8 @@ def profile_columns( ) inner_charts.append(inner_chart) + db_api.delete_tables_created_by_splink_from_db() + if inner_charts != []: outer_spec = deepcopy(_outer_chart_spec_freq) outer_spec["vconcat"] = inner_charts