Skip to content

Commit

Permalink
Add missing files for #235 (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Mar 20, 2024
1 parent 4727147 commit 24f9545
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
6 changes: 6 additions & 0 deletions serve/mlc_serve/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class ServingError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)

class JSONModeError(ServingError):
pass
95 changes: 95 additions & 0 deletions serve/tests/test_error_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse

from mlc_serve.engine import (
ChatMessage,
DebugOptions,
Request,
SamplingParams,
StoppingCriteria,
)
from mlc_serve.utils import (
create_mlc_engine,
get_default_mlc_serve_argparser,
postproc_mlc_serve_args,
)


def _test_bad_json_schema(args: argparse.Namespace):
engine = create_mlc_engine(args)

sampling_params = SamplingParams(
temperature=0.0,
vocab_size=engine.model_artifact_config.vocab_size,
)

sampling_params_json = SamplingParams(
temperature=0.0,
vocab_size=engine.model_artifact_config.vocab_size,
json_schema={
"schema": {
"$defs": {
"Person": {
"type": "object",
"properties": {"name": {"type": "string"}},
},
"Family": {
"type": "object",
"properties": {
"last_name": {"type": "string"},
"members": {
"type": "array",
"items": {"$ref": "#/$defs/Person"},
},
},
},
},
}
},
)
prompt = "Hello, my name is"

# The first request should fail, but the engine should be able to continue
# processing the second request.
engine.add(
[
Request(
request_id="0",
messages=[ChatMessage(role="user", content=prompt)],
sampling_params=sampling_params_json,
stopping_criteria=StoppingCriteria(max_tokens=20, stop_sequences=None),
debug_options=DebugOptions(prompt=prompt),
)
]
)
engine.add(
[
Request(
request_id="1",
messages=[ChatMessage(role="user", content=prompt)],
sampling_params=sampling_params,
stopping_criteria=StoppingCriteria(max_tokens=20, stop_sequences=None),
debug_options=DebugOptions(prompt=prompt),
)
]
)

results = []

while engine.has_pending_requests():
results.append(engine.step())

if args.use_staging_engine:
engine.stop()

# The second request should succeed
assert len(results) > 1
# The first request fails with an empty output and an error message
assert len(results[0].outputs[0].sequences) == 0 and results[0].outputs[0].error is not None


if __name__ == "__main__":
parser = get_default_mlc_serve_argparser("test engine")
args = parser.parse_args()
postproc_mlc_serve_args(args)

_test_bad_json_schema(args)

0 comments on commit 24f9545

Please sign in to comment.