Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sunggg committed Dec 10, 2023
1 parent bf7409c commit d697e46
Showing 1 changed file with 104 additions and 47 deletions.
151 changes: 104 additions & 47 deletions serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
import argparse
import os
from pathlib import Path
from mlc_serve.engine import (
Request,
ChatMessage,
DebugOptions,
SamplingParams,
StoppingCriteria,
FinishReason,
get_engine_config
get_engine_config,
)
from mlc_serve.engine.staging_engine import StagingInferenceEngine
from mlc_serve.engine.sync_engine import SynchronousInferenceEngine
from mlc_serve.model.base import get_model_artifact_config
from mlc_serve.model.paged_cache_model import HfTokenizerModule, PagedCacheModelModule

def create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,

):
engine_config = get_engine_config({
"use_staging_engine": use_staging_engine,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps"
})
def create_engine(
model_artifact_path,
use_staging_engine,
max_num_sequences,
max_input_len,
):
engine_config = get_engine_config(
{
"use_staging_engine": use_staging_engine,
"max_num_sequences": max_num_sequences,
"max_input_len": max_input_len,
# Use defaults for "min_decode_steps", "max_decode_steps"
}
)

if use_staging_engine:
engine = StagingInferenceEngine(
Expand All @@ -41,33 +44,33 @@ def create_engine(
else:
engine = SynchronousInferenceEngine(
PagedCacheModelModule(
model_artifact_path = model_artifact_path,
engine_config = engine_config,
))
model_artifact_path=model_artifact_path,
engine_config=engine_config,
)
)
return engine


def create_request(idx, prompt, temp, max_tokens, stop, ignore_eos):
return Request(
request_id = str(idx),
messages = [ChatMessage(role="user", content=prompt)],
sampling_params = SamplingParams(
temperature=0.0,
),
stopping_criteria = StoppingCriteria(
max_tokens=max_tokens,
stop_sequences=stop
request_id=str(idx),
messages=[ChatMessage(role="user", content=prompt)],
sampling_params=SamplingParams(
temperature=0.0,
),
debug_options = DebugOptions(ignore_eos = ignore_eos)
stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop),
debug_options=DebugOptions(ignore_eos=ignore_eos),
)


def _test_max_tokens(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
ignore_eos=False
):
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
max_input_len=512,
num_requests=5,
ignore_eos=False,
):
prompt = "Write a merge sort program in Python."
engine = create_engine(
model_artifact_path,
Expand All @@ -76,7 +79,17 @@ def _test_max_tokens(
max_input_len,
)

requests = [create_request(idx=str(n-1), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=ignore_eos) for n in range(1, num_requests)]
requests = [
create_request(
idx=str(n - 1),
prompt=prompt,
temp=0,
max_tokens=n,
stop=None,
ignore_eos=ignore_eos,
)
for n in range(1, num_requests)
]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -88,7 +101,10 @@ def _test_max_tokens(
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert (
seq.num_generated_tokens
== requests[int(res.request_id)].stopping_criteria.max_tokens
)
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta
Expand All @@ -98,12 +114,12 @@ def _test_max_tokens(


def _test_max_context_length(
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
num_requests=5,
ignore_eos=False
):
model_artifact_path,
use_staging_engine,
max_num_sequences=4,
num_requests=5,
ignore_eos=False,
):
model_artifact_config = get_model_artifact_config(model_artifact_path)
max_context_length = model_artifact_config.max_context_length

Expand All @@ -115,7 +131,17 @@ def _test_max_context_length(
)
prompt = "hi " * (max_context_length - 15)

requests = [create_request(idx=str(n), prompt=prompt, temp=0, max_tokens=None, stop=None, ignore_eos=ignore_eos) for n in range(num_requests)]
requests = [
create_request(
idx=str(n),
prompt=prompt,
temp=0,
max_tokens=None,
stop=None,
ignore_eos=ignore_eos,
)
for n in range(num_requests)
]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand Down Expand Up @@ -150,7 +176,17 @@ def _test_ignore_eos(
max_input_len,
)
s = 113
requests = [create_request(idx=str(n-s), prompt=prompt, temp=0, max_tokens=n, stop=None, ignore_eos=True) for n in range(s, s+num_requests)]
requests = [
create_request(
idx=str(n - s),
prompt=prompt,
temp=0,
max_tokens=n,
stop=None,
ignore_eos=True,
)
for n in range(s, s + num_requests)
]
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -162,7 +198,10 @@ def _test_ignore_eos(
seq = res.sequences[0]

if seq.is_finished:
assert seq.num_generated_tokens == requests[int(res.request_id)].stopping_criteria.max_tokens
assert (
seq.num_generated_tokens
== requests[int(res.request_id)].stopping_criteria.max_tokens
)
assert seq.finish_reason == FinishReason.Length
else:
generated[int(res.request_id)] += seq.delta
Expand All @@ -188,7 +227,16 @@ def _test_stop(
ignore_eos = False
requests = []
for n, stop in enumerate(["\n", ["\n"], "\n\n", "!", ["n", "!"]]):
requests.append(create_request(idx=str(n), prompt=prompt, temp=0, max_tokens=300, stop=stop, ignore_eos=False))
requests.append(
create_request(
idx=str(n),
prompt=prompt,
temp=0,
max_tokens=300,
stop=stop,
ignore_eos=False,
)
)
engine.add(requests)

generated = ["" for _ in range(num_requests)]
Expand All @@ -200,12 +248,21 @@ def _test_stop(
seq = res.sequences[0]
req_id = int(res.request_id)
if seq.is_finished:
assert seq.finish_reason == FinishReason.Stop, f"{seq.finish_reason.name}"
assert (
seq.finish_reason == FinishReason.Stop
), f"{seq.finish_reason.name}"
assert not seq.delta
gen_txt = generated[req_id]

# stop token should appear only once in the gen text.
found = sum([gen_txt.count(str_stop) for str_stop in requests[req_id].stopping_criteria.stop_sequences])
found = sum(
[
gen_txt.count(str_stop)
for str_stop in requests[
req_id
].stopping_criteria.stop_sequences
]
)
assert found == 1, f"{gen_txt!r}, matches: {found}"
else:
generated[int(res.request_id)] += seq.delta
Expand All @@ -219,7 +276,7 @@ def _test_stop(
parser.add_argument("--local-id", type=str, required=True)
parser.add_argument("--artifact-path", type=str, default="dist")
args = parser.parse_args()
model_artifact_path = os.path.join(args.artifact_path, args.local_id)
model_artifact_path = Path(os.path.join(args.artifact_path, args.local_id))

_test_max_tokens(model_artifact_path, use_staging_engine=True)
_test_max_tokens(model_artifact_path, use_staging_engine=False)
Expand Down

0 comments on commit d697e46

Please sign in to comment.