Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ROCM libs and improvements #2358

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 43 additions & 9 deletions Dockerfile_amd
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt

# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
FROM rocm/dev-ubuntu-22.04:6.2 AS base

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
Expand All @@ -50,23 +50,25 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-dev && \
python3.11-dev \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
Expand All @@ -76,7 +78,14 @@ ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH

ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"

RUN curl -fsSL -v -o cmake-3.30.2-linux-x86_64.sh https://github.com/Kitware/CMake/releases/download/v3.30.2/cmake-3.30.2-linux-x86_64.sh \
danieldk marked this conversation as resolved.
Show resolved Hide resolved
&& chmod +x cmake-3.30.2-linux-x86_64.sh \
&& ./cmake-3.30.2-linux-x86_64.sh --skip-license --prefix=/usr/local \
&& rm cmake-3.30.2-linux-x86_64.sh

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
Expand All @@ -100,19 +109,37 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya

# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN pip install numpy einops ninja joblib msgpack --no-cache-dir

# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
&& rm -rf hipBLASLt

RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
pip install .

RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
git checkout ${PYTORCH_COMMIT} && \
git submodule update --init --recursive && \
pip install -r requirements.txt --no-cache-dir

ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
Expand Down Expand Up @@ -224,6 +251,13 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy

ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1

COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

Expand Down
6 changes: 6 additions & 0 deletions docs/source/installation_amd.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC

By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.

## Custom PagedAttention

For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.

The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.

## Unsupported features

The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)

build-flash-attention-v2-cuda:
pip install -U packaging wheel
Expand Down
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
Expand All @@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
Expand Down
152 changes: 104 additions & 48 deletions server/text_generation_server/layers/attention/rocm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional
import torch
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
Expand All @@ -8,13 +9,19 @@

major, minor = torch.cuda.get_device_capability()
is_sm75 = major == 7 and minor == 5
_PARTITION_SIZE = 512

_PARTITION_SIZE_V1V2 = 512
_PARTITION_SIZE_CUSTOM = 256

use_triton = os.getenv("ROCM_USE_FLASH_ATTN_V2_TRITON", "").lower() in {"true", "1"}
ENGINE = "triton" if use_triton else "ck"

custom_attn_available = os.getenv("ROCM_USE_CUSTOM_PAGED_ATTN", "1") != "0"
if custom_attn_available:
from vllm._custom_C import paged_attention_custom
danieldk marked this conversation as resolved.
Show resolved Hide resolved

try:
from vllm._C import cache_ops
import vllm._custom_ops as ops
except Exception as e:
raise ImportError(
f"Could not import vllm paged attention. Make sure your installation is correct. Complete error: {e}"
Expand All @@ -33,9 +40,7 @@ def reshape_and_cache(
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)
ops.reshape_and_cache(key, value, key_cache, value_cache, slots, "auto", 1.0)


def paged_attention(
Expand All @@ -45,8 +50,9 @@ def paged_attention(
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
input_lengths: Seqlen,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):
# Adapted from: https://github.com/vllm-project/vllm/blob/f8a1e39fae05ca610be8d5a78be9d40f5274e5fc/vllm/model_executor/layers/attention.py
# Copyright 2023 The vLLM team. All rights
Expand All @@ -68,8 +74,25 @@ def paged_attention(
# value_cache => [num_blocks, num_heads, head_size, block_size]
block_size = value_cache.shape[3]
num_seqs, num_heads, head_size = query.shape

num_kv_heads = key_cache.shape[1]
gqa_ratio = num_heads // num_kv_heads
use_custom = (
custom_attn_available
and (query.dtype == torch.half or query.dtype == torch.bfloat16)
and (head_size == 128 or head_size == 64)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_s <= 32768
)

if not use_custom:
_PARTITION_SIZE = _PARTITION_SIZE_V1V2
else:
_PARTITION_SIZE = _PARTITION_SIZE_CUSTOM

max_num_partitions = (max_s + _PARTITION_SIZE - 1) // _PARTITION_SIZE
input_lengths = input_lengths.input_lengths
input_lengths = seqlen.input_lengths

out = torch.empty_like(query)

Expand All @@ -78,9 +101,13 @@ def paged_attention(
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
from vllm._C import ops
import vllm._custom_ops as ops

use_v1 = max_s <= 8192 and (max_num_partitions == 1 or num_seqs * num_heads > 512)
use_v1 = (
max_s <= 8192
and (max_num_partitions == 1 or num_seqs * num_heads > 512)
and not use_custom
)
if use_v1:
ops.paged_attention_v1(
out,
Expand Down Expand Up @@ -112,24 +139,44 @@ def paged_attention(
)
max_logits = torch.empty_like(exp_sums)

ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
if not use_custom:
ops.paged_attention_v2(
danieldk marked this conversation as resolved.
Show resolved Hide resolved
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
kv_head_mapping,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
1.0,
)
else:
paged_attention_custom(
out,
exp_sums,
max_logits,
tmp_output,
query,
key_cache,
value_cache,
num_kv_heads,
softmax_scale,
block_tables,
input_lengths,
block_size,
max_s,
None,
"auto",
)

return out


Expand Down Expand Up @@ -173,13 +220,14 @@ def paged_attention(

def attention(
q,
k,
v,
cu_seqlens,
max_s,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
danieldk marked this conversation as resolved.
Show resolved Hide resolved
):
if window_size_left <= 0 and window_size_left != -1:
raise ValueError("`window_size_left` must be > 0 or -1")
Expand All @@ -189,46 +237,54 @@ def attention(
# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
return flash_attn_2_cuda.varlen_fwd(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
None,
None,
None,
None,
seqlen.max_q,
seqlen.max_k,
0.0,
softmax_scale,
False,
causal,
window_size_left,
0,
softcap,
False,
None,
)
)[0]

elif ENGINE == "triton":
from .flash_attn_triton import triton_attention

def attention(
q,
k,
v,
cu_seqlens,
max_s,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap=0.0,
danieldk marked this conversation as resolved.
Show resolved Hide resolved
):
out = torch.empty_like(q)

# We do not need to check window_size_left (not supported) here, so it is already checked ahead of time at model load.
output, _ = triton_attention(
q,
k,
v,
key_cache,
value_cache,
out,
cu_seqlens,
cu_seqlens,
max_s,
max_s,
seqlen.cu_seqlen_q,
seqlen.cu_seqlen_q,
seqlen.max_q,
seqlen.max_k,
causal,
softmax_scale,
)
Expand Down
Loading
Loading