forked from mudler/LocalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: 🐍 add mamba support (mudler#1589)
feat(mamba): Initial import This is a first iteration of the mamba backend, loosely based on mamba-chat(https://github.com/havenhq/mamba-chat).
- Loading branch information
Showing
13 changed files
with
762 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
.PHONY: mamba | ||
mamba: | ||
$(MAKE) -C ../common-env/transformers | ||
bash install.sh | ||
|
||
.PHONY: run | ||
run: | ||
@echo "Running mamba..." | ||
bash run.sh | ||
@echo "mamba run." | ||
|
||
.PHONY: test | ||
test: | ||
@echo "Testing mamba..." | ||
bash test.sh | ||
@echo "mamba tested." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Creating a separate environment for the mamba project | ||
|
||
``` | ||
make mamba | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#!/usr/bin/env python3 | ||
from concurrent import futures | ||
import time | ||
import argparse | ||
import signal | ||
import sys | ||
import os | ||
|
||
import backend_pb2 | ||
import backend_pb2_grpc | ||
|
||
import grpc | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | ||
|
||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 | ||
|
||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1 | ||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) | ||
MAMBA_CHAT= os.environ.get('MAMBA_CHAT', '1') == '1' | ||
|
||
# Implement the BackendServicer class with the service methods | ||
class BackendServicer(backend_pb2_grpc.BackendServicer): | ||
""" | ||
A gRPC servicer that implements the Backend service defined in backend.proto. | ||
""" | ||
def generate(self,prompt, max_new_tokens): | ||
""" | ||
Generates text based on the given prompt and maximum number of new tokens. | ||
Args: | ||
prompt (str): The prompt to generate text from. | ||
max_new_tokens (int): The maximum number of new tokens to generate. | ||
Returns: | ||
str: The generated text. | ||
""" | ||
self.generator.end_beam_search() | ||
|
||
# Tokenizing the input | ||
ids = self.generator.tokenizer.encode(prompt) | ||
|
||
self.generator.gen_begin_reuse(ids) | ||
initial_len = self.generator.sequence[0].shape[0] | ||
has_leading_space = False | ||
decoded_text = '' | ||
for i in range(max_new_tokens): | ||
token = self.generator.gen_single_token() | ||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): | ||
has_leading_space = True | ||
|
||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) | ||
if has_leading_space: | ||
decoded_text = ' ' + decoded_text | ||
|
||
if token.item() == self.generator.tokenizer.eos_token_id: | ||
break | ||
return decoded_text | ||
|
||
def Health(self, request, context): | ||
""" | ||
Returns a health check message. | ||
Args: | ||
request: The health check request. | ||
context: The gRPC context. | ||
Returns: | ||
backend_pb2.Reply: The health check reply. | ||
""" | ||
return backend_pb2.Reply(message=bytes("OK", 'utf-8')) | ||
|
||
def LoadModel(self, request, context): | ||
""" | ||
Loads a language model. | ||
Args: | ||
request: The load model request. | ||
context: The gRPC context. | ||
Returns: | ||
backend_pb2.Result: The load model result. | ||
""" | ||
try: | ||
tokenizerModel = request.Tokenizer | ||
if tokenizerModel == "": | ||
tokenizerModel = request.Model | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(tokenizerModel) | ||
if MAMBA_CHAT: | ||
tokenizer.eos_token = "<|endoftext|>" | ||
tokenizer.pad_token = tokenizer.eos_token | ||
self.tokenizer = tokenizer | ||
self.model = MambaLMHeadModel.from_pretrained(request.Model, device="cuda", dtype=torch.float16) | ||
except Exception as err: | ||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") | ||
return backend_pb2.Result(message="Model loaded successfully", success=True) | ||
|
||
def Predict(self, request, context): | ||
""" | ||
Generates text based on the given prompt and sampling parameters. | ||
Args: | ||
request: The predict request. | ||
context: The gRPC context. | ||
Returns: | ||
backend_pb2.Result: The predict result. | ||
""" | ||
if request.TopP == 0: | ||
request.TopP = 0.9 | ||
|
||
max_tokens = request.Tokens | ||
|
||
if request.Tokens == 0: | ||
max_tokens = 2000 | ||
|
||
encoded_input = self.tokenizer(request.Prompt) | ||
|
||
out = self.model.generate(input_ids=encoded_input["input_ids"], max_length=max_tokens, temperature=request.Temperratur, | ||
top_p=request.TopP, eos_token_id=self.tokenizer.eos_token_id) | ||
|
||
decoded = self.tokenizer.batch_decode(out) | ||
|
||
generated_text = decoded[0] | ||
|
||
# Remove prompt from response if present | ||
if request.Prompt in generated_text: | ||
generated_text = generated_text.replace(request.Prompt, "") | ||
|
||
return backend_pb2.Result(message=bytes(generated_text, encoding='utf-8')) | ||
|
||
def PredictStream(self, request, context): | ||
""" | ||
Generates text based on the given prompt and sampling parameters, and streams the results. | ||
Args: | ||
request: The predict stream request. | ||
context: The gRPC context. | ||
Returns: | ||
backend_pb2.Result: The predict stream result. | ||
""" | ||
# Implement PredictStream RPC | ||
#for reply in some_data_generator(): | ||
# yield reply | ||
# Not implemented yet | ||
return self.Predict(request, context) | ||
|
||
def serve(address): | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS)) | ||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) | ||
server.add_insecure_port(address) | ||
server.start() | ||
print("Server started. Listening on: " + address, file=sys.stderr) | ||
|
||
# Define the signal handler function | ||
def signal_handler(sig, frame): | ||
print("Received termination signal. Shutting down...") | ||
server.stop(0) | ||
sys.exit(0) | ||
|
||
# Set the signal handlers for SIGINT and SIGTERM | ||
signal.signal(signal.SIGINT, signal_handler) | ||
signal.signal(signal.SIGTERM, signal_handler) | ||
|
||
try: | ||
while True: | ||
time.sleep(_ONE_DAY_IN_SECONDS) | ||
except KeyboardInterrupt: | ||
server.stop(0) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Run the gRPC server.") | ||
parser.add_argument( | ||
"--addr", default="localhost:50051", help="The address to bind the server to." | ||
) | ||
args = parser.parse_args() | ||
|
||
serve(args.addr) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.