diff --git a/tests/entrypoints/openai/test_mp_crash.py b/tests/entrypoints/openai/test_mp_crash.py new file mode 100644 index 0000000000000..7dc595a7be351 --- /dev/null +++ b/tests/entrypoints/openai/test_mp_crash.py @@ -0,0 +1,35 @@ +from typing import Any + +import pytest + +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.api_server import build_async_engine_client +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.utils import FlexibleArgumentParser + + +def crashing_from_engine_args( + cls, + engine_args: Any = None, + start_engine_loop: Any = None, + usage_context: Any = None, + stat_loggers: Any = None, +) -> "AsyncLLMEngine": + raise Exception("foo") + + +@pytest.mark.asyncio +async def test_mp_crash_detection(monkeypatch): + + with pytest.raises(RuntimeError) as excinfo, monkeypatch.context() as m: + m.setattr(AsyncLLMEngine, "from_engine_args", + crashing_from_engine_args) + parser = FlexibleArgumentParser( + description="vLLM's remote OpenAI server.") + parser = make_arg_parser(parser) + args = parser.parse_args([]) + + async with build_async_engine_client(args): + pass + assert "The server process died before responding to the readiness probe"\ + in str(excinfo.value) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 48aa904d4721d..d44604b12fb69 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -120,9 +120,18 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: # Build RPCClient, which conforms to AsyncEngineClient Protocol. async_engine_client = AsyncEngineRPCClient(rpc_path) - await async_engine_client.setup() try: + while True: + try: + await async_engine_client.setup() + break + except TimeoutError as e: + if not rpc_server_process.is_alive(): + raise RuntimeError( + "The server process died before " + "responding to the readiness probe") from e + yield async_engine_client finally: # Ensure rpc server process was terminated diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 8552c286eeeea..d69b202e2d1bb 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -18,6 +18,9 @@ from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +# Time to wait before checking it the server process is alive. +SERVER_START_TIMEOUT_MS = 1000 + class AsyncEngineRPCClient: @@ -61,7 +64,16 @@ def socket(self): socket.connect(self.rpc_path) yield socket finally: - socket.close() + # linger == 0 means discard unsent messages + # when the socket is closed. This is necessary + # because otherwise self.context.destroy() will + # wait for 30 seconds until unsent messages are + # received, which is impossible if the server + # crashed. In the absence of a server crash we + # always expect a response before closing the + # socket anyway. + # Reference: http://api.zeromq.org/4-2:zmq-setsockopt#toc24 + socket.close(linger=0) async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, expected_type: Any, @@ -85,14 +97,19 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request(self, request: RPC_REQUEST_TYPE, - error_message: str): + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + timeout: Optional[int] = None): """Send one-way RPC request to trigger an action.""" with self.socket() as socket: # Ping RPC Server with request. await socket.send(cloudpickle.dumps(request)) # Await acknowledgement from RPCServer. + if timeout is not None and await socket.poll(timeout=timeout) == 0: + raise TimeoutError(f"server didn't reply within {timeout} ms") + response = cloudpickle.loads(await socket.recv()) if not isinstance(response, str) or response != VLLM_RPC_SUCCESS_STR: @@ -117,7 +134,8 @@ async def wait_for_server(self): await self._send_one_way_rpc_request( request=RPCUtilityRequest.IS_SERVER_READY, - error_message="Unable to start RPC Server.") + error_message="Unable to start RPC Server.", + timeout=SERVER_START_TIMEOUT_MS) async def _get_model_config_rpc(self) -> ModelConfig: """Get the ModelConfig object from the RPC Server"""