Skip to content

Commit

Permalink
[BugFix] Overhaul async request cancellation (vllm-project#7111)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and kylesayrs committed Aug 17, 2024
1 parent 30648f1 commit 23fb40b
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 226 deletions.
9 changes: 5 additions & 4 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
from typing import Any, Dict
from typing import Any, Dict, Iterable

import uvicorn
from fastapi.responses import JSONResponse, Response
Expand All @@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0

async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
async def _engine_abort(self, request_ids: Iterable[str]):
ids = list(request_ids)
self._num_aborts += len(ids)
await super()._engine_abort(ids)

def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
Expand Down
25 changes: 12 additions & 13 deletions tests/async_engine/test_request_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ async def test_request_tracker():
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not aborted
assert not stream_1.finished

stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
assert not finished
assert not aborted
assert not stream_2.finished
assert not stream_3.finished

Expand All @@ -36,19 +36,19 @@ async def test_request_tracker():
assert not tracker.new_requests_event.is_set()

tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "1" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "1" in aborted
assert not new
assert stream_1.finished

stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "4" in aborted
assert not new
assert stream_4.finished

Expand All @@ -57,10 +57,9 @@ async def test_request_tracker():
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(finished) == 1
assert "2" in finished
assert not aborted
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import socket
import sys
from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)

Expand Down Expand Up @@ -37,11 +38,11 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
print(f"iterator {idx} cancelled")

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators)
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
Expand Down
Loading

0 comments on commit 23fb40b

Please sign in to comment.