Skip to content

Commit

Permalink
support for custom trainer classes from plugins
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 30, 2024
1 parent 9801b45 commit 962bbbc
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ def get_post_trainer_create_callbacks(self, trainer):
return callbacks

def _get_trainer_cls(self):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if trainer_cls:
return trainer_cls
if self.cfg.relora_steps:
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
Expand Down
27 changes: 27 additions & 0 deletions src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
None
"""

def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""

def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Expand Down Expand Up @@ -346,6 +357,22 @@ def post_lora_load(self, cfg, model):
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)

def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
if trainer_cls is not None:
return trainer_cls
return None

def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
Expand Down

0 comments on commit 962bbbc

Please sign in to comment.