diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index 0117bb97..af052798 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -612,6 +612,65 @@ def _get_information_schema_tables(self, database: str) -> pd.DataFrame: return df_tables + def get_training_plan_generic(self, df) -> TrainingPlan: + # For each of the following, we look at the df columns to see if there's a match: + database_column = df.columns[ + df.columns.str.lower().str.contains("database") + | df.columns.str.lower().str.contains("table_catalog") + ].to_list()[0] + schema_column = df.columns[ + df.columns.str.lower().str.contains("table_schema") + ].to_list()[0] + table_column = df.columns[ + df.columns.str.lower().str.contains("table_name") + ].to_list()[0] + column_column = df.columns[ + df.columns.str.lower().str.contains("column_name") + ].to_list()[0] + data_type_column = df.columns[ + df.columns.str.lower().str.contains("data_type") + ].to_list()[0] + + plan = TrainingPlan([]) + + for database in df[database_column].unique().tolist(): + for schema in ( + df.query(f'{database_column} == "{database}"')[schema_column] + .unique() + .tolist() + ): + for table in ( + df.query( + f'{database_column} == "{database}" and {schema_column} == "{schema}"' + )[table_column] + .unique() + .tolist() + ): + df_columns_filtered_to_table = df.query( + f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"' + ) + doc = f"The following columns are in the {table} table in the {database} database:\n\n" + doc += df_columns_filtered_to_table[ + [ + database_column, + schema_column, + table_column, + column_column, + data_type_column, + ] + ].to_markdown() + + plan._plan.append( + TrainingPlanItem( + item_type=TrainingPlanItem.ITEM_TYPE_IS, + item_group=f"{database}.{schema}", + item_name=table, + item_value=doc, + ) + ) + + return plan + def get_training_plan_snowflake( self, filter_databases: Union[List[str], None] = None,