Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ROCM libs and improvements #2358

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 40 additions & 12 deletions Dockerfile_amd
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ COPY launcher launcher
RUN cargo build --profile release-opt

# Text Generation Inference base image for RoCm
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update AS base
FROM rocm/dev-ubuntu-22.04:6.2 AS base

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
build-essential \
Expand All @@ -50,33 +50,34 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
curl \
git \
make \
libmsgpack-dev \
libssl-dev \
llvm-dev \
g++ \
# Needed to build VLLM & flash.
rocthrust-dev \
hipsparse-dev \
hipblas-dev \
hipblaslt-dev \
hipcub-dev \
rocblas-dev \
hiprand-dev \
hipfft-dev \
rocrand-dev \
miopen-hip-dev \
hipfft-dev \
hipcub-dev \
hipsolver-dev \
rccl-dev \
cmake \
python3.11-dev && \
python3.11-venv && \
rm -rf /var/lib/apt/lists/*

# Keep in sync with `server/pyproject.toml
ARG MAMBA_VERSION=23.1.0-1
ARG PYTORCH_VERSION='2.3.0'
ARG ROCM_VERSION='6.0.2'
ARG PYTHON_VERSION='3.11.10'
# Automatically set by buildx
ARG TARGETPLATFORM
ENV PATH /opt/conda/bin:$PATH
ENV PATH=/opt/conda/bin:$PATH

ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"

# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
# Install mamba
Expand All @@ -100,19 +101,38 @@ RUN case ${TARGETPLATFORM} in \
/opt/conda/bin/conda install -y "python=${PYTHON_VERSION}" ;; \
esac && \
/opt/conda/bin/conda clean -ya

# Install flash-attention, torch dependencies
RUN pip install numpy einops ninja --no-cache-dir
RUN pip install numpy einops ninja joblib msgpack cmake --no-cache-dir

# Install HIPBLASLt
ARG HIPBLASLT_BRANCH="6f65c6e"
RUN git clone https://github.com/ROCm/hipBLASLt \
&& cd hipBLASLt \
&& git checkout ${HIPBLASLT_BRANCH} \
&& SCCACHE_IDLE_TIMEOUT=1800 ./install.sh --architecture ${PYTORCH_ROCM_ARCH} \
&& cd build/release \
&& make package
RUN dpkg -i hipBLASLt/build/release/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status \
&& rm -rf hipBLASLt

RUN pip uninstall -y triton && \
git clone --depth 1 --single-branch https://github.com/ROCm/triton.git && \
cd triton/python && \
pip install .
pip install . && \
rm -r triton

RUN git clone --depth 1 --recursive --single-branch --branch 2.3-patched https://github.com/fxmarty/pytorch.git pytorch && cd pytorch && pip install -r requirements.txt --no-cache-dir
ARG PYTORCH_COMMIT="da320214e66b5af0f7db8fd18a64dbb519d17b27"
RUN git clone --depth 1 --recursive --single-branch --branch main https://github.com/pytorch/pytorch.git pytorch && \
cd pytorch && git fetch --depth 1 origin ${PYTORCH_COMMIT} && \
git checkout ${PYTORCH_COMMIT} && \
git submodule update --init --recursive && \
pip install -r requirements.txt --no-cache-dir

ARG _GLIBCXX_USE_CXX11_ABI="1"
ARG CMAKE_PREFIX_PATH="/opt/conda"
ARG PYTORCH_ROCM_ARCH="gfx90a;gfx942"
ARG BUILD_CAFFE2="0" \
BUILD_CAFFE2_OPS="0" \
USE_CUDA="0" \
Expand All @@ -126,6 +146,7 @@ ARG BUILD_CAFFE2="0" \
USE_MEM_EFF_ATTENTION="0"

RUN cd pytorch && python tools/amd_build/build_amd.py && python setup.py install
RUN rm -rf pytorch

# Set AS recommended: https://github.com/ROCm/triton/wiki/A-script-to-set-program-execution-environment-in-ROCm
ENV HIP_FORCE_DEV_KERNARG=1
Expand Down Expand Up @@ -224,6 +245,13 @@ ENTRYPOINT ["./entrypoint.sh"]
# Final image
FROM base-copy

ENV ROCM_USE_CUSTOM_PAGED_ATTN=1
ENV PYTORCH_TUNABLEOP_TUNING_AFTER_WARMUP=0
ENV VLLM_MOE_PADDING=0
ENV ATTENTION=paged
ENV USE_PREFIX_CACHING=0
ENV ROCM_USE_SKINNY_GEMM=1

COPY ./tgi-entrypoint.sh /tgi-entrypoint.sh
RUN chmod +x /tgi-entrypoint.sh

Expand Down
6 changes: 6 additions & 0 deletions docs/source/installation_amd.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Two implementations of Flash Attention are available for ROCm, the first is [ROC

By default, the Composable Kernel implementation is used. However, the Triton implementation has slightly lower latency on MI250 and MI300, but requires a warmup which can be prohibitive as it needs to be done again for each new prompt length. If needed, FA Triton impelmentation can be enabled with `--env ROCM_USE_FLASH_ATTN_V2_TRITON="0"` when launching TGI's docker container.

## Custom PagedAttention

For better performance on ROCm, a custom Paged Attention kernel is available and is enabled by default. To disable it and fall back to the PagedAttention v2 kernel, set the environment variable `ROCM_USE_CUSTOM_PAGED_ATTN=0`.

The custom kernel supports bf16 and fp16 data types, block size of 16, head size of 128, a maximum context length of 16k, and GQA ratios between 1 and 16. For other configurations, we use the PagedAttention v2 kernel.

## Unsupported features

The following features are currently not supported in the ROCm version of TGI, and the supported may be extended in the future:
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile-flash-att-v2
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
flash_att_v2_commit_cuda := v2.6.1
flash_att_v2_commit_rocm := 2554f490101742ccdc56620a938f847f61754be6
flash_att_v2_commit_rocm := 3cea2fb6ee54fb7e1aad9db6ac6c9331184b8647 # (Aug28)

build-flash-attention-v2-cuda:
pip install -U packaging wheel
Expand Down
4 changes: 2 additions & 2 deletions server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
commit_cuda := d243e9dc7e2c9c2e36a4150ec8e64809cb55c01b
commit_rocm := c6ee53b1be97e3bbc791b95f22827501297f8921
commit_rocm := 4e0929e6e4fa0a3d09d358715c288020ea9dc247
build-vllm-cuda:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
Expand All @@ -13,7 +13,7 @@ install-vllm-cuda: build-vllm-cuda
build-vllm-rocm:
if [ ! -d 'vllm' ]; then \
pip install -U ninja packaging --no-cache-dir && \
git clone https://github.com/fxmarty/rocm-vllm.git vllm; \
git clone https://github.com/mht-sharma/vllm.git vllm; \
fi
cd vllm && git fetch && git checkout $(commit_rocm) && \
PYTORCH_ROCM_ARCH="gfx90a;gfx942" python setup.py build
Expand Down
18 changes: 16 additions & 2 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,24 @@
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .rocm import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
elif SYSTEM == "ipex":
from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
from .ipex import (
attention,
paged_attention,
reshape_and_cache,
SUPPORTS_WINDOWING,
PREFILL_IN_KV_CACHE,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

Expand All @@ -25,5 +38,6 @@
"paged_attention",
"reshape_and_cache",
"SUPPORTS_WINDOWING",
"PREFILL_IN_KV_CACHE",
"Seqlen",
]
3 changes: 3 additions & 0 deletions server/text_generation_server/layers/attention/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models.globals import ATTENTION
import torch
from typing import Optional
Expand Down Expand Up @@ -65,5 +66,7 @@ class Seqlen:
max_k: int

def clamp(self, max):
if SYSTEM == "rocm":
return self
raise NotImplementedError("Not implemented seqlen for paged")
return Seqlen(torch.clamp(self.input_lengths, max=max))
6 changes: 6 additions & 0 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,9 @@ def attention(
None,
)
return out


# Prefill in the cache with every kind of attention, unless we
# have a configuration that requires flash-attention v1, which
# does not support block tables.
PREFILL_IN_KV_CACHE = ATTENTION != "paged" or V2
1 change: 1 addition & 0 deletions server/text_generation_server/layers/attention/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False


def attention(
Expand Down
Loading
Loading