Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add initial support for Prometheus metric collection #95

Merged
merged 9 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions serve/mlc_serve/api/handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
import uuid
import json
import os

from http import HTTPStatus
from typing import Annotated, AsyncIterator, List

Expand Down Expand Up @@ -61,6 +63,25 @@ def _get_sampling_params(request: ChatCompletionRequest) -> SamplingParams:
return sampling_params


@router.get("/metrics")
def metrics():
# See https://prometheus.github.io/client_python/multiprocess/ for why we need this.
if "PROMETHEUS_MULTIPROC_DIR" in os.environ:
from prometheus_client import (
CollectorRegistry,
generate_latest,
multiprocess,
)
from starlette.responses import Response

registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry)

return Response(content=generate_latest(registry))
else:
return {}


@router.post("/v1/chat/completions")
async def request_completion(
request: ChatCompletionRequest,
Expand Down
1 change: 1 addition & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ class RequestState:
sampling_params: SamplingParams
stopping_criteria: StoppingCriteria
debug_options: DebugOptions
arrival_timestamp: float
is_ended: bool = False
validation_err: Optional[ValidationError] = None

Expand Down
43 changes: 43 additions & 0 deletions serve/mlc_serve/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from .metrics_labels import *
from prometheus_client import Counter, Histogram, Gauge


class PrometheusMetrics:
def __init__(self):
self.counters = {}
self.histograms = {}
self.gauges = {}

for label in [NUM_CACHE_EVICTONS]:
self.counters[label] = Counter(label, label)

buckets_e2e_lat = (0.5, 2.5, 4.5, 6.5, 8.5, 10.5, 12.5, 14.5, 16.5, 18.5)
buckets_ttft = (0.1, 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3.0)
buckets_batched_prefill_tokens = (500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 6000, 7000, 8000)
buckets_batched_decode_tokens = (1, 10, 30, 50, 75, 100, 125, 150, 175, 200, 250, 300)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the future, I think tracking memory & compute utilization would be helpful. With my naive understanding, this might be tricky since it may require per-gpu tracking.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a tool to get such information we can certainly do that. But for memory util (not KV cache util), due to memory profiling and KV cache pre-allocation, we always use 90% of available VRAM.

for label, buckets in [
(E2E_LATENCY, buckets_e2e_lat),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we handle exception in these metrics? Is it included in e2e latency buckets for example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We record a metric only when a request successfully reaches the point in the code where the metric is instrumented. For e2e latency, that's L165 in stage_engne_worker.py. So if an exception is raised, the request doesn't reach that point and hence it won't be recoreded.

Copy link
Member

@sunggg sunggg Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we count the number of each type of exceptions raised? That would be helpful to identify abnormal behavior.

Copy link
Member Author

@masahi masahi Dec 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's possible if we have a particular exception in mind. We need to be aware of where that exception could be raised, and add a counter in except: block. This could be one of follow-up work we can do.

(FIRST_TOKEN_LATENCY, buckets_ttft),
(BATCHED_PREFILL_TOKENS, buckets_batched_prefill_tokens),
(BATCHED_DECODE_TOKENS, buckets_batched_decode_tokens),
]:
self.histograms[label] = Histogram(label, label, buckets=buckets)

for label in [KV_CACHE_UTILIZATION]:
self.gauges[label] = Gauge(label, label)

def _lookup(self, metrics_dict, label):
if label in metrics_dict:
return metrics_dict[label]

return RuntimeError(f"No metric {label} found.")

def counter(self, label: str):
return self._lookup(self.counters, label)

def histogram(self, label: str):
return self._lookup(self.histograms, label)

def gauge(self, label: str):
return self._lookup(self.gauges, label)
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/metrics_labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
NUM_CACHE_EVICTONS = "num_cache_evictions"
E2E_LATENCY = "e2e_latency"
FIRST_TOKEN_LATENCY = "first_token_latency"
KV_CACHE_UTILIZATION = "kv_cache_utilization"
BATCHED_PREFILL_TOKENS = "batched_prefill_tokens"
BATCHED_DECODE_TOKENS = "batched_decode_tokens"
2 changes: 2 additions & 0 deletions serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
An implementation of InferenceEngine that offloads the text generation loop to another worker process.
"""
import time
import logging
import multiprocessing
import queue
Expand Down Expand Up @@ -259,6 +260,7 @@ def _get_new_request_state(self, request: Request) -> RequestState:
debug_options=request.debug_options,
output_text="",
validation_err=validation_err,
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
Expand Down
49 changes: 40 additions & 9 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
The worker for StagingInferenceEngine
"""
import os
import time
import multiprocessing
import multiprocessing.synchronize
from collections import deque
Expand All @@ -12,12 +12,22 @@
import structlog

from .base import FinishReason, RequestId, RequestState, ValidationError
from .model_module import DecodeRequest, ModelModule, PrefillRequest, SequenceId, TextGenerator, Tokenizer as TokenizerP
from .metrics import PrometheusMetrics
from .metrics_labels import *
from .model_module import (
DecodeRequest,
ModelModule,
PrefillRequest,
SequenceId,
TextGenerator,
Tokenizer as TokenizerP,
)
from ..model.base import ModelArtifactConfig
from ..logging_utils import configure_logging

LOG = structlog.stdlib.get_logger(__name__)


@dataclass
class ShutdownCommand:
pass
Expand Down Expand Up @@ -73,7 +83,6 @@ class GenerationLoopWorker:
stopped_requests: List[RequestState]
current_batch: Dict[RequestId, RequestState]


def __init__(
self,
model_module: ModelModule,
Expand Down Expand Up @@ -102,6 +111,9 @@ def __init__(

self.current_batch = dict[RequestId, RequestState]()

self.prom_metrics = PrometheusMetrics()
self.inv_kv_cache_size = 1.0 / self.cache_manager.get_kv_cache_size()

def add(self, request_states: list[RequestState]):
LOG.debug("GenerationLoopWorker", requests_states=request_states)
with self.queue_lock:
Expand All @@ -125,7 +137,6 @@ def add(self, request_states: list[RequestState]):
self.queue.extend(valid_states)
self.has_new_requests.notify_all()


def _cacnel_or_stop_request(
masahi marked this conversation as resolved.
Show resolved Hide resolved
self, request_id: RequestId, requests: list[RequestState]
):
Expand Down Expand Up @@ -182,6 +193,8 @@ def step(self) -> GenerationLoopWorkerOutput:
)
)
self._remove_request_from_batch(state.request_id)
duration = time.time() - state.arrival_timestamp
self.prom_metrics.histogram(E2E_LATENCY).observe(duration)

for state in self.stopped_requests:
outputs.append(
Expand Down Expand Up @@ -226,7 +239,7 @@ def step(self) -> GenerationLoopWorkerOutput:
)
return result

requests = self._get_requests_to_process()
requests, is_prompt_batch = self._get_requests_to_process()
results = self.text_generator.generate(requests, self.cache_manager.get_cache())
LOG.debug("Finished text generation.")

Expand Down Expand Up @@ -264,13 +277,22 @@ def step(self) -> GenerationLoopWorkerOutput:
SequenceGenerationOutput(id=res.sequence_id, new_tokens=new_tokens)
)

if is_prompt_batch:
ttft = time.time() - state.arrival_timestamp
self.prom_metrics.histogram(FIRST_TOKEN_LATENCY).observe(ttft)

self.prom_metrics.gauge(KV_CACHE_UTILIZATION).set(
1.0 - self.cache_manager.get_free_space() * self.inv_kv_cache_size
)

LOG.debug("Finished state update and stopping criteria check.")

return result

def _adjust_batch(self):
with self.queue_lock:
while self.cache_manager.get_max_new_tokens() < 1:
self.prom_metrics.counter(NUM_CACHE_EVICTONS).inc()
request_to_remove = min(
self.current_batch.values(), key=lambda s: len(s.token_ids)
)
Expand Down Expand Up @@ -344,6 +366,8 @@ def _get_requests_to_process(self):
)

if is_prompt_batch:
prefill_token_counts = 0

for state in self.current_batch.values():
if state.next_start_position == 0:
requests.append(
Expand All @@ -354,10 +378,14 @@ def _get_requests_to_process(self):
sampling_params=state.sampling_params,
)
)
prefill_token_counts += len(state.token_ids)

self.prom_metrics.histogram(BATCHED_PREFILL_TOKENS).observe(prefill_token_counts)

LOG.debug(
"Creating prompt batch.",
num_requests=len(requests),
total_tokens=sum(len(r.token_ids) for r in requests),
total_tokens=prefill_token_counts,
)
else:
for state in self.current_batch.values():
Expand All @@ -372,9 +400,13 @@ def _get_requests_to_process(self):
self.cache_manager.extend(
seq_id, len(state.token_ids) - state.next_start_position
)
LOG.debug("Creating decode batch with %s requests.", len(requests))

return requests
decode_token_counts = len(requests)
self.prom_metrics.histogram(BATCHED_DECODE_TOKENS).observe(decode_token_counts)

LOG.debug("Creating decode batch with %s requests.", decode_token_counts)

return requests, is_prompt_batch

def _has_request_to_process(self) -> bool:
return len(self.queue) != 0 or len(self.current_batch) != 0
Expand Down Expand Up @@ -406,7 +438,6 @@ def run_generation_loop_worker(
enable_json_logs = False,
log_level="INFO",
):

configure_logging(enable_json_logs, log_level)
structlog.contextvars.bind_contextvars(**contextvars)

Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/engine/sync_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
A implementation of InferenceEngine that executes in the current process.
"""

import time
import logging
from typing import Deque, Set, Dict
from collections import deque
from sre_parse import Tokenizer
from threading import Condition, Lock
from uuid import uuid4

from .base import (
FinishReason,
Expand Down Expand Up @@ -384,6 +383,7 @@ def _get_new_request_state(self, request: Request) -> RequestState:
stopping_criteria=request.stopping_criteria,
debug_options=request.debug_options,
output_text="",
arrival_timestamp=time.time(),
)

def _decode_last_output(self, state: RequestState) -> str:
Expand Down
26 changes: 15 additions & 11 deletions serve/mlc_serve/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
import logging.config
import tempfile
import os
import uvicorn
from pathlib import Path
Expand Down Expand Up @@ -89,16 +89,20 @@ def run_server():
log_level = "DEBUG" if args.debug_logging else "INFO"
configure_logging(enable_json_logs=True, log_level=log_level)

engine = create_engine(args)
connector = AsyncEngineConnector(engine)
app = create_app(connector)
uvicorn.run(
app,
host=args.host,
port=args.port,
reload=False,
access_log=False,
)
with tempfile.TemporaryDirectory() as temp_dir:
os.environ["PROMETHEUS_MULTIPROC_DIR"] = temp_dir

engine = create_engine(args)
connector = AsyncEngineConnector(engine)
app = create_app(connector)

uvicorn.run(
app,
host=args.host,
port=args.port,
reload=False,
access_log=False,
)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion serve/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires-python = ">=3.9"
description = "LLM Batch Inference server"
dynamic = ["version"]

dependencies = ["fastapi==0.103.1", "pydantic>=1.8.0"]
dependencies = ["fastapi==0.103.1", "pydantic>=1.8.0", "prometheus-client>=0.18.0"]

[project.optional-dependencies]
test = ["pytest~=7.4.2", "httpx_sse~=0.3.1", "pytest-timeout~=2.2.0"]
Expand Down