Skip to content

Commit

Permalink
Reconcile merge diffs
Browse files Browse the repository at this point in the history
Also:
* Remove torchrun
* Remove cythonization of sampler
  • Loading branch information
mawong-amd committed Sep 3, 2024
1 parent 8295ea0 commit 922e143
Show file tree
Hide file tree
Showing 33 changed files with 524 additions and 696 deletions.
62 changes: 36 additions & 26 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ARG BUILD_TRITON="1"
# If "0", it is copied in from the local working directory.
ARG REMOTE_VLLM="0"


# -----------------------
# vLLM base image
FROM $BASE_IMAGE AS base
Expand All @@ -27,7 +28,23 @@ ENV PYTORCH_ROCM_ARCH=${ARG_PYTORCH_ROCM_ARCH}
# Install some basic utilities
RUN apt-get update -q -y && apt-get install -q -y python3 python3-pip
RUN apt-get update -q -y && apt-get install -q -y \
sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev
ccache sqlite3 libsqlite3-dev libfmt-dev libmsgpack-dev libsuitesparse-dev

ENV CCACHE_DIR=/root/.cache/ccache

RUN python3 -m pip install --upgrade pip
# Remove sccache so it doesn't interfere with ccache
# TODO: implement sccache support across components
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
# Install torch == 2.5.0 on ROCm
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.1"*) \
python3 -m pip uninstall -y torch torchvision \
&& python3 -m pip install --no-cache-dir --pre \
torch==2.5.0.dev20240726 \
torchvision==0.20.0.dev20240726 \
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
*) ;; esac

ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/bin:
Expand All @@ -36,6 +53,7 @@ ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/opt/conda/envs/py_3.9/lib/python3.9/

WORKDIR ${COMMON_WORKDIR}


# -----------------------
# hipBLASLt build stages
FROM base AS build_hipblaslt
Expand All @@ -52,6 +70,7 @@ COPY --from=build_hipblaslt ${COMMON_WORKDIR}/hipBLASLt/build/release/*.deb /
FROM scratch AS export_hipblaslt_0
FROM export_hipblaslt_${BUILD_HIPBLASLT} AS export_hipblaslt


# -----------------------
# RCCL build stages
FROM base AS build_rccl
Expand All @@ -66,14 +85,15 @@ COPY --from=build_rccl ${COMMON_WORKDIR}/rccl/build/release/*.deb /
FROM scratch AS export_rccl_0
FROM export_rccl_${BUILD_RCCL} AS export_rccl


# -----------------------
# flash attn build stages
FROM base AS build_flash_attn
ARG FA_BRANCH="ae7928c"
ARG FA_BRANCH="23a2b1c2"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
RUN git clone ${FA_REPO} \
&& cd flash-attention \
&& git checkout ${FA_BRANCH} \
&& git checkout "${FA_BRANCH}" \
&& git submodule update --init \
&& GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist
FROM scratch AS export_flash_attn_1
Expand All @@ -82,10 +102,11 @@ COPY --from=build_flash_attn ${COMMON_WORKDIR}/flash-attention/dist/*.whl /
FROM scratch AS export_flash_attn_0
FROM export_flash_attn_${BUILD_FA} AS export_flash_attn


# -----------------------
# Triton build stages
FROM base AS build_triton
ARG TRITON_BRANCH="6ddb79b"
ARG TRITON_BRANCH="e0fc12c"
ARG TRITON_REPO="https://github.com/OpenAI/triton.git"
RUN git clone ${TRITON_REPO} \
&& cd triton \
Expand All @@ -98,13 +119,15 @@ COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl /
FROM scratch AS export_triton_0
FROM export_triton_${BUILD_TRITON} AS export_triton


# AMD-SMI build stages
FROM base AS build_amdsmi
RUN cd /opt/rocm/share/amd_smi \
&& pip wheel . --wheel-dir=dist
FROM scratch AS export_amdsmi
COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl /


# -----------------------
# vLLM (and gradlib) fetch stages
FROM base AS fetch_vllm_0
Expand All @@ -117,6 +140,7 @@ ONBUILD RUN git clone ${VLLM_REPO} \
&& git checkout ${VLLM_BRANCH}
FROM fetch_vllm_${REMOTE_VLLM} AS fetch_vllm


# -----------------------
# vLLM (and gradlib) build stages
FROM fetch_vllm AS build_vllm
Expand All @@ -130,7 +154,8 @@ if ls /install/*.deb; then \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
# Build vLLM
RUN cd vllm \
RUN --mount=type=cache,target=/root/.cache/ccache \
cd vllm \
&& python3 setup.py clean --all && python3 setup.py bdist_wheel --dist-dir=dist
# Build gradlib
RUN cd vllm/gradlib \
Expand All @@ -154,20 +179,9 @@ ARG COMMON_WORKDIR
ARG BUILD_FA

RUN python3 -m pip install --upgrade pip && rm -rf /var/lib/apt/lists/*
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
# Manually remove it so that later steps of numpy upgrade can continue
RUN case "$(which python3)" in \
*"/opt/conda/envs/py_3.9"*) \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
*) ;; esac

RUN --mount=type=bind,from=export_hipblaslt,src=/,target=/install \
if ls /install/*.deb; then \
apt-get purge -y hipblaslt \
&& dpkg -i /install/*.deb \
&& sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \
&& sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status; \
fi
# Package upgrades for useful functionality or to avoid dependency issues
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

RUN --mount=type=bind,from=export_rccl,src=/,target=/install \
if ls /install/*.deb; then \
Expand Down Expand Up @@ -200,16 +214,14 @@ RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \
RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli]

# Install vLLM (and gradlib)
# Make sure punica kernels are built (for LoRA)
ENV VLLM_INSTALL_PUNICA_KERNELS=1
RUN --mount=type=bind,from=export_vllm,src=/,target=/install \
cd /install \
&& pip install -U -r requirements-rocm.txt \
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
*"rocm-6.0"*) \
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
*"rocm-6.1"*) \
cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6;; \
cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
# Prevent interference if torch bundles its own HIP runtime
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
*) ;; esac \
&& pip uninstall -y vllm gradlib \
&& pip install *.whl
Expand All @@ -220,12 +232,10 @@ COPY --from=export_vllm /tests ${COMMON_WORKDIR}/vllm/tests
COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples
COPY --from=export_vllm /.buildkite ${COMMON_WORKDIR}/vllm/.buildkite


ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
ENV TOKENIZERS_PARALLELISM=false

# Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1

CMD ["/bin/bash"]

12 changes: 5 additions & 7 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
# Overview of the optional performance features uinque to https://github.com/ROCm/vllm
## Multi-GPU torchrun
On ROCm the default multi GPU executor is `torchrun` as opposed to `ray` on NVIDIA
This can be overridden by the `--worker-use-ray` flag to vllm or its benchmarks
To utilize torchran parallelism, the run command should be modified from
`python <command>`
to
`torchrun --standalone --nnodes=1 --nproc-per-node=<world-size> <command>`
## Triton attention
The default attention function on ROCm is using triton attention kernel. To fallback to the https://github.com/ROCm/flash-attention implementation set up the following environment symbol:
`VLLM_USE_TRITON_FLASH_ATTN=0`
Expand Down Expand Up @@ -53,3 +46,8 @@ python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_f
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.

Now, when running inference with fp8, we are using the tuned gemm for best performance.

## NCCL Performance environment variable

For MI300x, setting environment variable NCCL_MIN_NCHANNELS=112 is expected to improve performance.

31 changes: 15 additions & 16 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>

namespace py = pybind11;
#include "core/registration.h"

// declare templates for front (cpp) and back (cuda) sides of function:
// template <typename T>

void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block);
void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block) {
void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block) {
auto M = in_a.size(0);
auto K = in_a.size(1);
LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K,
Expand All @@ -21,10 +20,10 @@ void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block);

// template <typename T>
void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block) {
int M = in_a.size(0);
int K = in_a.size(1);
void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block = 4) {
auto M = in_a.size(0);
auto K = in_a.size(1);
// if (N != in_b.numel())
// throw std::invalid_argument("Size mismatch A.numel(): " +
// std::to_string(in_a.numel())
Expand All @@ -41,10 +40,10 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K,
const int N, cudaStream_t stream, const int CuCount);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in,
int64_t CuCount) {
int M = in_a.size(0);
int K = in_a.size(1);
void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
const int64_t N_in, const int64_t CuCount) {
auto M = in_a.size(0);
auto K = in_a.size(1);
int N = N_in;
wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N,
at::cuda::getCurrentCUDAStream(), CuCount);
Expand All @@ -54,9 +53,9 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
const int solidx = 0) {
int M = in_a.size(0);
int K = in_a.size(1);
const int64_t solidx = 0) {
auto M = in_a.size(0);
auto K = in_a.size(1);

LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K,
at::cuda::getCurrentCUDAStream(), solidx);
Expand All @@ -69,7 +68,7 @@ void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows,
int numAColumns, int numBRows, int numBColumns, int numCRows,
int numCColumns, cudaStream_t stream);

void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) {
void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) {
auto matA_sizes{in_a.sizes()};
auto matB_sizes{in_b.sizes()};
auto matO_sizes{out_c.sizes()};
Expand Down
3 changes: 1 addition & 2 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <cuda_fp16.h>
#include <stdexcept>
#include <algorithm>
#include "cuda_compat.h"

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
Expand All @@ -17,8 +18,6 @@
#define UNREACHABLE_CODE assert(false);
#endif

constexpr int WARP_SIZE = 64;

template <typename T>
__device__ __forceinline__ T loadnt(T* addr) {
return __builtin_nontemporal_load(addr);
Expand Down
12 changes: 6 additions & 6 deletions csrc/custom/custom_ops.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#pragma once
#include <torch/all.h>

void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block);
void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block);

void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
int64_t rows_per_block);
void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t rows_per_block);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, int64_t N_in,
int64_t CuCount);
void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t N_in, const int64_t CuCount);

void paged_attention_custom(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
Expand Down
2 changes: 1 addition & 1 deletion csrc/custom/paged_attention/attention_ll4mi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"

#include <algorithm>

Expand All @@ -23,7 +24,6 @@
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define WARP_SIZE 64

#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support

Expand Down
35 changes: 17 additions & 18 deletions csrc/custom/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,28 @@

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, custom_ops) {
custom_ops.def(
"LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()"
);
"LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block=4) -> "
"()");
custom_ops.impl("LLMM1", torch::kCUDA, &LLMM1);
custom_ops.def(
"LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> ()"
);
"LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) "
"-> ()");
custom_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu);
custom_ops.def(
"paged_attention_custom(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()"
);
custom_ops.impl("paged_attention_custom", torch::kCUDA, &paged_attention_custom);
"paged_attention_custom(Tensor! out, Tensor exp_sums,"
" Tensor max_logits, Tensor tmp_out,"
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype) -> ()");
custom_ops.impl("paged_attention_custom", torch::kCUDA,
&paged_attention_custom);
custom_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
" int CuCount) -> ()"
);
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
" int CuCount) -> ()");
custom_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
Loading

0 comments on commit 922e143

Please sign in to comment.