Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon committed Nov 21, 2024
1 parent 09cdeb3 commit c18a961
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
Expand Down Expand Up @@ -515,7 +516,25 @@ def load_model(self) -> None:
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))

def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
) -> torch.Tensor:
with set_forward_context(None):
hidden_states = model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens])
return hidden_states

def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
Expand All @@ -527,23 +546,17 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(input_ids=None,
positions=self.positions,
kv_caches=dummy_kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds)

@torch.inference_mode()
def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
self._dummy_run(self.model, self.max_num_tokens)
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
torch.cuda.synchronize()
del hidden_states, logits
gc.collect()

@torch.inference_mode()
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
Expand All @@ -554,18 +567,11 @@ def capture_model(self) -> None:
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens],
)
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self._dummy_run(self.model, num_tokens, self.kv_caches)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
Expand Down

0 comments on commit c18a961

Please sign in to comment.