diff --git a/examples/online_serving/disaggregated_prefill.sh b/examples/online_serving/disaggregated_prefill.sh index 2bb2824c6c86f..e4a7589e919d4 100644 --- a/examples/online_serving/disaggregated_prefill.sh +++ b/examples/online_serving/disaggregated_prefill.sh @@ -23,14 +23,6 @@ cleanup() { export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') -# install quart first -- required for disagg prefill proxy serve -if python3 -c "import quart" &> /dev/null; then - echo "Quart is already installed." -else - echo "Quart is not installed. Installing..." - python3 -m pip install quart -fi - # a function that waits vLLM server to start wait_for_server() { local port=$1 @@ -46,6 +38,7 @@ wait_for_server() { # prefilling instance, which is the KV producer CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8100 \ + --zmq-server-port 7010 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-transfer-config \ @@ -54,24 +47,27 @@ CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ # decoding instance, which is the KV consumer CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \ --port 8200 \ + --zmq-server-port 7011 \ --max-model-len 100 \ --gpu-memory-utilization 0.8 \ --kv-transfer-config \ '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' & -# wait until prefill and decode instances are ready -wait_for_server 8100 -wait_for_server 8200 - # launch a proxy server that opens the service at port 8000 # the workflow of this proxy: -# - send the request to prefill vLLM instance (port 8100), change max_tokens +# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens # to 1 # - after the prefill vLLM finishes prefill, send the request to decode vLLM -# instance -# NOTE: the usage of this API is subject to change --- in the future we will -# introduce "vllm connect" to connect between prefill and decode instances -python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py & +# instance (via zmq port 7011) +vllm connect --port 8000 \ + --prefill-addr 127.0.0.1:7010 \ + --decode-addr 127.0.0.1:7011 & + +# wait until prefill, decode instances and proxy are ready +wait_for_server 8000 +wait_for_server 8100 +wait_for_server 8200 + sleep 1 # serve two example requests diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py new file mode 100644 index 0000000000000..3698f6a126963 --- /dev/null +++ b/vllm/entrypoints/disagg_connector.py @@ -0,0 +1,234 @@ +import asyncio +import json +import signal +import traceback +import uuid +# from fastapi.lifespan import Lifespan +from asyncio import Queue +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +import uvicorn +import uvloop +import zmq +import zmq.asyncio +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +# default prefill and decode addr +time_out = 180 +fastapi_port = 8000 +prefill_addr = "ipc://localhost:7010" +socket_prefill_num = 100 +decode_addr = "ipc://localhost:7020" +socket_decode_num = 100 +context_type_json = "application/json" +context_type_error = "error" + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.disagg_connector') + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # create socket pool with prefill and decode + logger.info("start create_socket_pool") + app.state.zmqctx = zmq.asyncio.Context() + app.state.sockets_prefill = await create_socket_pool( + app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx) + logger.info("success create_socket_pool sockets_prefill") + app.state.sockets_decode = await create_socket_pool( + app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx) + logger.info("success create_socket_pool sockets_decode") + yield + ## close zmq context + logger.info("shutdown disagg connector") + logger.info("term zmqctx") + app.state.zmqctx.destroy(linger=0) + + +app = FastAPI(lifespan=lifespan) + + +# create async socket pool with num_sockets use ZMQ_DEALER +async def create_socket_pool(url: str, num_sockets: int, + zmqctx: zmq.asyncio.Context) -> Queue: + sockets: Queue[zmq.Socket] = Queue() + for i in range(num_sockets): + sock = zmqctx.socket(zmq.DEALER) + identity = f"worker-{i}-{uuid.uuid4()}" + sock.setsockopt(zmq.IDENTITY, identity.encode()) + sock.connect(url) + logger.info("%s started at %s with queue size %s", identity, url, + sockets.qsize()) + await sockets.put(sock) + return sockets + + +# select a socket and execute task +async def execute_task_async(route: str, headers: dict, request: dict, + sockets: Queue): + sock: zmq.Socket = await sockets.get() + try: + requestBody = json.dumps(request) + headersJson = json.dumps(headers) + logger.info("Sending requestBody: %s to %s with headers: %s", + requestBody, route, headersJson) + await asyncio.wait_for(sock.send_multipart( + [route.encode(), + headersJson.encode(), + requestBody.encode()]), + timeout=time_out) + logger.info("Sent end") + while True: + logger.info("Waiting for reply") + [contentType, + reply] = await asyncio.wait_for(sock.recv_multipart(), + timeout=time_out) + contentType_str = contentType.decode() + reply_str = reply.decode() + logger.info("Received result: %s, %s", contentType_str, reply_str) + yield (contentType_str, reply_str) + if context_type_json == contentType_str: + logger.info("Received %s message, return socket", + contentType_str) + break + if "[DONE]" in reply_str: + logger.info("Received stop signal, return socket") + break + except asyncio.TimeoutError: + logger.error(traceback.format_exc()) + logger.error("Timeout, return socket: %s", + sock.getsockopt(zmq.IDENTITY)) + yield (context_type_error, "System Error") + finally: + await sockets.put(sock) + + +async def generate_stream_response(fisrt_reply: str, + generator: AsyncGenerator): + yield fisrt_reply + async for _, reply in generator: + yield reply + + +async def prefill(route: str, header: dict, original_request_data: dict): + logger.info("start prefill") + generator = execute_task_async(route, header, original_request_data, + app.state.sockets_prefill) + async for contentType, reply in generator: + logger.info("contentType: %s, reply: %s", contentType, reply) + if context_type_error == contentType: + response = JSONResponse({"error": reply}) + response.status_code = 500 + return response + return True + + +async def decode(route: str, header: dict, original_request_data: dict): + logger.info("start decode") + generator = execute_task_async(route, header, original_request_data, + app.state.sockets_decode) + + async for contentType, reply in generator: + logger.info("contentType: %s, reply: %s", contentType, reply) + if context_type_error == contentType: + response = JSONResponse({"error": reply}) + response.status_code = 500 + return response + elif context_type_json == contentType: + return JSONResponse(reply) + else: + return StreamingResponse(generate_stream_response( + reply, generator), + media_type="text/event-stream") + + +@app.post('/v1/completions') +async def chat_completions(request: Request): + try: + # Add the X-Request-Id header to the raw headers list + x_request_id = str(uuid.uuid4()) + header = dict(request.headers) + if header.get("X-Request-Id") is None: + logger.info("add X-Request-Id: %s", x_request_id) + header["X-Request-Id"] = x_request_id + original_request_data = await request.json() + logger.info("Received request: %s header: %s", original_request_data, + header) + prefill_request = original_request_data.copy() + # change max_tokens = 1 to let it only do prefill + prefill_request['max_tokens'] = 1 + route = "/v1/completions" + # finish prefill + try: + prefill_response = await prefill(route, header, prefill_request) + if isinstance(prefill_response, JSONResponse): + return prefill_response + logger.info("finish prefill start decode") + response = await decode(route, header, original_request_data) + logger.info("finish decode") + except Exception as e: + logger.error("Error occurred in disagg prefill proxy server, %s", + e) + response = JSONResponse({"error": {"message": str(e)}}) + return response + + except Exception as e: + import sys + import traceback + exc_info = sys.exc_info() + logger.error("Error occurred in disagg prefill proxy server") + logger.error(e) + logger.error("".join(traceback.format_exception(*exc_info))) + + +async def run_disagg_connector(args, **uvicorn_kwargs) -> None: + logger.info("vLLM Disaggregate Connector start %s %s", args, + uvicorn_kwargs) + logger.info(args.prefill_addr) + app.state.port = args.port if args.port is not None else fastapi_port + app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr + is not None else decode_addr) + app.state.decode_addr = (f"ipc://{args.decode_addr}" + if args.decode_addr is not None else decode_addr) + logger.info( + "start connect prefill_addr: %s decode_addr: %s zmq server port: %s", + app.state.prefill_addr, app.state.decode_addr, app.state.port) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + # init uvicorn server + config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port) + server = uvicorn.Server(config) + await server.serve() + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser(description="vLLM disagg zmq server.") + parser.add_argument("--port", + type=int, + default=8000, + help="The fastapi server port") + parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The prefill address IP:PORT") + parser.add_argument("--decode-addr", + type=str, + required=True, + help="The decode address IP:PORT") + + args = parser.parse_args() + + uvloop.run(run_disagg_connector(args)) + + # uvicorn.run(app, host="0.0.0.0", port=fastapi_port) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 5dcf50bd1b0a1..db7f24d0a6d9e 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -4,11 +4,15 @@ from typing import Any import uvicorn +import zmq +import zmq.asyncio +import zmq.devices from fastapi import FastAPI, Request, Response from vllm import envs from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.entrypoints.openai.connect_worker import worker_routine from vllm.logger import init_logger from vllm.utils import find_process_using_port @@ -58,6 +62,37 @@ async def dummy_shutdown() -> None: return server.shutdown() +async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: + """Server routine""" + logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, + zmq_server_port) + # different zmq context can't communicate use inproc + workers_addr = "ipc://workers" + clients_addr = f"ipc://127.0.0.1:{zmq_server_port}" + # Prepare our context and sockets + context = zmq.asyncio.Context.instance() + try: + tasks = [ + asyncio.create_task(worker_routine(workers_addr, app, context, i)) + for i in range(100) + ] + logger.info("zmq tasks: %s", tasks) + # thread safety proxy create socket in the background: + # https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices + thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER) + thread_proxy.bind_in(clients_addr) + thread_proxy.bind_out(workers_addr) + thread_proxy.start() + await asyncio.gather(*tasks) + except KeyboardInterrupt: + print("ZMQ Server interrupted") + except zmq.ZMQError as e: + print("ZMQError:", e) + finally: + # We never get here but clean up anyhow + context.destroy(linger=0) + + def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: """Adds handlers for fatal errors that should crash the server""" diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 1aeefe86cd05e..05b0178b25b3e 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,7 +33,7 @@ from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import load_chat_template -from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.launcher import serve_http, serve_zmq from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) @@ -799,6 +799,12 @@ def signal_handler(*_) -> None: model_config = await engine_client.get_model_config() await init_app_state(engine_client, model_config, app.state, args) + zmq_server_port = args.zmq_server_port + if zmq_server_port is not None: + logger.info("asyncio.create_task Starting ZMQ server at port %d", + zmq_server_port) + asyncio.create_task(serve_zmq(args, zmq_server_port, app)) + shutdown_task = await serve_http( app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 35445449463e9..9d3d44fe8d71a 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -251,6 +251,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="If set to True, enable prompt_tokens_details in usage.") + parser.add_argument('--zmq-server-port', + type=int, + default=None, + help='The port to serve the zmq server on.') + return parser diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py new file mode 100644 index 0000000000000..7e68648253716 --- /dev/null +++ b/vllm/entrypoints/openai/connect_worker.py @@ -0,0 +1,137 @@ +import json +import tempfile +import traceback +import uuid +from typing import Optional + +import httpx +import zmq +import zmq.asyncio +from fastapi import FastAPI, Request + +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionRequest, + CompletionResponse, + ErrorResponse) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.logger import init_logger + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.connect_worker') + +def base(app: FastAPI) -> OpenAIServing: + # Reuse the existing instance + return tokenization(app) + + +def models(app: FastAPI) -> OpenAIServingModels: + return app.state.openai_serving_models + + +def chat(app: FastAPI) -> Optional[OpenAIServingChat]: + return app.state.openai_serving_chat + + +def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]: + return app.state.openai_serving_completion + +def tokenization(app: FastAPI) -> OpenAIServingTokenization: + return app.state.openai_serving_tokenization + + +def bytes_to_headers(bytes_data: bytes) -> httpx.Headers: + headers_dict = json.loads(bytes_data.decode()) + return httpx.Headers(headers_dict) + +async def worker_routine(worker_addr: str, app: FastAPI, + context: zmq.asyncio.Context, i: int = 0): + """Worker routine""" + try: + # Socket to talk to dispatcher + socket = context.socket(zmq.DEALER) + worker_identity = f"worker-{i}-{uuid.uuid4()}" + socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) + socket.connect(worker_addr) + logger.info("%s started at %s", worker_identity, worker_addr) + while True: + identity, url, header, body = await socket.recv_multipart() + logger.info("worker-%d Received request identity: [ %s ]", + i, identity.decode()) + url_str = url.decode() + logger.info("worker-%d Received request url: [ %s ]", + i, url_str) + headers = bytes_to_headers(header) + logger.info("worker-%d Received request headers: [ %s ]", + i, headers) + body_json = json.loads(body.decode()) + logger.info("worker-%d Received request body: [ %s ]", + i, body_json) + logger.info("worker-%d Calling OpenAI API", i) + completionRequest = CompletionRequest(**body_json) + createRequest = create_request(url_str, "POST", body_json, headers) + generator = await create_completion(app, completionRequest, + createRequest) + context_type_json = b"application/json" + if isinstance(generator, ErrorResponse): + content = generator.model_dump_json() + context_json = json.loads(content) + context_json.append("status_code", generator.code) + await socket.send_multipart([identity, context_type_json, + json.dumps(context_json).encode('utf-8')]) + elif isinstance(generator, CompletionResponse): + await socket.send_multipart([identity, + context_type_json, + json.dumps(generator.model_dump()).encode('utf-8')]) + else: + async for chunk in generator: + await socket.send_multipart([identity, + b"text/event-stream", + chunk.encode('utf-8')]) + except Exception as e: + logger.error("Error in worker routine: %s worker-%d", e, i) + logger.error(traceback.format_exc()) + +async def create_completion(app: FastAPI, request: CompletionRequest, + raw_request: Request): + handler = completion(app) + logger.info("zmq request post: %s", request) + if handler is None: + return base(app).create_error_response( + message="The model does not support Completions API") + + generator = await handler.create_completion(request, raw_request) + logger.info("zmq request end post: %s", generator) + return generator + + +def create_request(path: str, method: str, body: dict, + headers: httpx.Headers) -> Request: + scope = { + 'type': 'http', + 'http_version': '1.1', + 'method': method, + 'path': path, + 'headers': list(headers.items()) if headers else [], + } + if body: + scope['body'] = json.dumps(body) + async def receive(): + return { + 'type': 'http.request', + 'body': scope.get('body', b''), + } + async def send(message): + pass + return Request(scope, receive=receive, send=send) + + +if __name__ == "__main__": + print(bytes_to_headers(b'{"Content-Type": "application/json"}')) diff --git a/vllm/scripts.py b/vllm/scripts.py index 42e1c639eda10..2642702e91519 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -11,6 +11,7 @@ import vllm.version from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.disagg_connector import run_disagg_connector from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import (make_arg_parser, validate_parsed_serve_args) @@ -42,6 +43,14 @@ def serve(args: argparse.Namespace) -> None: uvloop.run(run_server(args)) +def connect(args: argparse.Namespace) -> None: + try: + uvloop.run(run_disagg_connector(args)) + except KeyboardInterrupt: + pass + + + def interactive_cli(args: argparse.Namespace) -> None: register_signal_handlers() @@ -192,6 +201,25 @@ def main(): "used for models that support system prompts.")) chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + connect_parser = subparsers.add_parser( + "connect", + help="Start the vLLM OpenAI Compatible API server And Connect to other" + "servers disaggreate prefill and decode", + usage="vllm connect [options]") + connect_parser.add_argument("--port", + type=int, + default=8001, + help="The fastapi server port") + connect_parser.add_argument("--prefill-addr", + type=str, + required=True, + help="The prefill address IP:PORT") + connect_parser.add_argument("--decode-addr", + type=str, + required=True, + help="The decode address IP:PORT") + connect_parser.set_defaults(dispatch_function=connect) + args = parser.parse_args() if args.subparser == "serve": validate_parsed_serve_args(args)