diff --git a/serve/benchmarks/benchmark_throughput.py b/serve/benchmarks/benchmark_throughput.py index 75c585f0cd..466c27ae5b 100644 --- a/serve/benchmarks/benchmark_throughput.py +++ b/serve/benchmarks/benchmark_throughput.py @@ -111,6 +111,7 @@ def main(args: argparse.Namespace): engine = LocalProcessInferenceEngine( model_module, max_batched_tokens=args.max_num_batched_tokens, + min_decode_steps=args.min_decode_steps, ) requests = sample_requests( @@ -154,6 +155,7 @@ def main(args: argparse.Namespace): parser.add_argument("--num-shards", type=int, default=1) parser.add_argument("--max-num-batched-tokens", type=int, default=-1) parser.add_argument("--max-input-len", type=int, default=-1) + parser.add_argument("--min-decode-steps", type=int, default=256) parser.add_argument( "--num-prompts", type=int, default=1000, help="Number of prompts to process." ) diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index ac137a84a4..c28e6d9bfc 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -34,6 +34,7 @@ def parse_args(): args.add_argument("--num-shards", type=int, default=1) args.add_argument("--max-num-batched-tokens", type=int, default=-1) args.add_argument("--max-input-len", type=int, default=-1) + args.add_argument("--min-decode-steps", type=int, default=256) args.add_argument("--debug-logging", action="store_true") parsed = args.parse_args() parsed.model, parsed.quantization = parsed.local_id.rsplit("-", 1) @@ -85,6 +86,7 @@ def run_server(): engine = LocalProcessInferenceEngine( model_module, max_batched_tokens=args.max_num_batched_tokens, + min_decode_steps=args.min_decode_steps, ) connector = AsyncEngineConnector(engine) app = create_app(connector)