Skip to content

Commit

Permalink
Support for guided decoding for offline LLM (vllm-project#6878)
Browse files Browse the repository at this point in the history
Co-authored-by: Cyrus Leung <[email protected]>
  • Loading branch information
2 people authored and kylesayrs committed Aug 17, 2024
1 parent 53426ea commit 53c0dd4
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 12 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def setup(app):
"tqdm",
"tensorizer",
"pynvml",
"outlines",
]

for mock_target in autodoc_mock_imports:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import pytest


@pytest.fixture
def sample_prompts():
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


@pytest.fixture
def sample_token_ids():
return [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]


@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
Expand Down Expand Up @@ -66,4 +86,4 @@ def sample_sql_statements():
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")
""")
142 changes: 142 additions & 0 deletions tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import json
import re
import weakref

import jsonschema
import pytest

from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

from ...conftest import cleanup

MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"


@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME, max_model_len=1024)

with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()


@pytest.mark.skip_global_cleanup
def test_guided_regex(sample_regex, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_regex=sample_regex))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert re.fullmatch(sample_regex, generated_text) is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_json_completion(sample_json_schema, llm):
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
)
outputs = llm.generate(
prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_json=sample_json_schema))

assert outputs is not None

for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)


@pytest.mark.skip_global_cleanup
def test_guided_choice_completion(sample_guided_choice, llm):
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_choice=sample_guided_choice))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")


@pytest.mark.skip_global_cleanup
def test_guided_grammar(sample_sql_statements, llm):

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=1000,
)
outputs = llm.generate(
prompts=("Generate a sql state that select col_1 from "
"table_1 where it is equals to 1"),
sampling_params=sampling_params,
use_tqdm=True,
guided_options_request=dict(guided_grammar=sample_sql_statements))

assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt

generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_statements)
parser.parse(generated_text)

# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")

assert generated_text.strip() == ground_truth

print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
44 changes: 43 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import (
GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
Expand Down Expand Up @@ -262,6 +265,8 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
Expand Down Expand Up @@ -303,6 +308,14 @@ def generate(
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)

if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
raise ValueError(
"You can only use one guided decoding but multiple is "
f"specified: {guided_options_request}")
guided_options_request = GuidedDecodingRequest(
**guided_options_request)

if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
Expand All @@ -311,7 +324,8 @@ def generate(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request)

outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
Expand Down Expand Up @@ -508,6 +522,7 @@ def _validate_and_add_requests(
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
Expand All @@ -523,6 +538,15 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
self._add_request(
Expand All @@ -548,6 +572,24 @@ def _add_request(
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)

def _add_guided_processor(
self,
params: SamplingParams,
guided_options: Optional[GuidedDecodingRequest] = None):
if guided_options:
if guided_options.guided_decoding_backend is None:
decoding_config = self.llm_engine.get_decoding_config()
guided_options.guided_decoding_backend = (
decoding_config.guided_decoding_backend)
guided_logits_processor = get_local_guided_decoding_logits_processor( #noqa
guided_options.guided_decoding_backend, guided_options,
self.get_tokenizer())
if guided_logits_processor:
if params.logits_processors is None:
params.logits_processors = []
params.logits_processors.append(guided_logits_processor)
return params

def _run_engine(
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
Expand Down
26 changes: 20 additions & 6 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union

import torch
Expand All @@ -14,6 +15,23 @@
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid

# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)

try:
from sphinx.ext.autodoc.mock import _MockModule

if isinstance(torch, _MockModule):
_LONG_INFO = _MOCK_LONG_INFO
else:
_LONG_INFO = torch.iinfo(torch.long)
except ModuleNotFoundError:
_LONG_INFO = torch.iinfo(torch.long)

assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max


class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
Expand Down Expand Up @@ -108,9 +126,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
Expand Down Expand Up @@ -327,9 +343,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens: Optional[int] = 16
n: int = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None,
ge=torch.iinfo(torch.long).min,
le=torch.iinfo(torch.long).max)
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None
Expand Down
26 changes: 24 additions & 2 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
get_lm_format_enforcer_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor

Expand All @@ -20,6 +21,8 @@ async def get_guided_decoding_logits_processor(
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_lm_format_enforcer_guided_decoding_logits_processor)
return await get_lm_format_enforcer_guided_decoding_logits_processor(
request, tokenizer)

Expand All @@ -28,6 +31,25 @@ async def get_guided_decoding_logits_processor(
"Must be one of 'outlines, 'lm-format-enforcer'")


def get_local_guided_decoding_logits_processor(
guided_decoding_backend: str, guided_options: GuidedDecodingRequest,
tokenizer) -> Optional[LogitsProcessor]:
# request = _adapt_request_for_tool_use(request)

if guided_decoding_backend == 'outlines':
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_options, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer'")


def _adapt_request_for_tool_use(request: Union[CompletionRequest,
ChatCompletionRequest]):
# the legacy completion API does not support tool use
Expand Down
Loading

0 comments on commit 53c0dd4

Please sign in to comment.