Skip to content

Commit

Permalink
add unittest/
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Nov 20, 2023
1 parent fd173de commit deee08b
Show file tree
Hide file tree
Showing 4 changed files with 512 additions and 0 deletions.
62 changes: 62 additions & 0 deletions serve/tests/unittest/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json

import pytest
from fastapi.testclient import TestClient
from httpx_sse import connect_sse
from mlc_serve.api import create_app
from mlc_serve.engine import AsyncEngineConnector, InferenceEngine
from mlc_serve.engine.dummy import DummyInferenceEngine


@pytest.fixture
def engine() -> InferenceEngine:
return DummyInferenceEngine()


@pytest.fixture
def client(engine):
connector = AsyncEngineConnector(engine, engine_wait_timeout=0.1)
app = create_app(connector)
with TestClient(app) as client:
yield client


@pytest.mark.timeout(3)
def test_chat_completion(client):
response = client.post(
"/v1/chat/completions",
json={"model": "test", "messages": "test prompt", "max_tokens": 10},
)
assert response.status_code == 200
assert response.json()["choices"][0]["message"]["content"] == " test" * 10


@pytest.mark.timeout(3)
def test_stream_chat_completion(client):
data = []
with connect_sse(
client,
"POST",
"/v1/chat/completions",
json={
"model": "test",
"messages": "test prompt",
"max_tokens": 10,
"stream": True,
},
) as event_source:
for sse in event_source.iter_sse():
data.append(sse.data)

events = [json.loads(d) for d in data[:-1]]

assert events[0]["choices"][0]["delta"]["role"] == "assistant"
assert events[0]["choices"][0]["delta"]["content"] == ""

assert all(e["choices"][0]["delta"]["content"] == " test" for e in events[1:-1])

assert events[-1]["choices"][0]["delta"] == {}
assert events[-1]["choices"][0]["finish_reason"] == "length"

assert data[-1] == "[DONE]"

24 changes: 24 additions & 0 deletions serve/tests/unittest/test_dummy_inference_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from mlc_serve.engine import ChatMessage, Request, SamplingParams, StoppingCriteria
from mlc_serve.engine.dummy import DummyInferenceEngine


def test_single_request():
engine = DummyInferenceEngine()

request_id = "1"
engine.add(
[
Request(
request_id=request_id,
messages=[ChatMessage(role="user", content="test prompt")],
sampling_params=SamplingParams(temperature=1),
stopping_criteria=StoppingCriteria(max_tokens=20),
),
]
)

result = engine.step()

assert result.outputs[0].request_id == request_id
assert result.outputs[0].error is None
assert len(result.outputs) == 1
152 changes: 152 additions & 0 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import torch

import argparse
import json
import random
import os

from mlc_llm import utils
from mlc_serve.engine import (
Request,
ChatMessage,
DebugOptions,
SamplingParams,
StoppingCriteria,
FinishReason,
get_engine_config
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule

def create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,

):
engine_config = get_engine_config({
"use_staging_engine": use_staging_engine,
"max_num_batched_tokens": max_num_batched_tokens,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps", "prompt_allocate_ratio"
})

if use_staging_engine:
engine = StagingInferenceEngine(
tokenizer_module=HfTokenizerModule(model_artifact_path),
model_module_loader=PagedCacheModelModule,
model_module_loader_kwargs={
"model_artifact_path": model_artifact_path,
"engine_config": engine_config,
},
)
engine.start()
else:
engine = SynchronousInferenceEngine(
PagedCacheModelModule(
model_artifact_path = model_artifact_path,
engine_config = engine_config,
))
return engine

def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
return Request(
request_id = str(idx),
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
stop_sequences=stop
),
debug_options = DebugOptions(ignore_eos = ignore_eos)
)

def test_max_tokens(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens=2560,
max_input_len=2560,
num_requests=5,
ignore_eos=False
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,
)

requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(1, num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()


def test_ignore_eos(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens=2560,
max_input_len=2560,
num_requests=5,
):
prompt = "hi"
engine = create_engine(
model_artifact_path,
use_staging_engine,
max_num_batched_tokens,
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True) for n in range(s, s+num_requests)]
engine.add(requests)

generated = ["" for _ in range(num_requests)]

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
assert len(res.sequences) == 1
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta

if use_staging_engine:
engine.stop()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="../../../dist")
args = parser.parse_args()
model_artifact_path = os.path.join(args.artifact_path, args.local_id)

test_max_tokens(model_artifact_path, use_staging_engine=True)
test_max_tokens(model_artifact_path, use_staging_engine=False)
test_ignore_eos(model_artifact_path, use_staging_engine=True)
test_ignore_eos(model_artifact_path, use_staging_engine=False)
Loading

0 comments on commit deee08b

Please sign in to comment.