Skip to content

Commit

Permalink
[core] further polish memory profiling (vllm-project#12126)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored and lckr committed Jan 19, 2025
1 parent c7ea31a commit 914bc14
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 67 deletions.
26 changes: 12 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from vllm_test_utils import monitor

from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, PlaceholderModule,
StoreBoolean, bind_kv_cache, deprecate_kwargs,
get_open_port, memory_profiling, merge_async_iterators,
supports_kw)
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
PlaceholderModule, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw)

from .utils import error_on_warning, fork_new_process_for_each_test

Expand Down Expand Up @@ -284,14 +284,13 @@ def test_memory_profiling():
# 512 MiB allocation outside of this instance
handle1 = lib.cudaMalloc(512 * 1024 * 1024)

baseline_memory_in_bytes = \
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
baseline_snapshot = MemorySnapshot()

# load weights

weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)

weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB

def measure_current_non_torch():
free, total = torch.cuda.mem_get_info()
Expand All @@ -300,8 +299,8 @@ def measure_current_non_torch():
current_non_torch = current_used - current_torch
return current_non_torch

with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
with memory_profiling(baseline_snapshot=baseline_snapshot,
weights_memory=weights_memory) as result, \
monitor(measure_current_non_torch) as monitored_values:
# make a memory spike, 1 GiB
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
Expand All @@ -316,13 +315,12 @@ def measure_current_non_torch():
assert measured_diff == 256 * 1024 * 1024

# Check that the memory usage is within 5% of the expected values
# 5% tolerance is caused by PyTorch caching allocator,
# we cannot control PyTorch's behavior of its internal buffers,
# 5% tolerance is caused by cuda runtime.
# we cannot control cuda runtime in the granularity of bytes,
# which causes a small error (<10 MiB in practice)
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
assert abs(non_torch_ratio - 1) <= 0.05
assert abs(torch_peak_ratio - 1) <= 0.05
assert result.torch_peak_increase == 1024 * 1024 * 1024
del weights
lib.cudaFree(handle1)
lib.cudaFree(handle2)
Expand Down
95 changes: 56 additions & 39 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,55 +1923,72 @@ def kill_process_tree(pid: int):
@dataclass
class MemorySnapshot:
"""Memory snapshot."""
torch_peak_in_bytes: int = 0
torch_memory_in_bytes: int = 0
torch_peak: int = 0
cuda_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
auto_measure: bool = True

def __post_init__(self):
if self.auto_measure:
self.measure()

def measure(self):
self.torch_peak_in_bytes = torch.cuda.max_memory_reserved()
# we measure the torch peak memory usage via allocated_bytes,
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
"allocated_bytes.all.peak", 0)

self.cuda_memory = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]

# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
self.torch_memory_in_bytes = torch.cuda.memory_reserved()
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()

self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
"""support a - b"""
return MemorySnapshot(
torch_peak_in_bytes=self.torch_peak_in_bytes -
other.torch_peak_in_bytes,
torch_memory_in_bytes=self.torch_memory_in_bytes -
other.torch_memory_in_bytes,
timestamp=self.timestamp - other.timestamp)
torch_peak=self.torch_peak - other.torch_peak,
cuda_memory=self.cuda_memory - other.cuda_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
auto_measure=False,
)


@dataclass
class MemoryProfilingResult:
"""Memory profiling result.
""" # noqa
baseline_memory_in_bytes: int = 0
non_kv_cache_memory_in_bytes: int = 0
torch_peak_increase_in_bytes: int = 0
non_torch_increase_in_bytes: int = 0
weights_memory_in_bytes: float = 0
"""Memory profiling result. All numbers are in bytes.
"""
non_kv_cache_memory: int = 0
torch_peak_increase: int = 0
non_torch_increase: int = 0
weights_memory: float = 0
before_create: MemorySnapshot = field(default_factory=MemorySnapshot)
before_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
after_profile: MemorySnapshot = field(default_factory=MemorySnapshot)
profile_time: float = 0.0


@contextlib.contextmanager
def memory_profiling(
baseline_memory_in_bytes: int, weights_memory_in_bytes: int
) -> Generator[MemoryProfilingResult, None, None]:
baseline_snapshot: MemorySnapshot,
weights_memory: int) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.
baseline_memory_in_bytes: memory used by all the components other than
the current vLLM instance. It contains: memory used by other processes, memory
used by another vLLM instance in the same process, etc. It is usually measured
before the current vLLM instance initialize the device. And we assume it is
constant during the profiling of the current vLLM instance.
weights_memory_in_bytes: memory used by PyTorch when loading the model weights.
baseline_snapshot: the memory snapshot before the current vLLM instance.
weights_memory: memory used by PyTorch when loading the model weights.
Note that, before loading the model weights, we also initialize the device
and distributed environment, which may consume some memory. This part is not
included in the weights_memory_in_bytes because PyTorch does not control it.
included in the weights_memory because PyTorch does not control it.
The memory in one GPU can be classified into 3 categories:
1. memory used by anything other than the current vLLM instance.
Expand Down Expand Up @@ -2006,20 +2023,21 @@ def memory_profiling(
b. 2 GiB reserved for the peak activation tensors (category 2)
c. 1 GiB used by non-torch components (category 3)
The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`.
The memory used for loading weights (a.) is directly given from the argument `weights_memory`.
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.).
The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.).
(c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`),
subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`.
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

result = MemoryProfilingResult()

result.baseline_memory_in_bytes = baseline_memory_in_bytes
result.before_create = baseline_snapshot
# the part of memory used for holding the model weights
result.weights_memory_in_bytes = weights_memory_in_bytes
result.weights_memory = weights_memory

result.before_profile.measure()

Expand All @@ -2030,13 +2048,12 @@ def memory_profiling(

result.after_profile.measure()

diff = result.after_profile - result.before_profile
result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes
current_cuda_memory_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa
result.profile_time = diff.timestamp
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
diff_profile = result.after_profile - result.before_profile
diff_from_create = result.after_profile - result.before_create
result.torch_peak_increase = diff_profile.torch_peak
result.non_torch_increase = diff_from_create.non_torch_memory
result.profile_time = diff_profile.timestamp
result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa


# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
Expand Down
31 changes: 17 additions & 14 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling
from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache,
memory_profiling)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -137,7 +138,8 @@ def init_device(self) -> None:
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
torch.cuda.reset_peak_memory_stats()
self.baseline_snapshot = MemorySnapshot()
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -192,18 +194,17 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
with memory_profiling(baseline_memory_in_bytes=total_gpu_memory -
self.init_gpu_memory,
weights_memory_in_bytes=self.model_runner.
model_memory_usage) as result:
with memory_profiling(
self.baseline_snapshot,
weights_memory=self.model_runner.model_memory_usage) as result:
self.model_runner.profile_run()

self._assert_memory_footprint_increased_during_profiling()

memory_for_current_instance = total_gpu_memory * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (memory_for_current_instance -
result.non_kv_cache_memory_in_bytes)
result.non_kv_cache_memory)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
Expand All @@ -226,11 +227,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
f"({self.cache_config.gpu_memory_utilization:.2f})"
f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n"
"model weights take "
f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;"
f"{(result.weights_memory / GiB_bytes):.2f}GiB;"
" non_torch_memory takes "
f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;"
f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;"
" PyTorch activation peak memory takes "
f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;"
f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;"
" the rest of the memory reserved for KV Cache is "
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")

Expand All @@ -246,11 +247,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, _ = torch.cuda.mem_get_info()
assert self.init_gpu_memory - free_gpu_memory > 0, (
free_gpu_memory, total = torch.cuda.mem_get_info()
cuda_memory = total - free_gpu_memory
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
f"currently used memory {cuda_memory}. "
f"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

def initialize_cache(self, num_gpu_blocks: int,
Expand Down

0 comments on commit 914bc14

Please sign in to comment.