diff --git a/tests/test_utils.py b/tests/test_utils.py index c68d730af7f8a..d5dc4464e634d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,10 +9,10 @@ from vllm_test_utils import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.utils import (FlexibleArgumentParser, PlaceholderModule, - StoreBoolean, bind_kv_cache, deprecate_kwargs, - get_open_port, memory_profiling, merge_async_iterators, - supports_kw) +from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, + PlaceholderModule, StoreBoolean, bind_kv_cache, + deprecate_kwargs, get_open_port, memory_profiling, + merge_async_iterators, supports_kw) from .utils import error_on_warning, fork_new_process_for_each_test @@ -284,14 +284,13 @@ def test_memory_profiling(): # 512 MiB allocation outside of this instance handle1 = lib.cudaMalloc(512 * 1024 * 1024) - baseline_memory_in_bytes = \ - torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0] + baseline_snapshot = MemorySnapshot() # load weights weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) - weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB + weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB def measure_current_non_torch(): free, total = torch.cuda.mem_get_info() @@ -300,8 +299,8 @@ def measure_current_non_torch(): current_non_torch = current_used - current_torch return current_non_torch - with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes, - weights_memory_in_bytes=weights_memory_in_bytes) as result, \ + with memory_profiling(baseline_snapshot=baseline_snapshot, + weights_memory=weights_memory) as result, \ monitor(measure_current_non_torch) as monitored_values: # make a memory spike, 1 GiB spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) @@ -316,13 +315,12 @@ def measure_current_non_torch(): assert measured_diff == 256 * 1024 * 1024 # Check that the memory usage is within 5% of the expected values - # 5% tolerance is caused by PyTorch caching allocator, - # we cannot control PyTorch's behavior of its internal buffers, + # 5% tolerance is caused by cuda runtime. + # we cannot control cuda runtime in the granularity of bytes, # which causes a small error (<10 MiB in practice) - non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa - torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa + non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 - assert abs(torch_peak_ratio - 1) <= 0.05 + assert result.torch_peak_increase == 1024 * 1024 * 1024 del weights lib.cudaFree(handle1) lib.cudaFree(handle2) diff --git a/vllm/utils.py b/vllm/utils.py index 89ba119bb5e55..17bffd2846b46 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1923,36 +1923,57 @@ def kill_process_tree(pid: int): @dataclass class MemorySnapshot: """Memory snapshot.""" - torch_peak_in_bytes: int = 0 - torch_memory_in_bytes: int = 0 + torch_peak: int = 0 + cuda_memory: int = 0 + torch_memory: int = 0 + non_torch_memory: int = 0 timestamp: float = 0.0 + auto_measure: bool = True + + def __post_init__(self): + if self.auto_measure: + self.measure() def measure(self): - self.torch_peak_in_bytes = torch.cuda.max_memory_reserved() + # we measure the torch peak memory usage via allocated_bytes, + # rather than `torch.cuda.memory_reserved()` . + # After `torch.cuda.reset_peak_memory_stats()`, + # `torch.cuda.memory_reserved()` will keep growing, and only shrink + # when we call `torch.cuda.empty_cache()` or OOM happens. + self.torch_peak = torch.cuda.memory_stats().get( + "allocated_bytes.all.peak", 0) + + self.cuda_memory = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + # torch.cuda.memory_reserved() is how many bytes # PyTorch gets from cuda (by calling cudaMalloc, etc.) - self.torch_memory_in_bytes = torch.cuda.memory_reserved() + # this is used to measure the non-torch memory usage + self.torch_memory = torch.cuda.memory_reserved() + + self.non_torch_memory = self.cuda_memory - self.torch_memory self.timestamp = time.time() def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot": - """support a - b""" return MemorySnapshot( - torch_peak_in_bytes=self.torch_peak_in_bytes - - other.torch_peak_in_bytes, - torch_memory_in_bytes=self.torch_memory_in_bytes - - other.torch_memory_in_bytes, - timestamp=self.timestamp - other.timestamp) + torch_peak=self.torch_peak - other.torch_peak, + cuda_memory=self.cuda_memory - other.cuda_memory, + torch_memory=self.torch_memory - other.torch_memory, + non_torch_memory=self.non_torch_memory - other.non_torch_memory, + timestamp=self.timestamp - other.timestamp, + auto_measure=False, + ) @dataclass class MemoryProfilingResult: - """Memory profiling result. - """ # noqa - baseline_memory_in_bytes: int = 0 - non_kv_cache_memory_in_bytes: int = 0 - torch_peak_increase_in_bytes: int = 0 - non_torch_increase_in_bytes: int = 0 - weights_memory_in_bytes: float = 0 + """Memory profiling result. All numbers are in bytes. + """ + non_kv_cache_memory: int = 0 + torch_peak_increase: int = 0 + non_torch_increase: int = 0 + weights_memory: float = 0 + before_create: MemorySnapshot = field(default_factory=MemorySnapshot) before_profile: MemorySnapshot = field(default_factory=MemorySnapshot) after_profile: MemorySnapshot = field(default_factory=MemorySnapshot) profile_time: float = 0.0 @@ -1960,18 +1981,14 @@ class MemoryProfilingResult: @contextlib.contextmanager def memory_profiling( - baseline_memory_in_bytes: int, weights_memory_in_bytes: int -) -> Generator[MemoryProfilingResult, None, None]: + baseline_snapshot: MemorySnapshot, + weights_memory: int) -> Generator[MemoryProfilingResult, None, None]: """Memory profiling context manager. - baseline_memory_in_bytes: memory used by all the components other than - the current vLLM instance. It contains: memory used by other processes, memory - used by another vLLM instance in the same process, etc. It is usually measured - before the current vLLM instance initialize the device. And we assume it is - constant during the profiling of the current vLLM instance. - weights_memory_in_bytes: memory used by PyTorch when loading the model weights. + baseline_snapshot: the memory snapshot before the current vLLM instance. + weights_memory: memory used by PyTorch when loading the model weights. Note that, before loading the model weights, we also initialize the device and distributed environment, which may consume some memory. This part is not - included in the weights_memory_in_bytes because PyTorch does not control it. + included in the weights_memory because PyTorch does not control it. The memory in one GPU can be classified into 3 categories: 1. memory used by anything other than the current vLLM instance. @@ -2006,20 +2023,21 @@ def memory_profiling( b. 2 GiB reserved for the peak activation tensors (category 2) c. 1 GiB used by non-torch components (category 3) - The memory used for loading weights (a.) is directly given from the argument `weights_memory_in_bytes`. + The memory used for loading weights (a.) is directly given from the argument `weights_memory`. - The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` after profiling gives (b.). + The increase of `torch.cuda.memory_stats()["allocated_bytes.all.peak"]` during profiling gives (b.). - (c.) is tricky. We measure the total memory used in this GPU (`torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`), - subtract the baseline memory, the memory used by the model weights, and diff of `torch.cuda.memory_reserved()`. + The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.). """ # noqa + gc.collect() + torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() result = MemoryProfilingResult() - result.baseline_memory_in_bytes = baseline_memory_in_bytes + result.before_create = baseline_snapshot # the part of memory used for holding the model weights - result.weights_memory_in_bytes = weights_memory_in_bytes + result.weights_memory = weights_memory result.before_profile.measure() @@ -2030,13 +2048,12 @@ def memory_profiling( result.after_profile.measure() - diff = result.after_profile - result.before_profile - result.torch_peak_increase_in_bytes = diff.torch_peak_in_bytes - current_cuda_memory_bytes = torch.cuda.mem_get_info( - )[1] - torch.cuda.mem_get_info()[0] - result.non_torch_increase_in_bytes = current_cuda_memory_bytes - baseline_memory_in_bytes - weights_memory_in_bytes - diff.torch_memory_in_bytes # noqa - result.profile_time = diff.timestamp - result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa + diff_profile = result.after_profile - result.before_profile + diff_from_create = result.after_profile - result.before_create + result.torch_peak_increase = diff_profile.torch_peak + result.non_torch_increase = diff_from_create.non_torch_memory + result.profile_time = diff_profile.timestamp + result.non_kv_cache_memory = result.non_torch_increase + result.torch_peak_increase + result.weights_memory # noqa # Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501 diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 43eeb287d64eb..29d62ddda3dc0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -21,7 +21,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, SequenceGroupMetadata, SequenceGroupMetadataDelta) -from vllm.utils import GiB_bytes, bind_kv_cache, memory_profiling +from vllm.utils import (GiB_bytes, MemorySnapshot, bind_kv_cache, + memory_profiling) from vllm.worker.cache_engine import CacheEngine from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner @@ -137,7 +138,8 @@ def init_device(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) gc.collect() torch.cuda.empty_cache() - self.init_gpu_memory = torch.cuda.mem_get_info()[0] + torch.cuda.reset_peak_memory_stats() + self.baseline_snapshot = MemorySnapshot() else: raise RuntimeError( f"Not support device type: {self.device_config.device}") @@ -192,10 +194,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - with memory_profiling(baseline_memory_in_bytes=total_gpu_memory - - self.init_gpu_memory, - weights_memory_in_bytes=self.model_runner. - model_memory_usage) as result: + with memory_profiling( + self.baseline_snapshot, + weights_memory=self.model_runner.model_memory_usage) as result: self.model_runner.profile_run() self._assert_memory_footprint_increased_during_profiling() @@ -203,7 +204,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: memory_for_current_instance = total_gpu_memory * \ self.cache_config.gpu_memory_utilization available_kv_cache_memory = (memory_for_current_instance - - result.non_kv_cache_memory_in_bytes) + result.non_kv_cache_memory) # Calculate the number of blocks that can be allocated with the # profiled peak memory. @@ -226,11 +227,11 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: f"({self.cache_config.gpu_memory_utilization:.2f})" f" = {(memory_for_current_instance / GiB_bytes):.2f}GiB\n" "model weights take " - f"{(result.weights_memory_in_bytes / GiB_bytes):.2f}GiB;" + f"{(result.weights_memory / GiB_bytes):.2f}GiB;" " non_torch_memory takes " - f"{(result.non_torch_increase_in_bytes / GiB_bytes):.2f}GiB;" + f"{(result.non_torch_increase / GiB_bytes):.2f}GiB;" " PyTorch activation peak memory takes " - f"{(result.torch_peak_increase_in_bytes / GiB_bytes):.2f}GiB;" + f"{(result.torch_peak_increase / GiB_bytes):.2f}GiB;" " the rest of the memory reserved for KV Cache is " f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.") @@ -246,11 +247,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: def _assert_memory_footprint_increased_during_profiling(self): # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. - free_gpu_memory, _ = torch.cuda.mem_get_info() - assert self.init_gpu_memory - free_gpu_memory > 0, ( + free_gpu_memory, total = torch.cuda.mem_get_info() + cuda_memory = total - free_gpu_memory + assert self.baseline_snapshot.cuda_memory < cuda_memory, ( "Error in memory profiling. " - f"Initial free memory {self.init_gpu_memory}, current free memory" - f" {free_gpu_memory}. This happens when the GPU memory was " + f"Initial used memory {self.baseline_snapshot.cuda_memory}, " + f"currently used memory {cuda_memory}. " + f"This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") def initialize_cache(self, num_gpu_blocks: int,