Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba-Ssm - Loader for Mamba State Space models #5228

Closed
wants to merge 25 commits into from

Conversation

IggoOnCode
Copy link
Contributor

@IggoOnCode IggoOnCode commented Jan 10, 2024

Checklist:

This PR adds support for "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" as described in https://arxiv.org/abs/2312.00752

Currently it's able to use https://huggingface.co/state-spaces/mamba-2.8b-slimpj for inference.

I plan to add more features and training later. For now I hope we can add just basic inference.

This implements the issue: #4830

@thistleknot
Copy link

would love to see this merged

@IggoOnCode
Copy link
Contributor Author

I would very much like to get feedback on what's still to do to get this merged.

Maybe I can test/review another PR in turn. Give me a hint on how to help out!

@IggoOnCode IggoOnCode marked this pull request as draft January 18, 2024 22:34
@IggoOnCode
Copy link
Contributor Author

Training already works as a prototype but needs serious rework before pushing. Should only take some days.

@minipasila
Copy link
Contributor

Is this only supported on Linux? I tried installing your branch and it kept giving me some pip installation errors.

Collecting mamba-ssm (from -r requirements.txt (line 25))
  Using cached mamba_ssm-1.1.1.tar.gz (34 kB)
  Preparing metadata (setup.py) ... error
  error: subprocess-exited-with-error

  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [6 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "C:\Users\pasil\AppData\Local\Temp\pip-install-1yt4027x\mamba-ssm_f7a5d5610da94a50a7d6e5a2f11f858d\setup.py", line 8, in <module>
          from packaging.version import parse, Version
      ModuleNotFoundError: No module named 'packaging'
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.
Collecting causal-conv1d (from -r requirements.txt (line 2))
  Using cached causal_conv1d-1.1.1.tar.gz (6.6 kB)
  Preparing metadata (setup.py) ... error
  error: subprocess-exited-with-error

  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [6 lines of output]
      Traceback (most recent call last):
        File "<string>", line 2, in <module>
        File "<pip-setuptools-caller>", line 34, in <module>
        File "C:\Users\pasil\AppData\Local\Temp\pip-install-bnpazlzy\causal-conv1d_002ca3f0e39e4f87849057ec3db54a04\setup.py", line 9, in <module>
          from packaging.version import parse, Version
      ModuleNotFoundError: No module named 'packaging'
      [end of output]

  note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

@IggoOnCode
Copy link
Contributor Author

Hi @minipasila

Thanks for using and thereby testing this branch!

Is this only supported on Linux? I tried installing your branch and it kept giving me some pip installation errors.

While I develop and test only on Ubuntu, there is no reason that this should not work on windows.

In this case the error is caused by the original mamba-ssm python package not declaring it's dependencies correctly. I have added the "packaging" package to the requirements.txt and now I'm at least able to successfully pip install -r requirements.txt in a new conda environment (on ubuntu 22.04).

And there may also be a hard dependency on CUDA that comes from the original mamba-ssm package. I'm still unsure how to deal with that.

Please try the updated branch and report any errors. While I can't test on Windows, I'll try to address all bug reports for Windows as good as I can.

@IggoOnCode
Copy link
Contributor Author

The first version of training code has landed.

These models seem to learn very well.

@minipasila
Copy link
Contributor

Please try the updated branch and report any errors. While I can't test on Windows, I'll try to address all bug reports for Windows as good as I can.

For whatever reason it wouldn't install "packaging" package before it tried installing the mamba stuff, so I did that manually and it started installing them until it gave another error, stating that it depends on triton but I think that's only supported on Linux currently, unless that has changed recently.

INFO: pip is looking at multiple versions of mamba-ssm to determine which version is compatible with other requirements. This could take a while.
Collecting mamba-ssm (from -r requirements.txt (line 11))
  Using cached mamba_ssm-1.1.0.tar.gz (34 kB)
  Preparing metadata (setup.py) ... done
  Using cached mamba_ssm-1.0.1.tar.gz (28 kB)
  Preparing metadata (setup.py) ... done
ERROR: Cannot install -r requirements.txt (line 11) because these package versions have conflicting dependencies.

The conflict is caused by:
    mamba-ssm 1.1.1 depends on triton
    mamba-ssm 1.1.0 depends on triton
    mamba-ssm 1.0.1 depends on triton

To fix this you could try to:
1. loosen the range of package versions you've specified
2. remove package versions to allow pip attempt to solve the dependency conflict

ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts

@IggoOnCode
Copy link
Contributor Author

it depends on triton but I think that's only supported on Linux currently, unless that has changed recently.

I looked it up, triton is still Linux only. That's unlucky. But I don't think I can fix that easily.

Then this has to be windows only for the moment.

Can you try WSL?

@minipasila
Copy link
Contributor

it depends on triton but I think that's only supported on Linux currently, unless that has changed recently.

I looked it up, triton is still Linux only. That's unlucky. But I don't think I can fix that easily.

Then this has to be windows only for the moment.

Can you try WSL?

I did manage to find this https://github.com/jakaline-dev/Triton_win/releases/tag/2.1.0 even though it installed triton successfully that didn't fix the issue.. But since I don't want to deal with WSL right now I tried it on Google Colab and that seemed broken as well for some reason (maybe old cuda version?)

Colab error
13:04:43-753830 INFO     Loading state-spaces_mamba-2.8b-slimpj                                     
13:04:54-793189 INFO     LOADER: Mamba-Ssm                                                          
13:04:54-794505 INFO     TRUNCATION LENGTH: 2048                                                    
13:04:54-795532 INFO     INSTRUCTION TEMPLATE: Alpaca                                               
13:04:54-796439 INFO     Loaded the model in 11.04 seconds.                                         
Traceback (most recent call last):
File "/content/text-generation-webui/modules/callbacks.py", line 61, in gentask
  ret = self.mfunc(callback=_callback, *args, **self.kwargs)
File "/content/text-generation-webui/modules/mamba.py", line 154, in generate
  output = self.model.generate(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 244, in generate
  output = decode(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 145, in decode
  model._decoding_cache = update_graph_cache(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
  return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 305, in update_graph_cache
  cache.callables[batch_size, decoding_seqlen] = capture_graph(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/utils/generation.py", line 339, in capture_graph
  logits = model(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 233, in forward
  hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/models/mixer_seq_simple.py", line 155, in forward
  hidden_states, residual = layer(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
  return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/modules/mamba_simple.py", line 340, in forward
  hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
  return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 539, in apply
  return super().apply(*args, **kwargs)  # type: ignore[misc]
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
  y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
  _layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in run
  timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
  timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 83, in _bench
  return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
  fn()
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 81, in kernel_call
  self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 425, in compile
  so_path = make_stub(name, signature, constants)
File "/usr/local/lib/python3.10/dist-packages/triton/compiler/make_launcher.py", line 39, in make_stub
  so = _build(name, src_path, tmpdir)
File "/usr/local/lib/python3.10/dist-packages/triton/common/build.py", line 61, in _build
  cuda_lib_dirs = libcuda_dirs()
File "/usr/local/lib/python3.10/dist-packages/triton/common/build.py", line 30, in libcuda_dirs
  assert any(os.path.exists(os.path.join(path, 'libcuda.so')) for path in dirs), msg
AssertionError: libcuda.so cannot found!

But on Runpod I was able to get it working successfully without problems. I was able to load the model and generate text.

@IggoOnCode
Copy link
Contributor Author

IggoOnCode commented Jan 22, 2024

@minipasila Thank you for trying again!

Too bad Triton_win didn't work. That would have been perfect.

For the Google Colab error, I can't say anything. I've never tried Colab.

But on Runpod I was able to get it working successfully without problems. I was able to load the model and generate text.

That is awesome! Thank you for investing the resources to try it!

Then we have one datapoint that the changes generally work. :-)

IggoOnCode added 2 commits January 24, 2024 01:22
… Made fix for mixed precision models not apply to SSMs.
… not work because the free instances on Colab all dont support bfloat16
@IggoOnCode
Copy link
Contributor Author

I fixed the "libcuda.so not found" error on Google Colab. The cause was the default installation of triton==2.1 which breaks on Colab. Forcing a downgrade to triton==2.0 fixed that.

But Mamba still doesn't work on the free runtimes, because neither the T4 (I assume Tesla T40 GPUs) nor the TPU instances support bfloat16.

Maybe someone with Colab credits can now test the premium GPUs.

I will have a look into making Mamba work with other data types, but only after this PR is merged.

@IggoOnCode IggoOnCode marked this pull request as ready for review January 25, 2024 01:05
@oobabooga
Copy link
Owner

This PR lacks a justification for adding mamba support. It is a promising alternative architecture for LLMs, but the linked model is small both in number of parameters (2.8b) and training dataset size (600b). What is the use case?

Another promising alternative architecture is RWKV and it suffers from the same problem.

@thistleknot
Copy link

mamba is the hottest thing since sliced bread is the justification. limitless context. non transformer based model (no attention) and it's much more efficient to train than transformer models (faster convergence).

@IggoOnCode
Copy link
Contributor Author

This PR lacks a justification for adding mamba support.

I didn't know I needed one as there already was an issue asking for it and, as others already stated, mamba is the new cool kid in town.

But I'll happily provide one:

This PR adds support for Mamba State space models to allow experimentation with and evaluation of this new model architecture.

In first experments Mamba models provide better or similar performance than transformer models in comparable size, while also offering benefits like constant memory and linear time requirements for larger context sizes and higher training efficency. These benefits are especially important for users of text-generation-webui as most of us have to do with very limited resources.

It is a promising alternative architecture for LLMs, but the linked model is small both in number of parameters (2.8b) and training dataset size (600b).

It performs on the level of 7b transformers which are widely used. With lower VRAM requirements, making it even more accessible for smaller GPUs.

I would love to already have a mamba-13b or larger. There are rumours that companies are training on those, but I know nothing for sure. What I know for sure, is that companies are much more likely to invest millions in compute if the technology has wider software support. By implementing it in a widely used software like text-generation-webui we can encourage those investments, leading to better research in this model type. Any new model architecture that could surpass transformer models will need a lot of support from all sides to actually do it, because transformers already have a large ecosystem and literal billions and billions of investment dollars. The open-source community can implement, test and then adopt or drop new model types easily, but we need the upfront pretraining by larger organisations. Getting that will be easier if the tooling is ready.

What is the use case?

Fun, Research, specialised fine tunes.

During development of this PR I used it (an in between version, not the final one) to create this model: https://huggingface.co/IggoOnCode/mamba-2.8b-slimpj-OpenOrca_1ep

First evaluation suggests that I'm probably not good at this. Perplexity went up, accuracy down. That can happen in fine-tune as I have read, futher testing how it feels is required. But it proves the reason why I wanted to use text-generation-webui for training. Using the webui I had saved a snapshot during training which I could evaluate too. Turns out that perplexity and accuracy where worse at 50% of the training run and got better after that. (Next I maybe try to train another epoch on top of it, but I'll see).
I'll will now use the evaluation feature of the webui too and put the results on the model card soon.

Another promising alternative architecture is RWKV and it suffers from the same problem.

RWKV has been supported individually by text-generation-webui before the support moved directly into the transformers library. When mamba moves into transformers too, then the special mamba support can easily be removed. If Mamba doesn't make it, it can also be removed.

What I take from this argument is that the mamba integration should be as easily removable as possible. To facilitate this I will go over the training code again and make the distinction not between llama model or mamba model, but between lora and full fine-tune, which it basically is. That has the additional advantage that we may be able to train RWKV models too (and other model types that may come in the future).

@IggoOnCode IggoOnCode marked this pull request as draft January 30, 2024 00:38
@IggoOnCode IggoOnCode marked this pull request as ready for review February 10, 2024 19:10
@IggoOnCode
Copy link
Contributor Author

IggoOnCode commented Feb 10, 2024

Good news, mamba support in transformers is definitely coming: huggingface/transformers#28094

Sadly it didn't "just work" when using the add-mamba branch from transformers in text-generation-webui. The models could not even be loaded, although loading should already work for them. And the transformers branch does not support any training yet (according to the issue to the PR).

The code in this PR now works for full fine-tuning of mamba and llama and lora training of llama. But I had to keep four references to mamba in training.py.

I would be happy if this PR could be reviewed and considered for merging.

I already have prepared a branch for the removal of mamba-ssm (https://github.com/IggoOnCode/text-generation-webui/tree/remove_mamba-ssm). When this PR gets merged I will immediately create a draft PR from that branch that I will keep conflict-free until the transformers implementation for mamba is ready to replace this implementation.

@thistleknot
Copy link

thistleknot commented Feb 10, 2024 via email

@ArthurZucker
Copy link

from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", padding_side = "left")
tokenizer.pad_token = tokenizer.eos_token

model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float32)
model.config.use_cache = True
input_ids = tokenizer(["Hey how are you doing?", "Explain how soy sauce is made"], padding=True, return_tensors= "pt")["input_ids"]

out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))

the branch in transformers now supports this 🤗

@IggoOnCode
Copy link
Contributor Author

IggoOnCode commented Feb 16, 2024

@oobabooga

As you obviously don't want the mamba-ssm code, are you interested in only the training changes (bringing back full fine-tuning)?

If so, I would move them to an own PR. But only if you at least give any indication on how to progress now.

@IggoOnCode IggoOnCode closed this Feb 16, 2024
@oobabooga
Copy link
Owner

@IggoOnCode, I appreciate your contribution, but as a hobby developer maintaining this project by myself, I have to be selective about which changes I can take on. Accepting the PR means committing to long-term maintenance and handling of future PRs related to the new loader, which can be a significant time investment.

To make it easier to integrate experimental loaders, I would like to refactor the project to have self-contained loaders. This would allow each loader to have its own functions for logits, text generation, etc., and be maintained independently.

Regarding training, the current code could benefit from a review to ensure best practices are being followed. Specifically, the code contains parameters not found elsewhere like Overlap Length, Prefer Newline Cut Length, Hard Cut String that I think should be removed, making it more aligned with how things are done in axolotl. The string chunking logic also probably needs improvement #3476. If you could review the training code while adding Mamba support, it would be greatly appreciated.

@IggoOnCode
Copy link
Contributor Author

@oobabooga That's completely reasonable. Thanks for letting me know.

I'll see what I can do for the training code while doing my experiments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants