Skip to content

Commit

Permalink
[Hardware][Ascend] Init Ascend NPU Support
Browse files Browse the repository at this point in the history
Co-authored-by: wangshuai09 <[email protected]>
Signed-off-by: MengqingCao <[email protected]>
  • Loading branch information
MengqingCao and wangshuai09 committed Nov 29, 2024
1 parent c82b432 commit 198b85b
Show file tree
Hide file tree
Showing 31 changed files with 1,820 additions and 7 deletions.
19 changes: 19 additions & 0 deletions Dockerfile.npu
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
FROM ascendai/cann:8.0.rc3.alpha002-910b-ubuntu22.04-py3.9
## Change to use the following if you failed with the above one
# FROM quay.io/ascend/cann:8.0.rc3.alpha002-910b-ubuntu22.04-py3.9

# Define environments
ENV DEBIAN_FRONTEND=noninteractive

RUN apt-get update -y && \
apt-get install -y python3-pip git vim
WORKDIR /workspace

COPY . /workspace/vllm/

# install build requirements
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
# build vLLM with NPU backend
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="npu" python3 -m pip install /workspace/vllm/

CMD ["/bin/bash"]
41 changes: 41 additions & 0 deletions examples/offline_inference_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import gc

import torch

from vllm import LLM, SamplingParams
from vllm.distributed.parallel_state import (destroy_distributed_environment,
destroy_model_parallel)


def clean_up():
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()


# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]

# Create a sampling params object.
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)

# Create an LLM.
llm = LLM(model="facebook/opt-125m")

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

del llm
clean_up()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ exclude = [
]

[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
ignore-words-list = "dout, te, indicies, subtile, cann"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"

[tool.isort]
Expand Down
8 changes: 8 additions & 0 deletions requirements-npu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Common dependencies
-r requirements-common.txt

decorator
pyyaml
scipy
setuptools
torch_npu == 2.5.1rc1
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def _is_xpu() -> bool:
return VLLM_TARGET_DEVICE == "xpu"


def _is_npu() -> bool:
return VLLM_TARGET_DEVICE == "npu"


def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()

Expand Down Expand Up @@ -429,6 +433,8 @@ def get_vllm_version() -> str:
version += f"{sep}cpu"
elif _is_xpu():
version += f"{sep}xpu"
elif _is_npu():
version += f"{sep}npu"
else:
raise RuntimeError("Unknown runtime environment")

Expand Down Expand Up @@ -489,6 +495,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-cpu.txt")
elif _is_xpu():
requirements = _read_requirements("requirements-xpu.txt")
elif _is_npu():
requirements = _read_requirements("requirements-npu.txt")
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCm, Neuron, HPU, "
Expand Down
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

logger = init_logger(__name__)

if not current_platform.is_tpu() and not current_platform.is_hpu():
if not (current_platform.is_tpu() or current_platform.is_hpu()
or current_platform.is_npu()):
try:
import vllm._C
except ImportError as e:
Expand Down
Loading

0 comments on commit 198b85b

Please sign in to comment.