From 962bbbc1ab6f421268a8394aae71d60cb3da2919 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 30 Dec 2024 12:20:45 -0500 Subject: [PATCH] support for custom trainer classes from plugins --- src/axolotl/core/trainer_builder.py | 5 +++++ src/axolotl/integrations/base.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index def1d7a264..7eadd3e592 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -295,6 +295,11 @@ def get_post_trainer_create_callbacks(self, trainer): return callbacks def _get_trainer_cls(self): + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if trainer_cls: + return trainer_cls if self.cfg.relora_steps: return ReLoRATrainer if self.cfg.model_config_type == "mamba": diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d10..26f2f8a6f0 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -111,6 +111,17 @@ def post_lora_load(self, cfg, model): # pylint: disable=unused-argument None """ + def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): + """ + Returns a custom class for the trainer. + + Parameters: + cfg (dict): The global axolotl configuration. + + Returns: + class: The class for the trainer. + """ + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument """ Creates and returns an optimizer for training. @@ -346,6 +357,22 @@ def post_lora_load(self, cfg, model): for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) + def get_trainer_cls(self, cfg): + """ + Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The trainer class, or None if none was found. + """ + for plugin in self.plugins.values(): + trainer_cls = plugin.get_trainer_cls(cfg) + if trainer_cls is not None: + return trainer_cls + return None + def create_optimizer(self, cfg, trainer): """ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.