Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Disaggregate prefill decode with zmq #11791

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions examples/online_serving/disaggregated_prefill.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ cleanup() {

export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')

# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi

# a function that waits vLLM server to start
wait_for_server() {
local port=$1
Expand All @@ -46,6 +38,7 @@ wait_for_server() {
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8100 \
--zmq-server-port 7010 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
Expand All @@ -54,24 +47,27 @@ CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct \
--port 8200 \
--zmq-server-port 7011 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &

# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200

# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (port 8100), change max_tokens
# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance
# NOTE: the usage of this API is subject to change --- in the future we will
# introduce "vllm connect" to connect between prefill and decode instances
python3 ../../benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py &
# instance (via zmq port 7011)
vllm connect --port 8000 \
--prefill-addr 127.0.0.1:7010 \
--decode-addr 127.0.0.1:7011 &

# wait until prefill, decode instances and proxy are ready
wait_for_server 8000
wait_for_server 8100
wait_for_server 8200

sleep 1

# serve two example requests
Expand Down
234 changes: 234 additions & 0 deletions vllm/entrypoints/disagg_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import asyncio
import json
import signal
import traceback
import uuid
# from fastapi.lifespan import Lifespan
from asyncio import Queue
from contextlib import asynccontextmanager
from typing import AsyncGenerator

import uvicorn
import uvloop
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse

from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser

# default prefill and decode addr
time_out = 180
fastapi_port = 8000
prefill_addr = "ipc://localhost:7010"
socket_prefill_num = 100
decode_addr = "ipc://localhost:7020"
socket_decode_num = 100
context_type_json = "application/json"
context_type_error = "error"

# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disagg_connector')


@asynccontextmanager
async def lifespan(app: FastAPI):
# create socket pool with prefill and decode
logger.info("start create_socket_pool")
app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_prefill")
app.state.sockets_decode = await create_socket_pool(
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_decode")
yield
## close zmq context
logger.info("shutdown disagg connector")
logger.info("term zmqctx")
app.state.zmqctx.destroy(linger=0)


app = FastAPI(lifespan=lifespan)


# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int,
zmqctx: zmq.asyncio.Context) -> Queue:
sockets: Queue[zmq.Socket] = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
identity = f"worker-{i}-{uuid.uuid4()}"
sock.setsockopt(zmq.IDENTITY, identity.encode())
sock.connect(url)
logger.info("%s started at %s with queue size %s", identity, url,
sockets.qsize())
await sockets.put(sock)
return sockets


# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict,
sockets: Queue):
sock: zmq.Socket = await sockets.get()
try:
requestBody = json.dumps(request)
headersJson = json.dumps(headers)
logger.info("Sending requestBody: %s to %s with headers: %s",
requestBody, route, headersJson)
await asyncio.wait_for(sock.send_multipart(
[route.encode(),
headersJson.encode(),
requestBody.encode()]),
timeout=time_out)
logger.info("Sent end")
while True:
logger.info("Waiting for reply")
[contentType,
reply] = await asyncio.wait_for(sock.recv_multipart(),
timeout=time_out)
contentType_str = contentType.decode()
reply_str = reply.decode()
logger.info("Received result: %s, %s", contentType_str, reply_str)
yield (contentType_str, reply_str)
if context_type_json == contentType_str:
logger.info("Received %s message, return socket",
contentType_str)
break
if "[DONE]" in reply_str:
logger.info("Received stop signal, return socket")
break
except asyncio.TimeoutError:
logger.error(traceback.format_exc())
logger.error("Timeout, return socket: %s",
sock.getsockopt(zmq.IDENTITY))
yield (context_type_error, "System Error")
finally:
await sockets.put(sock)


async def generate_stream_response(fisrt_reply: str,
generator: AsyncGenerator):
yield fisrt_reply
async for _, reply in generator:
yield reply


async def prefill(route: str, header: dict, original_request_data: dict):
logger.info("start prefill")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_prefill)
async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_error == contentType:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
return True


async def decode(route: str, header: dict, original_request_data: dict):
logger.info("start decode")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_decode)

async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_error == contentType:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
elif context_type_json == contentType:
return JSONResponse(reply)
else:
return StreamingResponse(generate_stream_response(
reply, generator),
media_type="text/event-stream")


@app.post('/v1/completions')
async def chat_completions(request: Request):
try:
# Add the X-Request-Id header to the raw headers list
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info("add X-Request-Id: %s", x_request_id)
header["X-Request-Id"] = x_request_id
original_request_data = await request.json()
logger.info("Received request: %s header: %s", original_request_data,
header)
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
try:
prefill_response = await prefill(route, header, prefill_request)
if isinstance(prefill_response, JSONResponse):
return prefill_response
logger.info("finish prefill start decode")
response = await decode(route, header, original_request_data)
logger.info("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse({"error": {"message": str(e)}})
return response

except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))


async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)
app.state.port = args.port if args.port is not None else fastapi_port
app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr
is not None else decode_addr)
app.state.decode_addr = (f"ipc://{args.decode_addr}"
if args.decode_addr is not None else decode_addr)
logger.info(
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
app.state.prefill_addr, app.state.decode_addr, app.state.port)

def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")

signal.signal(signal.SIGTERM, signal_handler)
# init uvicorn server
config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port)
server = uvicorn.Server(config)
await server.serve()


if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(description="vLLM disagg zmq server.")
parser.add_argument("--port",
type=int,
default=8000,
help="The fastapi server port")
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The prefill address IP:PORT")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The decode address IP:PORT")

args = parser.parse_args()

uvloop.run(run_disagg_connector(args))

# uvicorn.run(app, host="0.0.0.0", port=fastapi_port)
35 changes: 35 additions & 0 deletions vllm/entrypoints/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
from typing import Any

import uvicorn
import zmq
import zmq.asyncio
import zmq.devices
from fastapi import FastAPI, Request, Response

from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm.logger import init_logger
from vllm.utils import find_process_using_port

Expand Down Expand Up @@ -58,6 +62,37 @@ async def dummy_shutdown() -> None:
return server.shutdown()


async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
# different zmq context can't communicate use inproc
workers_addr = "ipc://workers"
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context.instance()
try:
tasks = [
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(100)
]
logger.info("zmq tasks: %s", tasks)
# thread safety proxy create socket in the background:
# https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices
thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER)
thread_proxy.bind_in(clients_addr)
thread_proxy.bind_out(workers_addr)
thread_proxy.start()
await asyncio.gather(*tasks)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
except zmq.ZMQError as e:
print("ZMQError:", e)
finally:
# We never get here but clean up anyhow
context.destroy(linger=0)


def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

Expand Down
8 changes: 7 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.launcher import serve_http, serve_zmq
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
Expand Down Expand Up @@ -799,6 +799,12 @@ def signal_handler(*_) -> None:
model_config = await engine_client.get_model_config()
await init_app_state(engine_client, model_config, app.state, args)

zmq_server_port = args.zmq_server_port
if zmq_server_port is not None:
logger.info("asyncio.create_task Starting ZMQ server at port %d",
zmq_server_port)
asyncio.create_task(serve_zmq(args, zmq_server_port, app))

shutdown_task = await serve_http(
app,
host=args.host,
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=False,
help="If set to True, enable prompt_tokens_details in usage.")

parser.add_argument('--zmq-server-port',
type=int,
default=None,
help='The port to serve the zmq server on.')

return parser


Expand Down
Loading
Loading