Skip to content

Commit

Permalink
Bump vLLM to 0.5.3.post1 (kserve#3828)
Browse files Browse the repository at this point in the history
* Bump vLLM to 0.5.3.post1

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Update makefile

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* approx probability comparison

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

* Set multiprocessing method to spawn

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>

---------

Signed-off-by: Sivanantham Chinnaiyan <[email protected]>
  • Loading branch information
sivanantha321 authored Aug 2, 2024
1 parent f9e7d5a commit 96fb00e
Show file tree
Hide file tree
Showing 11 changed files with 1,152 additions and 691 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

# Base Image URL
BASE_IMG ?= python:3.9-slim-bullseye
PMML_BASE_IMG ?= openjdk:11-slim
BASE_IMG ?= python:3.11-slim-bookworm
PMML_BASE_IMG ?= openjdk:21-slim-bookworm

# Image URL to use all building/pushing image targets
IMG ?= kserve-controller:latest
Expand Down
5 changes: 4 additions & 1 deletion python/huggingface_server.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ ARG POETRY_HOME=/opt/poetry
ARG POETRY_VERSION=1.8.3

# Install vllm
ARG VLLM_VERSION=0.4.3
ARG VLLM_VERSION=0.5.3.post1

RUN apt-get update -y && apt-get install gcc python3.10-venv python3-dev -y && apt-get clean && \
rm -rf /var/lib/apt/lists/*
Expand Down Expand Up @@ -60,6 +60,9 @@ ENV SAFETENSORS_FAST_GPU="1"
ENV HF_HUB_DISABLE_TELEMETRY="1"
# NCCL Lib path for vLLM. https://github.com/vllm-project/vllm/blob/ec784b2526219cd96159a52074ab8cd4e684410a/vllm/utils.py#L598-L602
ENV VLLM_NCCL_SO_PATH="/lib/x86_64-linux-gnu/libnccl.so.2"
# https://github.com/vllm-project/vllm/issues/6152
# Set the multiprocess method to spawn to avoid issues with cuda initialization for `mp` executor backend.
ENV VLLM_WORKER_MULTIPROC_METHOD="spawn"

USER 1000
ENTRYPOINT ["python3", "-m", "huggingfaceserver"]
Expand Down
5 changes: 2 additions & 3 deletions python/huggingfaceserver/huggingfaceserver/encoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torch import Tensor
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
BatchEncoding,
PreTrainedModel,
Expand Down Expand Up @@ -125,12 +124,13 @@ def load(self) -> bool:
model_id_or_path = self.model_id_or_path

self.max_length = _get_and_verify_max_len(self.model_config, self.max_length)
model_cls = get_model_class_for_task(self.task)

# device_map = "auto" enables model parallelism but all model architcture dont support it.
# For pre-check we initialize the model class without weights to check the `_no_split_modules`
# device_map = "auto" for models that support this else set to either cuda/cpu
with init_empty_weights():
self._model = AutoModel.from_config(self.model_config)
self._model = model_cls.from_config(self.model_config)

device_map = self._device

Expand All @@ -157,7 +157,6 @@ def load(self) -> bool:

# load huggingface model using from_pretrained for inference mode
if not self.predictor_host:
model_cls = get_model_class_for_task(self.task)
self._model = model_cls.from_pretrained(
model_id_or_path,
revision=self.model_revision,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from kserve.constants.constants import LLM_STATS_KEY
from transformers import (
AutoConfig,
AutoModel,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
Expand Down Expand Up @@ -183,12 +182,13 @@ def load(self) -> bool:
model_id_or_path = self.model_id_or_path

self.max_length = _get_and_verify_max_len(self.model_config, self.max_length)
model_cls = get_model_class_for_task(self.task)

# device_map = "auto" enables model parallelism but all model architcture dont support it.
# For pre-check we initialize the model class without weights to check the `_no_split_modules`
# device_map = "auto" for models that support this else set to either cuda/cpu
with init_empty_weights():
self._model = AutoModel.from_config(self.model_config)
self._model = model_cls.from_config(self.model_config)

device_map = self._device

Expand Down Expand Up @@ -222,7 +222,6 @@ def load(self) -> bool:

logger.info("Successfully loaded tokenizer")
# load huggingface model using from_pretrained for inference mode
model_cls = get_model_class_for_task(self.task)
self._model = model_cls.from_pretrained(
model_id_or_path,
revision=self.model_revision,
Expand Down
129 changes: 97 additions & 32 deletions python/huggingfaceserver/huggingfaceserver/vllm/vllm_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

import asyncio
import time

import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
from typing import (
AsyncGenerator,
AsyncIterator,
Expand All @@ -27,10 +23,16 @@
Optional,
Tuple,
Union,
Iterator,
)

import torch
from vllm.inputs import parse_and_batch_prompt
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_completion import (
parse_prompt_format,
merge_async_iterators,
)
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -131,21 +133,20 @@ async def create_completion(self, completion_request: CompletionRequest):
generators = []
try:
sampling_params = to_sampling_params(request)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt
)
else:
input_ids, prompt_text = self._validate_prompt_and_tokenize(
request, prompt=prompt
)
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
request.prompt,
# TODO: Introduce vLLM specific sampling params
# truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
# add_special_tokens=request.add_special_tokens,
)
)

for i, prompt_inputs in enumerate(prompts):
generators.append(
self.engine.generate(
{"prompt": prompt_text, "prompt_token_ids": input_ids},
{"prompt_token_ids": prompt_inputs[0]},
sampling_params,
f"{request_id}-{i}",
)
Expand All @@ -167,6 +168,7 @@ async def create_completion(self, completion_request: CompletionRequest):
if stream:
return self.completion_stream_generator(
request,
prompts,
result_generator,
request_id,
created_time,
Expand All @@ -178,6 +180,8 @@ async def create_completion(self, completion_request: CompletionRequest):
final_res_batch: List[RequestOutput] = [None] * len(prompts)
try:
async for i, res in result_generator:
if res.prompt is None:
res.prompt = prompts[i][1]
final_res_batch[i] = res
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name
Expand All @@ -190,6 +194,7 @@ async def create_completion(self, completion_request: CompletionRequest):
async def completion_stream_generator(
self,
request: CreateCompletionRequest,
prompts: List[Tuple[List[int], str]],
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
Expand All @@ -202,6 +207,8 @@ async def completion_stream_generator(

try:
async for prompt_idx, res in result_generator:
if res.prompt is None:
res.prompt = prompts[prompt_idx][1]

for output in res.outputs:
i = output.index + prompt_idx * request.n
Expand All @@ -215,7 +222,7 @@ async def completion_stream_generator(
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
delta_token_ids = res.prompt_token_ids + list(output.token_ids)
top_logprobs = (res.prompt_logprobs or []) + (
output.logprobs or []
)
Expand Down Expand Up @@ -295,7 +302,7 @@ def request_output_to_completion_response(
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
token_ids = prompt_token_ids + list(output.token_ids)
top_logprobs = output.logprobs or prompt_logprobs
if output.logprobs and prompt_logprobs:
top_logprobs = prompt_logprobs + output.logprobs
Expand Down Expand Up @@ -362,22 +369,13 @@ async def _post_init(self):
revision=engine_model_config.tokenizer_revision,
)

def _validate_prompt_and_tokenize(
def _validate_input(
self,
request: Union[CreateCompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
request: CreateCompletionRequest,
input_ids: List[int],
input_text: str,
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise InvalidInput("Either prompt or prompt_ids should be provided.")
if prompt and prompt_ids:
raise InvalidInput("Only one of prompt or prompt_ids should be provided.")

input_ids = (
prompt_ids if prompt_ids is not None else self.tokenizer(prompt).input_ids
)
token_num = len(input_ids)
input_text = prompt if prompt is not None else self.tokenizer.decode(prompt_ids)

if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num
Expand All @@ -394,6 +392,73 @@ def _validate_prompt_and_tokenize(
else:
return input_ids, input_text

def _tokenize_prompt_input_or_inputs(
self,
request: CreateCompletionRequest,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
# TODO: Introduce vLLM specific sampling params
# truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[tuple[list[int], str]]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
prompt=prompt_input["content"],
# truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt_input["content"],
# truncate_prompt_tokens=truncate_prompt_tokens,
)

def _normalize_prompt_text_to_input(
self,
request: CreateCompletionRequest,
prompt: str,
# TODO: Introduce vLLM specific sampling params
# truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> tuple[list[int], str]:
encoded = self.tokenizer(prompt, add_special_tokens=add_special_tokens)
# if truncate_prompt_tokens is not None:
# encoded = self.tokenizer(prompt,
# add_special_tokens=add_special_tokens,
# truncation=True,
# max_length=truncate_prompt_tokens)

input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)

def _normalize_prompt_tokens_to_input(
self,
request: CreateCompletionRequest,
prompt_ids: List[int],
# TODO: Introduce vLLM specific sampling params
# truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> tuple[list[int], str]:
input_ids = prompt_ids
# if truncate_prompt_tokens is not None:
# input_ids = prompt_ids[-truncate_prompt_tokens:]

input_text = self.tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)

def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
Expand Down
Loading

0 comments on commit 96fb00e

Please sign in to comment.