diff --git a/serve/mlc_serve/api/handler.py b/serve/mlc_serve/api/handler.py index 1245553883..8457502c38 100644 --- a/serve/mlc_serve/api/handler.py +++ b/serve/mlc_serve/api/handler.py @@ -1,6 +1,8 @@ import time import uuid import json +import os + from http import HTTPStatus from typing import Annotated, AsyncIterator, List @@ -61,6 +63,25 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams: return sampling_params +@router.get("/metrics") +def metrics(): + # See https://prometheus.github.io/client_python/multiprocess/ for why we need this. + if "PROMETHEUS_MULTIPROC_DIR" in os.environ: + from prometheus_client import ( + CollectorRegistry, + generate_latest, + multiprocess, + ) + from starlette.responses import Response + + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + return Response(content=generate_latest(registry)) + else: + return {} + + @router.post("/v1/chat/completions") async def request_completion( request: ChatCompletionRequest, diff --git a/serve/mlc_serve/engine/base.py b/serve/mlc_serve/engine/base.py index be373bf48b..2eddf04878 100644 --- a/serve/mlc_serve/engine/base.py +++ b/serve/mlc_serve/engine/base.py @@ -251,6 +251,7 @@ class RequestState: sampling_params: SamplingParams stopping_criteria: StoppingCriteria debug_options: DebugOptions + arrival_timestamp: float is_ended: bool = False validation_err: Optional[ValidationError] = None diff --git a/serve/mlc_serve/engine/metrics.py b/serve/mlc_serve/engine/metrics.py new file mode 100644 index 0000000000..3e76485927 --- /dev/null +++ b/serve/mlc_serve/engine/metrics.py @@ -0,0 +1,43 @@ +from .metrics_labels import * +from prometheus_client import Counter, Histogram, Gauge + + +class PrometheusMetrics: + def __init__(self): + self.counters = {} + self.histograms = {} + self.gauges = {} + + for label in [NUM_CACHE_EVICTONS]: + self.counters[label] = Counter(label, label) + + buckets_e2e_lat = (0.5, 2.5, 4.5, 6.5, 8.5, 10.5, 12.5, 14.5, 16.5, 18.5) + buckets_ttft = (0.1, 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3.0) + buckets_batched_prefill_tokens = (500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 6000, 7000, 8000) + buckets_batched_decode_tokens = (1, 10, 30, 50, 75, 100, 125, 150, 175, 200, 250, 300) + + for label, buckets in [ + (E2E_LATENCY, buckets_e2e_lat), + (FIRST_TOKEN_LATENCY, buckets_ttft), + (BATCHED_PREFILL_TOKENS, buckets_batched_prefill_tokens), + (BATCHED_DECODE_TOKENS, buckets_batched_decode_tokens), + ]: + self.histograms[label] = Histogram(label, label, buckets=buckets) + + for label in [KV_CACHE_UTILIZATION]: + self.gauges[label] = Gauge(label, label) + + def _lookup(self, metrics_dict, label): + if label in metrics_dict: + return metrics_dict[label] + + return RuntimeError(f"No metric {label} found.") + + def counter(self, label: str): + return self._lookup(self.counters, label) + + def histogram(self, label: str): + return self._lookup(self.histograms, label) + + def gauge(self, label: str): + return self._lookup(self.gauges, label) diff --git a/serve/mlc_serve/engine/metrics_labels.py b/serve/mlc_serve/engine/metrics_labels.py new file mode 100644 index 0000000000..d90b3ee2cc --- /dev/null +++ b/serve/mlc_serve/engine/metrics_labels.py @@ -0,0 +1,6 @@ +NUM_CACHE_EVICTONS = "num_cache_evictions" +E2E_LATENCY = "e2e_latency" +FIRST_TOKEN_LATENCY = "first_token_latency" +KV_CACHE_UTILIZATION = "kv_cache_utilization" +BATCHED_PREFILL_TOKENS = "batched_prefill_tokens" +BATCHED_DECODE_TOKENS = "batched_decode_tokens" diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 2ed2b54dcc..e60fb23a24 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -1,6 +1,7 @@ """ An implementation of InferenceEngine that offloads the text generation loop to another worker process. """ +import time import logging import multiprocessing import queue @@ -259,6 +260,7 @@ def _get_new_request_state(self, request: Request) -> RequestState: debug_options=request.debug_options, output_text="", validation_err=validation_err, + arrival_timestamp=time.time(), ) def _decode_last_output(self, state: RequestState) -> str: diff --git a/serve/mlc_serve/engine/staging_engine_worker.py b/serve/mlc_serve/engine/staging_engine_worker.py index 8546c3f75a..583aea8f89 100644 --- a/serve/mlc_serve/engine/staging_engine_worker.py +++ b/serve/mlc_serve/engine/staging_engine_worker.py @@ -1,7 +1,7 @@ """ The worker for StagingInferenceEngine """ -import os +import time import multiprocessing import multiprocessing.synchronize from collections import deque @@ -12,12 +12,22 @@ import structlog from .base import FinishReason, RequestId, RequestState, ValidationError -from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP +from .metrics import PrometheusMetrics +from .metrics_labels import * +from .model_module import ( + DecodeRequest, + ModelModule, + PrefillRequest, + SequenceId, + TextGenerator, + Tokenizer as TokenizerP, +) from ..model.base import ModelArtifactConfig from ..logging_utils import configure_logging LOG = structlog.stdlib.get_logger(__name__) + @dataclass class ShutdownCommand: pass @@ -73,7 +83,6 @@ class GenerationLoopWorker: stopped_requests: List[RequestState] current_batch: Dict[RequestId, RequestState] - def __init__( self, model_module: ModelModule, @@ -102,6 +111,9 @@ def __init__( self.current_batch = dict[RequestId, RequestState]() + self.prom_metrics = PrometheusMetrics() + self.inv_kv_cache_size = 1.0 / self.cache_manager.get_kv_cache_size() + def add(self, request_states: list[RequestState]): LOG.debug("GenerationLoopWorker", requests_states=request_states) with self.queue_lock: @@ -125,7 +137,6 @@ def add(self, request_states: list[RequestState]): self.queue.extend(valid_states) self.has_new_requests.notify_all() - def _cacnel_or_stop_request( self, request_id: RequestId, requests: list[RequestState] ): @@ -182,6 +193,8 @@ def step(self) -> GenerationLoopWorkerOutput: ) ) self._remove_request_from_batch(state.request_id) + duration = time.time() - state.arrival_timestamp + self.prom_metrics.histogram(E2E_LATENCY).observe(duration) for state in self.stopped_requests: outputs.append( @@ -226,7 +239,7 @@ def step(self) -> GenerationLoopWorkerOutput: ) return result - requests = self._get_requests_to_process() + requests, is_prompt_batch = self._get_requests_to_process() results = self.text_generator.generate(requests, self.cache_manager.get_cache()) LOG.debug("Finished text generation.") @@ -264,6 +277,14 @@ def step(self) -> GenerationLoopWorkerOutput: SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens) ) + if is_prompt_batch: + ttft = time.time() - state.arrival_timestamp + self.prom_metrics.histogram(FIRST_TOKEN_LATENCY).observe(ttft) + + self.prom_metrics.gauge(KV_CACHE_UTILIZATION).set( + 1.0 - self.cache_manager.get_free_space() * self.inv_kv_cache_size + ) + LOG.debug("Finished state update and stopping criteria check.") return result @@ -271,6 +292,7 @@ def step(self) -> GenerationLoopWorkerOutput: def _adjust_batch(self): with self.queue_lock: while self.cache_manager.get_max_new_tokens() < 1: + self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc() request_to_remove = min( self.current_batch.values(), key=lambda s: len(s.token_ids) ) @@ -344,6 +366,8 @@ def _get_requests_to_process(self): ) if is_prompt_batch: + prefill_token_counts = 0 + for state in self.current_batch.values(): if state.next_start_position == 0: requests.append( @@ -354,10 +378,14 @@ def _get_requests_to_process(self): sampling_params=state.sampling_params, ) ) + prefill_token_counts += len(state.token_ids) + + self.prom_metrics.histogram(BATCHED_PREFILL_TOKENS).observe(prefill_token_counts) + LOG.debug( "Creating prompt batch.", num_requests=len(requests), - total_tokens=sum(len(r.token_ids) for r in requests), + total_tokens=prefill_token_counts, ) else: for state in self.current_batch.values(): @@ -372,9 +400,13 @@ def _get_requests_to_process(self): self.cache_manager.extend( seq_id, len(state.token_ids) - state.next_start_position ) - LOG.debug("Creating decode batch with %s requests.", len(requests)) - return requests + decode_token_counts = len(requests) + self.prom_metrics.histogram(BATCHED_DECODE_TOKENS).observe(decode_token_counts) + + LOG.debug("Creating decode batch with %s requests.", decode_token_counts) + + return requests, is_prompt_batch def _has_request_to_process(self) -> bool: return len(self.queue) != 0 or len(self.current_batch) != 0 @@ -406,7 +438,6 @@ def run_generation_loop_worker( enable_json_logs = False, log_level="INFO", ): - configure_logging(enable_json_logs, log_level) structlog.contextvars.bind_contextvars(**contextvars) diff --git a/serve/mlc_serve/engine/sync_engine.py b/serve/mlc_serve/engine/sync_engine.py index ad3892423e..adc961c15b 100644 --- a/serve/mlc_serve/engine/sync_engine.py +++ b/serve/mlc_serve/engine/sync_engine.py @@ -2,12 +2,11 @@ A implementation of InferenceEngine that executes in the current process. """ +import time import logging from typing import Deque, Set, Dict from collections import deque -from sre_parse import Tokenizer from threading import Condition, Lock -from uuid import uuid4 from .base import ( FinishReason, @@ -384,6 +383,7 @@ def _get_new_request_state(self, request: Request) -> RequestState: stopping_criteria=request.stopping_criteria, debug_options=request.debug_options, output_text="", + arrival_timestamp=time.time(), ) def _decode_last_output(self, state: RequestState) -> str: diff --git a/serve/mlc_serve/run.py b/serve/mlc_serve/run.py index c2d6dd5b20..59bc7651ce 100644 --- a/serve/mlc_serve/run.py +++ b/serve/mlc_serve/run.py @@ -1,5 +1,5 @@ import argparse -import logging.config +import tempfile import os import uvicorn from pathlib import Path @@ -89,16 +89,20 @@ def run_server(): log_level = "DEBUG" if args.debug_logging else "INFO" configure_logging(enable_json_logs=True, log_level=log_level) - engine = create_engine(args) - connector = AsyncEngineConnector(engine) - app = create_app(connector) - uvicorn.run( - app, - host=args.host, - port=args.port, - reload=False, - access_log=False, - ) + with tempfile.TemporaryDirectory() as temp_dir: + os.environ["PROMETHEUS_MULTIPROC_DIR"] = temp_dir + + engine = create_engine(args) + connector = AsyncEngineConnector(engine) + app = create_app(connector) + + uvicorn.run( + app, + host=args.host, + port=args.port, + reload=False, + access_log=False, + ) if __name__ == "__main__": diff --git a/serve/pyproject.toml b/serve/pyproject.toml index 407438b02d..cb56f1bbf3 100644 --- a/serve/pyproject.toml +++ b/serve/pyproject.toml @@ -4,7 +4,7 @@ requires-python = ">=3.9" description = "LLM Batch Inference server" dynamic = ["version"] -dependencies = ["fastapi==0.103.1", "pydantic>=1.8.0"] +dependencies = ["fastapi==0.103.1", "pydantic>=1.8.0", "prometheus-client>=0.18.0"] [project.optional-dependencies] test = ["pytest~=7.4.2", "httpx_sse~=0.3.1", "pytest-timeout~=2.2.0"]