Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support TorchServe on cpu & gpu #15

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions 1-build/Dockerfile-base-cpu
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
FROM python:3.9
ARG BASE_IMAGE=python:3.9

FROM ${BASE_IMAGE}
ARG BASE_IMAGE=python:3.9
ARG MODEL_SERVER=fastapi


LABEL description="Base container for CPU models"

USER root

RUN apt-get update && apt-get install -y htop dnsutils bc vim

RUN pip install torch configparser transformers
RUN pip install configparser

RUN echo "alias ll='ls -alh --color=auto'" >> /root/.bashrc
RUN if [ "$MODEL_SERVER" = "fastapi" ]; then \
pip install torch transformers; \
echo "alias ll='ls -alh --color=auto'" >> /root/.bashrc; \
else \
apt-get update && apt-get install -y curl; \
fi
35 changes: 20 additions & 15 deletions 1-build/Dockerfile-base-gpu
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
FROM nvidia/cuda:11.1.1-runtime-ubuntu20.04
ARG BASE_IMAGE=nvidia/cuda:11.1.1-runtime-ubuntu20.04
FROM ${BASE_IMAGE}
ARG BASE_IMAGE=nvidia/cuda:11.1.1-runtime-ubuntu20.04
ARG MODEL_SERVER=fastapi

LABEL description="Base container for GPU models"

RUN apt-get update && apt-get install -y htop vim wget curl software-properties-common debconf-utils python3-distutils dnsutils bc
USER root

# Install python3.9
RUN DEBIAN_FRONTEND=noninteractive; add-apt-repository -y ppa:deadsnakes/ppa; apt install -y python3.9; update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1
RUN apt-get update && apt-get install -y htop dnsutils bc vim curl
RUN pip install configparser

# Install pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py; python get-pip.py; rm -f get-pip.py

# Install pytorch with GPU support
RUN pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html

RUN echo "PATH=/usr/local/cuda/bin\${PATH:+:\${PATH}}" >> /etc/environment
RUN echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64\${LD_LIBRARY_PATH:+:\${LD_LIBRARY_PATH}}" >> /etc/environment

# Install other python libraries
RUN pip install transformers configparser
RUN if [ "$MODEL_SERVER" = "fastapi" ]; then \
apt-get update && apt-get install -y wget software-properties-common debconf-utils python3-distutils ; \
# Install python3.9
DEBIAN_FRONTEND=noninteractive; add-apt-repository -y ppa:deadsnakes/ppa; apt install -y python3.9; update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1;\
# Install pip
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py; python get-pip.py; rm -f get-pip.py; \
# Install pytorch with GPU support
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html; \
echo "PATH=/usr/local/cuda/bin\${PATH:+:\${PATH}}" >> /etc/environment; \
echo "LD_LIBRARY_PATH=/usr/local/cuda/lib64\${LD_LIBRARY_PATH:+:\${LD_LIBRARY_PATH}}" >> /etc/environment; \
# Install other python libraries
pip install transformers ; \
fi
21 changes: 21 additions & 0 deletions 3-pack/Dockerfile.torchserve
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
ARG BASE_IMAGE
FROM $BASE_IMAGE

ARG MODEL_NAME
ARG MODEL_FILE_NAME
ARG PROCESSOR


LABEL description="Model $MODEL_NAME packed in a TorchServe container to run on $PROCESSOR"

WORKDIR /home/model-server

COPY 3-pack/torchserve torchserve

WORKDIR /home/model-server/torchserve
USER root
COPY 3-pack/torchserve/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh

RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh \
&& chown -R model-server /home/model-server

17 changes: 17 additions & 0 deletions 3-pack/torchserve/dockerd-entrypoint.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -e

if [[ "$1" = "serve" ]]; then
shift 1

pip install -r requirements.txt
python download_model.py
torch-model-archiver --model-name BERTQA --version 1.0 --handler handler.py --config-file model-config.yaml --extra-files "./setup_config.json" --archive-format no-archive --export-path /home/model-server/model-store -f
mv Transformer_model /home/model-server/model-store/BERTQA/
torchserve --start --ts-config /home/model-server/config.properties --models model0=BERTQA
else
eval "$@"
fi

# prevent docker exit
tail -f /dev/null
116 changes: 116 additions & 0 deletions 3-pack/torchserve/download_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import json
import os
import sys

import torch
import transformers
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
set_seed,
)

print("Transformers version", transformers.__version__)
set_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def transformers_model_dowloader(
mode,
pretrained_model_name,
do_lower_case,
max_length,
torchscript,
hardware,
batch_size,
):
"""This function, save the checkpoint, config file along with tokenizer config and vocab files
of a transformer model of your choice.
"""
print("Download model and tokenizer", pretrained_model_name)
# loading pre-trained model and tokenizer
config = AutoConfig.from_pretrained(
pretrained_model_name, torchscript=torchscript
)
model = AutoModelForQuestionAnswering.from_pretrained(
pretrained_model_name, config=config
)
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name, do_lower_case=do_lower_case
)

NEW_DIR = "./Transformer_model"
try:
os.mkdir(NEW_DIR)
except OSError:
print("Creation of directory %s failed" % NEW_DIR)
else:
print("Successfully created directory %s " % NEW_DIR)

print(
"Save model and tokenizer/ Torchscript model based on the setting from setup_config",
pretrained_model_name,
"in directory",
NEW_DIR,
)
if save_mode == "pretrained":
model.save_pretrained(NEW_DIR)
tokenizer.save_pretrained(NEW_DIR)
elif save_mode == "torchscript":
dummy_input = "This is a dummy input for torch jit trace"
question = "What does the little engine say?"

context = """In the childrens story about the little engine a small locomotive is pulling a large load up a mountain.
Since the load is heavy and the engine is small it is not sure whether it will be able to do the job. This is a story
about how an optimistic attitude empowers everyone to achieve more. In the story the little engine says: 'I think I can' as it is
pulling the heavy load all the way to the top of the mountain. On the way down it says: I thought I could."""
inputs = tokenizer.encode_plus(
question,
context,
max_length=int(max_length),
padding='max_length',
add_special_tokens=True,
return_tensors="pt",
truncation=True
)
model.to(device).eval()
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
traced_model = torch.jit.trace(model, (input_ids, attention_mask))
torch.jit.save(traced_model, os.path.join(NEW_DIR, "traced_model.pt"))
return


if __name__ == "__main__":
dirname = os.path.dirname(__file__)
if len(sys.argv) > 1:
filename = os.path.join(dirname, sys.argv[1])
else:
filename = os.path.join(dirname, "setup_config.json")
f = open(filename)
settings = json.load(f)
mode = settings["mode"]
model_name = settings["model_name"]
do_lower_case = settings["do_lower_case"]
max_length = settings["max_length"]
save_mode = settings["save_mode"]
if save_mode == "torchscript":
torchscript = True
else:
torchscript = False
hardware = settings.get("hardware")
batch_size = int(settings.get("batch_size", "1"))

transformers_model_dowloader(
mode,
model_name,
do_lower_case,
max_length,
torchscript,
hardware,
batch_size,
)
167 changes: 167 additions & 0 deletions 3-pack/torchserve/handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import ast
import json
import logging
import os

import torch
import transformers
from transformers import (
AutoModelForQuestionAnswering,
AutoTokenizer,
)
from optimum.bettertransformer import BetterTransformer

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)


class TransformersSeqClassifierHandler(BaseHandler):
"""
Transformers handler class for sequence, token classification and question answering.
"""

def __init__(self):
super(TransformersSeqClassifierHandler, self).__init__()
self.initialized = False

def initialize(self, ctx):
"""In this initialize function, the BERT model is loaded and
the Layer Integrated Gradients Algorithm for Captum Explanations
is initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
model_weights_dir = ctx.model_yaml_config["handler"]["model_dir"]

self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
# read configs for the mode, model_name, etc. from setup_config.json
setup_config_path = os.path.join(model_dir, "setup_config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_file:
self.setup_config = json.load(setup_config_file)
else:
logger.warning("Missing the setup_config.json file.")

# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
if self.setup_config["save_mode"] == "torchscript":
serialized_file = "traced_model.pt"
model_pt_path = os.path.join(model_weights_dir, serialized_file)
self.model = torch.jit.load(model_pt_path, map_location=self.device)
elif self.setup_config["save_mode"] == "pretrained":
self.model = AutoModelForQuestionAnswering.from_pretrained(model_weights_dir)

try:
self.model = BetterTransformer.transform(self.model)
except RuntimeError as error:
logger.warning(
"HuggingFace Optimum is not supporting this model,for the list of supported models, please refer to this doc,https://huggingface.co/docs/optimum/bettertransformer/overview"
)
self.model.to(self.device)

if self.setup_config["save_mode"] == "pretrained":
self.tokenizer = AutoTokenizer.from_pretrained(
self.setup_config["model_name"],
do_lower_case=self.setup_config["do_lower_case"],
)
else:
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir,
do_lower_case=self.setup_config["do_lower_case"],
)

self.model.eval()
logger.info("Transformer model from path %s loaded successfully", model_dir)

self.initialized = True

def preprocess(self, requests):
"""Basic text preprocessing, based on the user's chocie of application mode.
Args:
requests (str): The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of Tensor for the size of the word tokens.
"""
input_ids_batch = None
attention_mask_batch = None
logger.info(f"req: {requests}")
for idx, input_text in enumerate(requests):
max_length = self.setup_config["max_length"]
logger.info("Received text: '%s'", input_text)

question = input_text["seq_0"].decode("utf-8")
context = input_text["seq_1"].decode("utf-8")
logger.info(f" question: {question}")
logger.info(f"context: {context}")
inputs = self.tokenizer.encode_plus(
question,
context,
max_length=int(max_length),
padding='max_length',
add_special_tokens=True,
return_tensors="pt",
truncation=True
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# making a batch out of the recieved requests
# attention masks are passed for cases where input tokens are padded.
if input_ids.shape is not None:
if input_ids_batch is None:
input_ids_batch = input_ids
attention_mask_batch = attention_mask
else:
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
attention_mask_batch = torch.cat(
(attention_mask_batch, attention_mask), 0
)
return (input_ids_batch, attention_mask_batch)

def inference(self, input_batch):
"""Predict the class (or classes) of the received text using the
serialized transformers checkpoint.
Args:
input_batch (list): List of Text Tensors from the pre-process function is passed here
Returns:
list : It returns a list of the predicted value for the input text
"""
input_ids_batch, attention_mask_batch = input_batch
inferences = []
# the output should be only answer_start and answer_end
# we are outputing the words just for demonstration.
output = self.model(
input_ids_batch, attention_mask_batch
)
answer_text = str(output[0])
answer_start = torch.argmax(output[0])
answer_end = torch.argmax(output[1])+1
if (answer_end > answer_start):
answer_text = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids_batch[0][answer_start:answer_end]))
else:
answer_text = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids_batch[0][answer_start:]))
inferences.append(answer_text)
logger.info("Model predicted: '%s'", answer_text)


print("Generated text", inferences)
return inferences

def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output
6 changes: 6 additions & 0 deletions 3-pack/torchserve/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
minWorkers: 1
maxWorkers: 1
batchSize: 1
responseTimeout: 240
handler:
model_dir: "Transformer_model"
Loading