Skip to content

Commit

Permalink
feat: 🐍 add mamba support (mudler#1589)
Browse files Browse the repository at this point in the history
feat(mamba): Initial import

This is a first iteration of the mamba backend, loosely based on
mamba-chat(https://github.com/havenhq/mamba-chat).
  • Loading branch information
mudler authored Jan 19, 2024
1 parent 52c9a7f commit 9e653d6
Show file tree
Hide file tree
Showing 13 changed files with 762 additions and 3 deletions.
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ARG TARGETVARIANT

ENV BUILD_TYPE=${BUILD_TYPE}

ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"
ENV EXTERNAL_GRPC_BACKENDS="coqui:/build/backend/python/coqui/run.sh,huggingface-embeddings:/build/backend/python/sentencetransformers/run.sh,petals:/build/backend/python/petals/run.sh,transformers:/build/backend/python/transformers/run.sh,sentencetransformers:/build/backend/python/sentencetransformers/run.sh,autogptq:/build/backend/python/autogptq/run.sh,bark:/build/backend/python/bark/run.sh,diffusers:/build/backend/python/diffusers/run.sh,exllama:/build/backend/python/exllama/run.sh,vall-e-x:/build/backend/python/vall-e-x/run.sh,vllm:/build/backend/python/vllm/run.sh,mamba:/build/backend/python/mamba/run.sh,exllama2:/build/backend/python/exllama2/run.sh,transformers-musicgen:/build/backend/python/transformers-musicgen/run.sh"

ARG GO_TAGS="stablediffusion tinydream tts"

Expand Down Expand Up @@ -168,6 +168,9 @@ RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/vllm \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/mamba \
; fi
RUN if [ "${IMAGE_TYPE}" = "extras" ]; then \
PATH=$PATH:/opt/conda/bin make -C backend/python/sentencetransformers \
; fi
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ protogen-python:
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vall-e-x/ --grpc_python_out=backend/python/vall-e-x/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/vllm/ --grpc_python_out=backend/python/vllm/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/petals/ --grpc_python_out=backend/python/petals/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/mamba/ --grpc_python_out=backend/python/mamba/ backend/backend.proto
python3 -m grpc_tools.protoc -Ibackend/ --python_out=backend/python/exllama2/ --grpc_python_out=backend/python/exllama2/ backend/backend.proto

## GRPC
Expand All @@ -429,6 +430,7 @@ prepare-extra-conda-environments:
$(MAKE) -C backend/python/coqui
$(MAKE) -C backend/python/diffusers
$(MAKE) -C backend/python/vllm
$(MAKE) -C backend/python/mamba
$(MAKE) -C backend/python/sentencetransformers
$(MAKE) -C backend/python/transformers
$(MAKE) -C backend/python/transformers-musicgen
Expand Down
1 change: 0 additions & 1 deletion backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ message ModelOptions {
int32 CLIPSkip = 33;
string ControlNet = 48;

// RWKV
string Tokenizer = 34;

// LLM (llama.cpp)
Expand Down
8 changes: 7 additions & 1 deletion backend/python/coqui/coqui_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ def Health(self, request, context):
def LoadModel(self, request, context):

# Get device
device = "cuda" if request.CUDA else "cpu"
# device = "cuda" if request.CUDA else "cpu"
if torch.cuda.is_available():
print("CUDA is available", file=sys.stderr)
device = "cuda"
else:
print("CUDA is not available", file=sys.stderr)
device = "cpu"

if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
Expand Down
16 changes: 16 additions & 0 deletions backend/python/mamba/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.PHONY: mamba
mamba:
$(MAKE) -C ../common-env/transformers
bash install.sh

.PHONY: run
run:
@echo "Running mamba..."
bash run.sh
@echo "mamba run."

.PHONY: test
test:
@echo "Testing mamba..."
bash test.sh
@echo "mamba tested."
5 changes: 5 additions & 0 deletions backend/python/mamba/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Creating a separate environment for the mamba project

```
make mamba
```
182 changes: 182 additions & 0 deletions backend/python/mamba/backend_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#!/usr/bin/env python3
from concurrent import futures
import time
import argparse
import signal
import sys
import os

import backend_pb2
import backend_pb2_grpc

import grpc

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1'

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
"""
A gRPC servicer that implements the Backend service defined in backend.proto.
"""
def generate(self,prompt, max_new_tokens):
"""
Generates text based on the given prompt and maximum number of new tokens.
Args:
prompt (str): The prompt to generate text from.
max_new_tokens (int): The maximum number of new tokens to generate.
Returns:
str: The generated text.
"""
self.generator.end_beam_search()

# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)

self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
decoded_text = ''
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
has_leading_space = True

decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text

if token.item() == self.generator.tokenizer.eos_token_id:
break
return decoded_text

def Health(self, request, context):
"""
Returns a health check message.
Args:
request: The health check request.
context: The gRPC context.
Returns:
backend_pb2.Reply: The health check reply.
"""
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

def LoadModel(self, request, context):
"""
Loads a language model.
Args:
request: The load model request.
context: The gRPC context.
Returns:
backend_pb2.Result: The load model result.
"""
try:
tokenizerModel = request.Tokenizer
if tokenizerModel == "":
tokenizerModel = request.Model

tokenizer = AutoTokenizer.from_pretrained(tokenizerModel)
if MAMBA_CHAT:
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
self.tokenizer = tokenizer
self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
return backend_pb2.Result(message="Model loaded successfully", success=True)

def Predict(self, request, context):
"""
Generates text based on the given prompt and sampling parameters.
Args:
request: The predict request.
context: The gRPC context.
Returns:
backend_pb2.Result: The predict result.
"""
if request.TopP == 0:
request.TopP = 0.9

max_tokens = request.Tokens

if request.Tokens == 0:
max_tokens = 2000

encoded_input = self.tokenizer(request.Prompt)

out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur,
top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id)

decoded = self.tokenizer.batch_decode(out)

generated_text = decoded[0]

# Remove prompt from response if present
if request.Prompt in generated_text:
generated_text = generated_text.replace(request.Prompt, "")

return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8'))

def PredictStream(self, request, context):
"""
Generates text based on the given prompt and sampling parameters, and streams the results.
Args:
request: The predict stream request.
context: The gRPC context.
Returns:
backend_pb2.Result: The predict stream result.
"""
# Implement PredictStream RPC
#for reply in some_data_generator():
# yield reply
# Not implemented yet
return self.Predict(request, context)

def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)

# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)

# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)

try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()

serve(args.addr)
61 changes: 61 additions & 0 deletions backend/python/mamba/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9e653d6

Please sign in to comment.