diff --git a/README.md b/README.md index fe98b3f..9b3d6d9 100644 --- a/README.md +++ b/README.md @@ -109,9 +109,9 @@ the `config_overrides` parameter with Hydra syntax from the command line. Here’s an example command with some configuration overrides: ```bash -aga run toy_data --config_overrides "feature_transformers=[], autogluon.predictor_fit_kwargs.time_limit=3600" +aga run toy_data --config_overrides "feature_transformers.enabled=False, autogluon.predictor_fit_kwargs.time_limit=3600" # OR -aga run toy_data --config_overrides "feature_transformers=[]" --config_overrides "autogluon.predictor_fit_kwargs.time_limit=3600" +aga run toy_data --config_overrides "feature_transformers.enabled=False" --config_overrides "autogluon.predictor_fit_kwargs.time_limit=3600" ``` \ No newline at end of file diff --git a/src/autogluon_assistant/assistant.py b/src/autogluon_assistant/assistant.py index 1a3cbba..151be91 100644 --- a/src/autogluon_assistant/assistant.py +++ b/src/autogluon_assistant/assistant.py @@ -47,8 +47,8 @@ def __init__(self, config: DictConfig) -> None: self.config = config self.llm: Union[AssistantChatOpenAI, AssistantChatBedrock] = LLMFactory.get_chat_model(config.llm) self.predictor = AutogluonTabularPredictor(config.autogluon) - self.feature_transformers_config = config.feature_transformers - self.use_feature_transformers = config.use_feature_transformers + self.feature_transformers_config = config.feature_transformers.transformers + self.use_feature_transformers = config.feature_transformers.enabled def describe(self) -> Dict[str, Any]: return { diff --git a/src/autogluon_assistant/configs/best_quality.yaml b/src/autogluon_assistant/configs/best_quality.yaml index ccaea01..296d8f4 100644 --- a/src/autogluon_assistant/configs/best_quality.yaml +++ b/src/autogluon_assistant/configs/best_quality.yaml @@ -5,19 +5,20 @@ save_artifacts: enabled: False append_timestamp: True path: "./aga-artifacts" -use_feature_transformers: True feature_transformers: - - _target_: autogluon_assistant.transformer.CAAFETransformer - eval_model: lightgbm - llm_provider: ${llm.provider} - llm_model: ${llm.model} - num_iterations: 5 - optimization_metric: roc - - _target_: autogluon_assistant.transformer.OpenFETransformer - n_jobs: 1 - num_features_to_keep: 10 - - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer - model_name: 'all-mpnet-base-v2' + enabled: True + transformers: + - _target_: autogluon_assistant.transformer.CAAFETransformer + eval_model: lightgbm + llm_provider: ${llm.provider} + llm_model: ${llm.model} + num_iterations: 5 + optimization_metric: roc + - _target_: autogluon_assistant.transformer.OpenFETransformer + n_jobs: 1 + num_features_to_keep: 10 + - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer + model_name: 'all-mpnet-base-v2' autogluon: predictor_init_kwargs: {} predictor_fit_kwargs: diff --git a/src/autogluon_assistant/configs/high_quality.yaml b/src/autogluon_assistant/configs/high_quality.yaml index e35e0f9..839f10f 100644 --- a/src/autogluon_assistant/configs/high_quality.yaml +++ b/src/autogluon_assistant/configs/high_quality.yaml @@ -5,19 +5,20 @@ save_artifacts: enabled: False append_timestamp: True path: "./aga-artifacts" -use_feature_transformers: False feature_transformers: - - _target_: autogluon_assistant.transformer.CAAFETransformer - eval_model: lightgbm - llm_provider: ${llm.provider} - llm_model: ${llm.model} - num_iterations: 5 - optimization_metric: roc - - _target_: autogluon_assistant.transformer.OpenFETransformer - n_jobs: 1 - num_features_to_keep: 10 - - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer - model_name: 'all-mpnet-base-v2' + enabled: False + transformers: + - _target_: autogluon_assistant.transformer.CAAFETransformer + eval_model: lightgbm + llm_provider: ${llm.provider} + llm_model: ${llm.model} + num_iterations: 5 + optimization_metric: roc + - _target_: autogluon_assistant.transformer.OpenFETransformer + n_jobs: 1 + num_features_to_keep: 10 + - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer + model_name: 'all-mpnet-base-v2' autogluon: predictor_init_kwargs: {} predictor_fit_kwargs: diff --git a/src/autogluon_assistant/configs/medium_quality.yaml b/src/autogluon_assistant/configs/medium_quality.yaml index 86e85ce..5733829 100644 --- a/src/autogluon_assistant/configs/medium_quality.yaml +++ b/src/autogluon_assistant/configs/medium_quality.yaml @@ -5,19 +5,20 @@ save_artifacts: enabled: False append_timestamp: True path: "./aga-artifacts" -use_feature_transformers: False feature_transformers: - - _target_: autogluon_assistant.transformer.CAAFETransformer - eval_model: lightgbm - llm_provider: ${llm.provider} - llm_model: ${llm.model} - num_iterations: 5 - optimization_metric: roc - - _target_: autogluon_assistant.transformer.OpenFETransformer - n_jobs: 1 - num_features_to_keep: 10 - - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer - model_name: 'all-mpnet-base-v2' + enabled: False + transformers: + - _target_: autogluon_assistant.transformer.CAAFETransformer + eval_model: lightgbm + llm_provider: ${llm.provider} + llm_model: ${llm.model} + num_iterations: 5 + optimization_metric: roc + - _target_: autogluon_assistant.transformer.OpenFETransformer + n_jobs: 1 + num_features_to_keep: 10 + - _target_: autogluon_assistant.transformer.PretrainedEmbeddingTransformer + model_name: 'all-mpnet-base-v2' autogluon: predictor_init_kwargs: {} predictor_fit_kwargs: diff --git a/src/autogluon_assistant/ui/pages/task.py b/src/autogluon_assistant/ui/pages/task.py index 854bd4d..f70133e 100644 --- a/src/autogluon_assistant/ui/pages/task.py +++ b/src/autogluon_assistant/ui/pages/task.py @@ -51,9 +51,9 @@ def update_config_overrides(): config_overrides.append(f"llm.provider={PROVIDER_MAPPING[st.session_state.llm]}") if st.session_state.feature_generation: - config_overrides.append("use_feature_transformers=True") + config_overrides.append("feature_transformers.enabled=True") else: - config_overrides.append("use_feature_transformers=False") + config_overrides.append("feature_transformers.enabled=False") st.session_state.config_overrides = config_overrides