-
-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feat]: Implementation of Autoclip #605
Comments
Hello, I'd like to work on the project by fixing this issue. Can you assign this to me? Thanks |
@ricor07 done. Most of the discussions about new features happen on the discord server. So if you have any questions, feel free to join. https://discord.com/invite/KwgcQd5scF |
FWIW, I hacked a quick implementation of this into my local branch, and it works quite nicely. # class BaseModelSetup
def __init__( # ...
# ...
self.grad_history = {}
def autoclip(self, model, clip_percentile=0.1):
pass
def _autoclip(self, modules, clip_percentile):
for name, module in modules.items():
if name not in self.grad_history:
self.grad_history[name] = []
params = [p for p in module.parameters() if p.grad is not None]
grad_obs = torch.stack([p.grad.data.norm()**2.0 for p in params]).sum().sqrt().item()
self.grad_history[name].append(grad_obs)
self.grad_history[name] = self.grad_history[name][-200:]
clip_value = torch.tensor(self.grad_history[name]).quantile(clip_percentile)
torch.nn.utils.clip_grad_norm_(params, clip_value)
# class StableDiffusionLoRASetup
def autoclip(self, model, clip_percentile=0.1):
modules = {name: module for name, module in model.__dict__.items() if isinstance(module, LoRAModuleWrapper)}
self._autoclip(modules, clip_percentile)
# GenericTrainer
# ...
if self.__is_update_step(train_progress):
self.model_setup.autoclip(self.model, clip_percentile=0.1) It'd need to be implemented per setup type, and the percentile would need to be configurable, but I am quite pleased with the results I got from it. The 200-sample clipping was an arbitrary choice, as well, but I don't expect would make too much of a difference overall. I experimented with tracking the percentiles over the entire model, split by module, split by layer, and even split by parameter. I actually think that splitting by layer produces better results overall, though for loras there are some interesting implications, because each layer has two params, and the |
Describe your use-case.
Please consider implementation of AutoClip: Adaptive Gradient Clipping as another option to flat max_grad_norm, you can find the code and more information at https://github.com/pseeth/autoclip
What would you like to see as a solution?
implementation of autoclip as an option, but pizza would be nice too.
Have you considered alternatives? List them here.
no.
The text was updated successfully, but these errors were encountered: