From 47632a7fe18f1fbe8464626641ed3b9a871f07d4 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Mon, 30 Dec 2024 17:31:32 -0700 Subject: [PATCH 1/8] :recycle: consolidate models in api layer Signed-off-by: Joe Runde --- vllm/entrypoints/openai/api_server.py | 62 ++--- vllm/entrypoints/openai/run_batch.py | 15 +- vllm/entrypoints/openai/serving_chat.py | 14 +- vllm/entrypoints/openai/serving_completion.py | 14 +- vllm/entrypoints/openai/serving_embedding.py | 9 +- vllm/entrypoints/openai/serving_engine.py | 232 ++++++++++-------- vllm/entrypoints/openai/serving_pooling.py | 9 +- vllm/entrypoints/openai/serving_score.py | 9 +- .../openai/serving_tokenization.py | 12 +- 9 files changed, 200 insertions(+), 176 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index bac72d87376da..5f68fe8aa8951 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -58,7 +58,9 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + OpenAIServing, + OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( @@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing: return tokenization(request) +def models(request: Request) -> OpenAIServingModels: + return request.app.state.openai_serving_models + + def chat(request: Request) -> Optional[OpenAIServingChat]: return request.app.state.openai_serving_chat @@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request): @router.get("/v1/models") async def show_available_models(raw_request: Request): - handler = base(raw_request) + handler = models(raw_request) - models = await handler.show_available_models() - return JSONResponse(content=models.model_dump()) + models_ = await handler.show_available_models() + return JSONResponse(content=models_.model_dump()) @router.get("/version") @@ -505,26 +511,22 @@ async def stop_profile(raw_request: Request): @router.post("/v1/load_lora_adapter") async def load_lora_adapter(request: LoadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.load_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @router.post("/v1/unload_lora_adapter") async def unload_lora_adapter(request: UnloadLoraAdapterRequest, raw_request: Request): - for route in [chat, completion, embedding]: - handler = route(raw_request) - if handler is not None: - response = await handler.unload_lora_adapter(request) - if isinstance(response, ErrorResponse): - return JSONResponse(content=response.model_dump(), - status_code=response.code) + handler = models(raw_request) + response = await handler.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) return Response(status_code=200, content=response) @@ -628,13 +630,18 @@ def init_app_state( resolved_chat_template = load_chat_template(args.chat_template) logger.info("Using supplied chat template:\n%s", resolved_chat_template) + state.openai_serving_models = OpenAIServingModels( + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + ) + # TODO: The chat template is now broken for lora adapters :( state.openai_serving_chat = OpenAIServingChat( engine_client, model_config, - base_model_paths, + state.openai_serving_models, args.response_role, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -646,16 +653,14 @@ def init_app_state( state.openai_serving_completion = OpenAIServingCompletion( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, - prompt_adapters=args.prompt_adapters, + state.openai_serving_models, request_logger=request_logger, return_tokens_as_token_ids=args.return_tokens_as_token_ids, ) if model_config.runner_type == "generate" else None state.openai_serving_pooling = OpenAIServingPooling( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -663,7 +668,7 @@ def init_app_state( state.openai_serving_embedding = OpenAIServingEmbedding( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, @@ -671,14 +676,13 @@ def init_app_state( state.openai_serving_scores = OpenAIServingScores( engine_client, model_config, - base_model_paths, + state.openai_serving_models, request_logger=request_logger ) if model_config.task == "score" else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, - base_model_paths, - lora_modules=args.lora_modules, + state.openai_serving_models, request_logger=request_logger, chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 572ed27b39083..6c274f8d37548 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,8 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + OpenAIServingModels) from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION @@ -213,13 +214,17 @@ async def main(args): request_logger = RequestLogger(max_log_len=args.max_log_len) # Create the openai serving objects. + openai_serving_models = OpenAIServingModels( + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + ) openai_serving_chat = OpenAIServingChat( engine, model_config, - base_model_paths, + openai_serving_models, args.response_role, - lora_modules=None, - prompt_adapters=None, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", @@ -228,7 +233,7 @@ async def main(args): openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, - base_model_paths, + openai_serving_models, request_logger=request_logger, chat_template=None, chat_template_content_format="auto", diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d085333563d19..32bca3aee6f68 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -21,10 +21,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing, - PromptAdapterPath) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput @@ -42,11 +40,9 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, response_role: str, *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, @@ -57,9 +53,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=prompt_adapters, + models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index aaad7b8c7f44c..6ec6455adc670 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -21,10 +21,8 @@ RequestResponseMetadata, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing, - PromptAdapterPath) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams @@ -41,18 +39,14 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], return_tokens_as_token_ids: bool = False, ): super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=prompt_adapters, + models=models, request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) diff_sampling_param = self.model_config.get_diff_sampling_param() diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index b8fb9d6bd77f2..e32aee90be945 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -16,7 +16,8 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) @@ -46,7 +47,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], @@ -54,9 +55,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5b6a089e4c319..439d17df775d4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -90,26 +90,27 @@ class TextTokensPrompt(TypedDict): RequestPrompt = Union[List[int], str, TextTokensPrompt] -class OpenAIServing: +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ def __init__( self, - engine_client: EngineClient, model_config: ModelConfig, base_model_paths: List[BaseModelPath], *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - return_tokens_as_token_ids: bool = False, ): super().__init__() - self.engine_client = engine_client - self.model_config = model_config - self.max_model_len = model_config.max_model_len - self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] @@ -139,19 +140,9 @@ def __init__( prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_num_virtual_tokens=num_virtual_tokens)) - self.request_logger = request_logger - self.return_tokens_as_token_ids = return_tokens_as_token_ids - - self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) - - self._tokenize_prompt_input_async = make_async( - self._tokenize_prompt_input, executor=self._tokenizer_executor) - self._tokenize_prompt_input_or_inputs_async = make_async( - self._tokenize_prompt_input_or_inputs, - executor=self._tokenizer_executor) - async def show_available_models(self) -> ModelList: - """Show available models. Right now we only have one model.""" + """Show available models. This includes the base model and all + adapters""" model_cards = [ ModelCard(id=base_model.name, max_model_len=self.max_model_len, @@ -177,6 +168,110 @@ async def show_available_models(self) -> ModelList: model_cards.extend(prompt_adapter_cards) return ModelList(data=model_cards) + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + +class OpenAIServing: + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + models: OpenAIServingModels, + *, + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__() + + self.engine_client = engine_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len + + self.models = models + + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + + self._tokenizer_executor = ThreadPoolExecutor(max_workers=1) + + self._tokenize_prompt_input_async = make_async( + self._tokenize_prompt_input, executor=self._tokenizer_executor) + self._tokenize_prompt_input_or_inputs_async = make_async( + self._tokenize_prompt_input_or_inputs, + executor=self._tokenizer_executor) + def create_error_response( self, message: str, @@ -205,11 +300,13 @@ async def _check_model( ) -> Optional[ErrorResponse]: if self._is_model_supported(request.model): return None - if request.model in [lora.lora_name for lora in self.lora_requests]: + if request.model in [ + lora.lora_name for lora in self.models.lora_requests + ]: return None if request.model in [ prompt_adapter.prompt_adapter_name - for prompt_adapter in self.prompt_adapter_requests + for prompt_adapter in self.models.prompt_adapter_requests ]: return None return self.create_error_response( @@ -223,10 +320,10 @@ def _maybe_get_adapters( None, PromptAdapterRequest]]: if self._is_model_supported(request.model): return None, None - for lora in self.lora_requests: + for lora in self.models.lora_requests: if request.model == lora.lora_name: return lora, None - for prompt_adapter in self.prompt_adapter_requests: + for prompt_adapter in self.models.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable @@ -588,81 +685,9 @@ def _get_decoded_token(logprob: Logprob, return logprob.decoded_token return tokenizer.decode(token_id) - async def _check_load_lora_adapter_request( - self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if both 'lora_name' and 'lora_path' are provided - if not request.lora_name or not request.lora_path: - return self.create_error_response( - message="Both 'lora_name' and 'lora_path' must be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name already exists - if any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return self.create_error_response( - message= - f"The lora adapter '{request.lora_name}' has already been" - "loaded.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - async def _check_unload_lora_adapter_request( - self, - request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if either 'lora_name' or 'lora_int_id' is provided - if not request.lora_name and not request.lora_int_id: - return self.create_error_response( - message= - "either 'lora_name' and 'lora_int_id' needs to be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name exists - if not any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return self.create_error_response( - message= - f"The lora adapter '{request.lora_name}' cannot be found.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - async def load_lora_adapter( - self, - request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_load_lora_adapter_request(request) - if error_check_ret is not None: - return error_check_ret - - lora_name, lora_path = request.lora_name, request.lora_path - unique_id = self.lora_id_counter.inc(1) - self.lora_requests.append( - LoRARequest(lora_name=lora_name, - lora_int_id=unique_id, - lora_path=lora_path)) - return f"Success: LoRA adapter '{lora_name}' added successfully." - - async def unload_lora_adapter( - self, - request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_unload_lora_adapter_request(request - ) - if error_check_ret is not None: - return error_check_ret - - lora_name = request.lora_name - self.lora_requests = [ - lora_request for lora_request in self.lora_requests - if lora_request.lora_name != lora_name - ] - return f"Success: LoRA adapter '{lora_name}' removed successfully." - def _is_model_supported(self, model_name): - return any(model.name == model_name for model in self.base_model_paths) + return any(model.name == model_name + for model in self.models.base_model_paths) def _get_model_name(self, lora: Optional[LoRARequest]): """ @@ -675,4 +700,13 @@ def _get_model_name(self, lora: Optional[LoRARequest]): """ if lora is not None: return lora.lora_name - return self.base_model_paths[0].name + return self.models.base_model_paths[0].name + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 01852f0df1eca..9b78eef08869e 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -15,7 +15,8 @@ PoolingChatRequest, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators @@ -44,7 +45,7 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], chat_template: Optional[str], @@ -52,9 +53,7 @@ def __init__( ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index a8a126e697641..671c24dcf76d4 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -10,7 +10,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, ScoreResponse, ScoreResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput @@ -50,15 +51,13 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, request_logger: Optional[RequestLogger], ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=None, - prompt_adapters=None, + models=models, request_logger=request_logger) async def create_score( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2e849333680d4..45ac883b2294c 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -15,9 +15,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - OpenAIServing) +from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + OpenAIServingModels) from vllm.logger import init_logger logger = init_logger(__name__) @@ -29,18 +28,15 @@ def __init__( self, engine_client: EngineClient, model_config: ModelConfig, - base_model_paths: List[BaseModelPath], + models: OpenAIServingModels, *, - lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, ) -> None: super().__init__(engine_client=engine_client, model_config=model_config, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - prompt_adapters=None, + models=models, request_logger=request_logger) self.chat_template = chat_template From 43a9ef20cac032f1aa5fd295105a7b082827e1d0 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 10:57:59 -0700 Subject: [PATCH 2/8] :recycle: is supported -> is base model Signed-off-by: Joe Runde --- vllm/entrypoints/openai/serving_engine.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 439d17df775d4..84e8c60412e65 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -121,7 +121,7 @@ def __init__( lora_path=lora.path, base_model_name=lora.base_model_name if lora.base_model_name - and self._is_model_supported(lora.base_model_name) + and self.is_base_model(lora.base_model_name) else self.base_model_paths[0].name) for i, lora in enumerate(lora_modules, start=1) ] @@ -139,6 +139,9 @@ def __init__( prompt_adapter_id=i, prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + def is_base_model(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) async def show_available_models(self) -> ModelList: """Show available models. This includes the base model and all @@ -686,8 +689,7 @@ def _get_decoded_token(logprob: Logprob, return tokenizer.decode(token_id) def _is_model_supported(self, model_name): - return any(model.name == model_name - for model in self.models.base_model_paths) + return self.models.is_base_model(model_name) def _get_model_name(self, lora: Optional[LoRARequest]): """ From 9a44d9d91ca368e22e633c18c544e001d0d11c4b Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 12:52:14 -0700 Subject: [PATCH 3/8] :test_tube: Test dynamic lora loading Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_lora_lineage.py | 32 +++++++++++++++++-- vllm/entrypoints/openai/serving_engine.py | 6 ++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/test_lora_lineage.py b/tests/entrypoints/openai/test_lora_lineage.py index ab39684c2f31a..ce4f85c13fff9 100644 --- a/tests/entrypoints/openai/test_lora_lineage.py +++ b/tests/entrypoints/openai/test_lora_lineage.py @@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files): "64", ] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + # Enable the /v1/load_lora_adapter endpoint + envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"} + + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: yield remote_server @@ -67,8 +70,8 @@ async def client_for_lora_lineage(server_with_lora_modules_json): @pytest.mark.asyncio -async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, - zephyr_lora_files): +async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, + zephyr_lora_files): models = await client_for_lora_lineage.models.list() models = models.data served_model = models[0] @@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI, assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models) assert lora_models[0].id == "zephyr-lora" assert lora_models[1].id == "zephyr-lora2" + + +@pytest.mark.asyncio +async def test_dynamic_lora_lineage( + client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files): + + response = await client_for_lora_lineage.post("load_lora_adapter", + cast_to=str, + body={ + "lora_name": + "zephyr-lora-3", + "lora_path": + zephyr_lora_files + }) + # Ensure adapter loads before querying /models + assert "success" in response + + models = await client_for_lora_lineage.models.list() + models = models.data + dynamic_lora_model = models[-1] + assert dynamic_lora_model.root == zephyr_lora_files + assert dynamic_lora_model.parent == MODEL_NAME + assert dynamic_lora_model.id == "zephyr-lora-3" diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 84e8c60412e65..00c39a057a56a 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -121,8 +121,8 @@ def __init__( lora_path=lora.path, base_model_name=lora.base_model_name if lora.base_model_name - and self.is_base_model(lora.base_model_name) - else self.base_model_paths[0].name) + and self.is_base_model(lora.base_model_name) else + self.base_model_paths[0].name) for i, lora in enumerate(lora_modules, start=1) ] @@ -139,7 +139,7 @@ def __init__( prompt_adapter_id=i, prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_num_virtual_tokens=num_virtual_tokens)) - + def is_base_model(self, model_name): return any(model.name == model_name for model in self.base_model_paths) From 5a399623e33baadff243830e76f17614c2a70052 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 14:28:02 -0700 Subject: [PATCH 4/8] :white_check_mark: fixup tests Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_serving_chat.py | 20 +++--- ...rving_engine.py => test_serving_models.py} | 65 +++++++++---------- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 29 ++++----- 5 files changed, 58 insertions(+), 60 deletions(-) rename tests/entrypoints/openai/{test_serving_engine.py => test_serving_models.py} (63%) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 61677b65af342..a48b575b4fd97 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -8,7 +8,8 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" @@ -50,14 +51,13 @@ async def _async_serving_chat_init(): engine = MockEngine() model_config = await engine.get_model_config() + models = OpenAIServingModels(model_config, BASE_MODEL_PATHS) serving_completion = OpenAIServingChat(engine, model_config, - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) return serving_completion @@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens(): mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) mock_engine.errored = False + models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=MockModelConfig()) serving_chat = OpenAIServingChat(mock_engine, MockModelConfig(), - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) req = ChatCompletionRequest( model=MODEL_NAME, @@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config(): mock_engine.errored = False # Initialize the serving chat + models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) serving_chat = OpenAIServingChat(mock_engine, mock_model_config, - BASE_MODEL_PATHS, + models, response_role="assistant", chat_template=CHAT_TEMPLATE, chat_template_content_format="auto", - lora_modules=None, - prompt_adapters=None, request_logger=None) req = ChatCompletionRequest( model=MODEL_NAME, diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_models.py similarity index 63% rename from tests/entrypoints/openai/test_serving_engine.py rename to tests/entrypoints/openai/test_serving_models.py index 096ab6fa0ac09..28f290636c6a9 100644 --- a/tests/entrypoints/openai/test_serving_engine.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -8,7 +8,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServingModels from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-2-7b" @@ -19,47 +19,46 @@ "Success: LoRA adapter '{lora_name}' removed successfully.") -async def _async_serving_engine_init(): - mock_engine_client = MagicMock(spec=EngineClient) +async def _async_serving_models_init() -> OpenAIServingModels: mock_model_config = MagicMock(spec=ModelConfig) # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 - serving_engine = OpenAIServing(mock_engine_client, - mock_model_config, - BASE_MODEL_PATHS, - lora_modules=None, - prompt_adapters=None, - request_logger=None) - return serving_engine + serving_models = OpenAIServingModels( + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + lora_modules=None, + prompt_adapters=None) + + return serving_models @pytest.mark.asyncio async def test_serving_model_name(): - serving_engine = await _async_serving_engine_init() - assert serving_engine._get_model_name(None) == MODEL_NAME + serving_models = await _async_serving_models_init() + assert serving_models.model_name(None) == MODEL_NAME request = LoRARequest(lora_name="adapter", lora_path="/path/to/adapter2", lora_int_id=1) - assert serving_engine._get_model_name(request) == request.lora_name + assert serving_models.model_name(request) == request.lora_name @pytest.mark.asyncio async def test_load_lora_adapter_success(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter", lora_path="/path/to/adapter2") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') - assert len(serving_engine.lora_requests) == 1 - assert serving_engine.lora_requests[0].lora_name == "adapter" + assert len(serving_models.lora_requests) == 1 + assert serving_models.lora_requests[0].lora_name == "adapter" @pytest.mark.asyncio async def test_load_lora_adapter_missing_fields(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="", lora_path="") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST @@ -67,43 +66,43 @@ async def test_load_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_load_lora_adapter_duplicate(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert response == LORA_LOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') - assert len(serving_engine.lora_requests) == 1 + assert len(serving_models.lora_requests) == 1 request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) + response = await serving_models.load_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST - assert len(serving_engine.lora_requests) == 1 + assert len(serving_models.lora_requests) == 1 @pytest.mark.asyncio async def test_unload_lora_adapter_success(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = LoadLoraAdapterRequest(lora_name="adapter1", lora_path="/path/to/adapter1") - response = await serving_engine.load_lora_adapter(request) - assert len(serving_engine.lora_requests) == 1 + response = await serving_models.load_lora_adapter(request) + assert len(serving_models.lora_requests) == 1 request = UnloadLoraAdapterRequest(lora_name="adapter1") - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( lora_name='adapter1') - assert len(serving_engine.lora_requests) == 0 + assert len(serving_models.lora_requests) == 0 @pytest.mark.asyncio async def test_unload_lora_adapter_missing_fields(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST @@ -111,9 +110,9 @@ async def test_unload_lora_adapter_missing_fields(): @pytest.mark.asyncio async def test_unload_lora_adapter_not_found(): - serving_engine = await _async_serving_engine_init() + serving_models = await _async_serving_models_init() request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") - response = await serving_engine.unload_lora_adapter(request) + response = await serving_models.unload_lora_adapter(request) assert isinstance(response, ErrorResponse) assert response.type == "InvalidUserInput" assert response.code == HTTPStatus.BAD_REQUEST diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 32bca3aee6f68..a88f9d27fbd82 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -120,7 +120,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) - model_name = self._get_model_name(lora_request) + model_name = self.models.model_name(lora_request) tokenizer = await self.engine_client.get_tokenizer(lora_request) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 6ec6455adc670..e91944f114fe0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -164,7 +164,7 @@ async def create_completion( result_generator = merge_async_iterators(*generators) - model_name = self._get_model_name(lora_request) + model_name = self.models.model_name(lora_request) num_prompts = len(engine_prompts) # Similar to the OpenAI API, when n != best_of, we do not stream the diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 00c39a057a56a..3a1de7ad39d66 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -104,8 +104,8 @@ def __init__( model_config: ModelConfig, base_model_paths: List[BaseModelPath], *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], + lora_modules: Optional[List[LoRAModulePath]] = None, + prompt_adapters: Optional[List[PromptAdapterPath]] = None, ): super().__init__() @@ -143,6 +143,18 @@ def __init__( def is_base_model(self, model_name): return any(model.name == model_name for model in self.base_model_paths) + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + async def show_available_models(self) -> ModelList: """Show available models. This includes the base model and all adapters""" @@ -691,19 +703,6 @@ def _get_decoded_token(logprob: Logprob, def _is_model_supported(self, model_name): return self.models.is_base_model(model_name) - def _get_model_name(self, lora: Optional[LoRARequest]): - """ - Returns the appropriate model name depending on the availability - and support of the LoRA or base model. - Parameters: - - lora: LoRARequest that contain a base_model_name. - Returns: - - str: The name of the base model or the first available model path. - """ - if lora is not None: - return lora.lora_name - return self.models.base_model_paths[0].name - def create_error_response( message: str, From 898fe8b56cc605afaaf1a02cfd5888199048da17 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 14:32:58 -0700 Subject: [PATCH 5/8] :recycle: Move OpenAIServingModels to separate file Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_serving_chat.py | 4 +- .../entrypoints/openai/test_serving_models.py | 4 +- vllm/entrypoints/openai/api_server.py | 5 +- vllm/entrypoints/openai/run_batch.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 4 +- vllm/entrypoints/openai/serving_completion.py | 4 +- vllm/entrypoints/openai/serving_embedding.py | 4 +- vllm/entrypoints/openai/serving_engine.py | 179 +---------------- vllm/entrypoints/openai/serving_models.py | 185 ++++++++++++++++++ vllm/entrypoints/openai/serving_pooling.py | 4 +- vllm/entrypoints/openai/serving_score.py | 4 +- .../openai/serving_tokenization.py | 4 +- 12 files changed, 209 insertions(+), 196 deletions(-) create mode 100644 vllm/entrypoints/openai/serving_models.py diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index a48b575b4fd97..f1c855e417fb8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -8,8 +8,8 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 28f290636c6a9..78a01e8759c9d 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -4,11 +4,11 @@ import pytest from vllm.config import ModelConfig -from vllm.engine.protocol import EngineClient from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServingModels +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-2-7b" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5f68fe8aa8951..b1accad2fc92e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -58,9 +58,8 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 6c274f8d37548..13f4d507c6c32 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,8 +20,8 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a88f9d27fbd82..9ba5eeb7709c9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -21,8 +21,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e91944f114fe0..17197dce8da23 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -21,8 +21,8 @@ RequestResponseMetadata, UsageInfo) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams, SamplingParams diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e32aee90be945..e7116a3d95d10 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -16,8 +16,8 @@ EmbeddingResponse, EmbeddingResponseData, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput, PoolingRequestOutput) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3a1de7ad39d66..5840bd997cee9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,5 +1,4 @@ import json -import pathlib from concurrent.futures.thread import ThreadPoolExecutor from dataclasses import dataclass from http import HTTPStatus @@ -28,13 +27,10 @@ DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, - ErrorResponse, - LoadLoraAdapterRequest, - ModelCard, ModelList, - ModelPermission, ScoreRequest, + ErrorResponse, ScoreRequest, TokenizeChatRequest, - TokenizeCompletionRequest, - UnloadLoraAdapterRequest) + TokenizeCompletionRequest) +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser # yapf: enable from vllm.inputs import TokensPrompt @@ -48,7 +44,7 @@ from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid +from vllm.utils import is_list_of, make_async, random_uuid logger = init_logger(__name__) @@ -90,173 +86,6 @@ class TextTokensPrompt(TypedDict): RequestPrompt = Union[List[int], str, TextTokensPrompt] -class OpenAIServingModels: - """Shared instance to hold data about the loaded base model(s) and adapters. - - Handles the routes: - - /v1/models - - /v1/load_lora_adapter - - /v1/unload_lora_adapter - """ - - def __init__( - self, - model_config: ModelConfig, - base_model_paths: List[BaseModelPath], - *, - lora_modules: Optional[List[LoRAModulePath]] = None, - prompt_adapters: Optional[List[PromptAdapterPath]] = None, - ): - super().__init__() - - self.base_model_paths = base_model_paths - self.max_model_len = model_config.max_model_len - - self.lora_id_counter = AtomicCounter(0) - self.lora_requests = [] - if lora_modules is not None: - self.lora_requests = [ - LoRARequest(lora_name=lora.name, - lora_int_id=i, - lora_path=lora.path, - base_model_name=lora.base_model_name - if lora.base_model_name - and self.is_base_model(lora.base_model_name) else - self.base_model_paths[0].name) - for i, lora in enumerate(lora_modules, start=1) - ] - - self.prompt_adapter_requests = [] - if prompt_adapters is not None: - for i, prompt_adapter in enumerate(prompt_adapters, start=1): - with pathlib.Path(prompt_adapter.local_path, - "adapter_config.json").open() as f: - adapter_config = json.load(f) - num_virtual_tokens = adapter_config["num_virtual_tokens"] - self.prompt_adapter_requests.append( - PromptAdapterRequest( - prompt_adapter_name=prompt_adapter.name, - prompt_adapter_id=i, - prompt_adapter_local_path=prompt_adapter.local_path, - prompt_adapter_num_virtual_tokens=num_virtual_tokens)) - - def is_base_model(self, model_name): - return any(model.name == model_name for model in self.base_model_paths) - - def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: - """Returns the appropriate model name depending on the availability - and support of the LoRA or base model. - Parameters: - - lora: LoRARequest that contain a base_model_name. - Returns: - - str: The name of the base model or the first available model path. - """ - if lora_request is not None: - return lora_request.lora_name - return self.base_model_paths[0].name - - async def show_available_models(self) -> ModelList: - """Show available models. This includes the base model and all - adapters""" - model_cards = [ - ModelCard(id=base_model.name, - max_model_len=self.max_model_len, - root=base_model.model_path, - permission=[ModelPermission()]) - for base_model in self.base_model_paths - ] - lora_cards = [ - ModelCard(id=lora.lora_name, - root=lora.local_path, - parent=lora.base_model_name if lora.base_model_name else - self.base_model_paths[0].name, - permission=[ModelPermission()]) - for lora in self.lora_requests - ] - prompt_adapter_cards = [ - ModelCard(id=prompt_adapter.prompt_adapter_name, - root=self.base_model_paths[0].name, - permission=[ModelPermission()]) - for prompt_adapter in self.prompt_adapter_requests - ] - model_cards.extend(lora_cards) - model_cards.extend(prompt_adapter_cards) - return ModelList(data=model_cards) - - async def load_lora_adapter( - self, - request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_load_lora_adapter_request(request) - if error_check_ret is not None: - return error_check_ret - - lora_name, lora_path = request.lora_name, request.lora_path - unique_id = self.lora_id_counter.inc(1) - self.lora_requests.append( - LoRARequest(lora_name=lora_name, - lora_int_id=unique_id, - lora_path=lora_path)) - return f"Success: LoRA adapter '{lora_name}' added successfully." - - async def unload_lora_adapter( - self, - request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: - error_check_ret = await self._check_unload_lora_adapter_request(request - ) - if error_check_ret is not None: - return error_check_ret - - lora_name = request.lora_name - self.lora_requests = [ - lora_request for lora_request in self.lora_requests - if lora_request.lora_name != lora_name - ] - return f"Success: LoRA adapter '{lora_name}' removed successfully." - - async def _check_load_lora_adapter_request( - self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if both 'lora_name' and 'lora_path' are provided - if not request.lora_name or not request.lora_path: - return create_error_response( - message="Both 'lora_name' and 'lora_path' must be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name already exists - if any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return create_error_response( - message= - f"The lora adapter '{request.lora_name}' has already been" - "loaded.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - async def _check_unload_lora_adapter_request( - self, - request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: - # Check if either 'lora_name' or 'lora_int_id' is provided - if not request.lora_name and not request.lora_int_id: - return create_error_response( - message= - "either 'lora_name' and 'lora_int_id' needs to be provided.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - # Check if the lora adapter with the given name exists - if not any(lora_request.lora_name == request.lora_name - for lora_request in self.lora_requests): - return create_error_response( - message= - f"The lora adapter '{request.lora_name}' cannot be found.", - err_type="InvalidUserInput", - status_code=HTTPStatus.BAD_REQUEST) - - return None - - class OpenAIServing: def __init__( diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py new file mode 100644 index 0000000000000..cd4e2b63f56d0 --- /dev/null +++ b/vllm/entrypoints/openai/serving_models.py @@ -0,0 +1,185 @@ +import json +import pathlib +from http import HTTPStatus +from typing import List, Optional, Union + +from vllm.config import ModelConfig +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + ModelCard, ModelList, + ModelPermission, + UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, + PromptAdapterPath, + create_error_response) +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.utils import AtomicCounter + + +class OpenAIServingModels: + """Shared instance to hold data about the loaded base model(s) and adapters. + + Handles the routes: + - /v1/models + - /v1/load_lora_adapter + - /v1/unload_lora_adapter + """ + + def __init__( + self, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]] = None, + prompt_adapters: Optional[List[PromptAdapterPath]] = None, + ): + super().__init__() + + self.base_model_paths = base_model_paths + self.max_model_len = model_config.max_model_len + + self.lora_id_counter = AtomicCounter(0) + self.lora_requests = [] + if lora_modules is not None: + self.lora_requests = [ + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self.is_base_model(lora.base_model_name) else + self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) + ] + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + def is_base_model(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) + + def model_name(self, lora_request: Optional[LoRARequest] = None) -> str: + """Returns the appropriate model name depending on the availability + and support of the LoRA or base model. + Parameters: + - lora: LoRARequest that contain a base_model_name. + Returns: + - str: The name of the base model or the first available model path. + """ + if lora_request is not None: + return lora_request.lora_name + return self.base_model_paths[0].name + + async def show_available_models(self) -> ModelList: + """Show available models. This includes the base model and all + adapters""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None diff --git a/vllm/entrypoints/openai/serving_pooling.py b/vllm/entrypoints/openai/serving_pooling.py index 9b78eef08869e..5830322071e58 100644 --- a/vllm/entrypoints/openai/serving_pooling.py +++ b/vllm/entrypoints/openai/serving_pooling.py @@ -15,8 +15,8 @@ PoolingChatRequest, PoolingRequest, PoolingResponse, PoolingResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger from vllm.outputs import PoolingOutput, PoolingRequestOutput from vllm.utils import merge_async_iterators diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py index 671c24dcf76d4..5d3e7139d7a17 100644 --- a/vllm/entrypoints/openai/serving_score.py +++ b/vllm/entrypoints/openai/serving_score.py @@ -10,8 +10,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, ScoreResponse, ScoreResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import TokensPrompt from vllm.logger import init_logger from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 45ac883b2294c..b67ecfb01316f 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -15,8 +15,8 @@ TokenizeRequest, TokenizeResponse) # yapf: enable -from vllm.entrypoints.openai.serving_engine import (OpenAIServing, - OpenAIServingModels) +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.logger import init_logger logger = init_logger(__name__) From 8c1d8ff3d28486778425aa89dee9f9aeb4028c7a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 14:51:01 -0700 Subject: [PATCH 6/8] :art: fmt Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_serving_models.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 78a01e8759c9d..3385b47f965ff 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -24,11 +24,10 @@ async def _async_serving_models_init() -> OpenAIServingModels: # Set the max_model_len attribute to avoid missing attribute mock_model_config.max_model_len = 2048 - serving_models = OpenAIServingModels( - base_model_paths=BASE_MODEL_PATHS, - model_config=mock_model_config, - lora_modules=None, - prompt_adapters=None) + serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config, + lora_modules=None, + prompt_adapters=None) return serving_models From c12e632547e42cbc6995994e61660ca1655dbd0a Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 14:57:45 -0700 Subject: [PATCH 7/8] :recycle: move model path classes to serving_models Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_cli_args.py | 2 +- tests/entrypoints/openai/test_serving_chat.py | 4 +-- .../entrypoints/openai/test_serving_models.py | 2 +- vllm/entrypoints/openai/api_server.py | 5 +-- vllm/entrypoints/openai/cli_args.py | 2 +- vllm/entrypoints/openai/run_batch.py | 4 +-- vllm/entrypoints/openai/serving_engine.py | 30 ----------------- vllm/entrypoints/openai/serving_models.py | 33 ++++++++++++++++--- 8 files changed, 39 insertions(+), 43 deletions(-) diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index 45e6980a94630..e49562ad6a21f 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -4,7 +4,7 @@ from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) -from vllm.entrypoints.openai.serving_engine import LoRAModulePath +from vllm.entrypoints.openai.serving_models import LoRAModulePath from vllm.utils import FlexibleArgumentParser from ...utils import VLLM_PATH diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index f1c855e417fb8..97248f1150979 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -8,8 +8,8 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_engine import BaseModelPath -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 3385b47f965ff..341850db46c80 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -7,7 +7,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_models import BaseModelPath from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.lora.request import LoRARequest diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index b1accad2fc92e..74fe378fdae42 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -58,8 +58,9 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 908f8c3532c9e..22206ef8dbfe6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,7 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) -from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, +from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 13f4d507c6c32..822c0f5f7c211 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,8 +20,8 @@ # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_engine import BaseModelPath -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5840bd997cee9..319f869240036 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,6 +1,5 @@ import json from concurrent.futures.thread import ThreadPoolExecutor -from dataclasses import dataclass from http import HTTPStatus from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, TypedDict, Union) @@ -48,26 +47,6 @@ logger = init_logger(__name__) - -@dataclass -class BaseModelPath: - name: str - model_path: str - - -@dataclass -class PromptAdapterPath: - name: str - local_path: str - - -@dataclass -class LoRAModulePath: - name: str - path: str - base_model_name: Optional[str] = None - - CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest, EmbeddingCompletionRequest, ScoreRequest, TokenizeCompletionRequest] @@ -531,12 +510,3 @@ def _get_decoded_token(logprob: Logprob, def _is_model_supported(self, model_name): return self.models.is_base_model(model_name) - - -def create_error_response( - message: str, - err_type: str = "BadRequestError", - status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: - return ErrorResponse(message=message, - type=err_type, - code=status_code.value) diff --git a/vllm/entrypoints/openai/serving_models.py b/vllm/entrypoints/openai/serving_models.py index cd4e2b63f56d0..26966896bc272 100644 --- a/vllm/entrypoints/openai/serving_models.py +++ b/vllm/entrypoints/openai/serving_models.py @@ -1,5 +1,6 @@ import json import pathlib +from dataclasses import dataclass from http import HTTPStatus from typing import List, Optional, Union @@ -9,15 +10,30 @@ ModelCard, ModelList, ModelPermission, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_engine import (BaseModelPath, - LoRAModulePath, - PromptAdapterPath, - create_error_response) from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.utils import AtomicCounter +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + class OpenAIServingModels: """Shared instance to hold data about the loaded base model(s) and adapters. @@ -183,3 +199,12 @@ async def _check_unload_lora_adapter_request( status_code=HTTPStatus.BAD_REQUEST) return None + + +def create_error_response( + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) From 55e774fbd241d9886118b76e226ec3865c956f09 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 31 Dec 2024 15:00:12 -0700 Subject: [PATCH 8/8] :art: fmt Signed-off-by: Joe Runde --- tests/entrypoints/openai/test_serving_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_models.py b/tests/entrypoints/openai/test_serving_models.py index 341850db46c80..96897dc730da2 100644 --- a/tests/entrypoints/openai/test_serving_models.py +++ b/tests/entrypoints/openai/test_serving_models.py @@ -7,8 +7,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse, LoadLoraAdapterRequest, UnloadLoraAdapterRequest) -from vllm.entrypoints.openai.serving_models import BaseModelPath -from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_models import (BaseModelPath, + OpenAIServingModels) from vllm.lora.request import LoRARequest MODEL_NAME = "meta-llama/Llama-2-7b"