Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] store inductor compiled Python file #12182

Merged
merged 2 commits into from
Jan 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,30 @@
logger = init_logger(__name__)


@dataclasses.dataclass
class InductorArtifact:
hash_str: str = ""
file_path: str = ""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside the file_path, we can find a call function. that is the ultimate function inductor compiled and called.



class InductorHashCache:
"""
Disk format: a Python list of tuples, each tuple is
(runtime_shape, graph_index, hash_str)
(runtime_shape, graph_index, hash_str, file_path)
We use list of tuple for readability.

In-memory format: a defaultdict of dict, where the key is
runtime_shape, and the value is a dict of graph_index to hash_str.

The data is essentially `Dict[Optional[int], Dict[int, str]]`,
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`,
we don't use json here because json doesn't support int as key.

TODO: better off-the-shelf solution to serialize the data?
"""

def __init__(self, cache_dir: str, disabled: bool = False):
self.cache: defaultdict = defaultdict(dict)
self.cache: Dict[Optional[int],
Dict[int, InductorArtifact]] = defaultdict(dict)
self.disabled = disabled
self.cache_dir = cache_dir
self.cache_file_path = os.path.join(cache_dir,
Expand All @@ -66,14 +73,25 @@ def deserialize(self, data: str):
# because it is a safe way to parse Python literals.
# do not use eval(), it is unsafe.
list_data = ast.literal_eval(data)
for runtime_shape, graph_index, hash_str in list_data:
self.cache[runtime_shape][graph_index] = hash_str
for item in list_data:
runtime_shape = item[0]
graph_index = item[1]
hash_str = item[2]
# for compatibility of old version,
# where we don't have file_path.
# NOTE: after running the new code, the file_path
# will be updated.
file_path = "" if len(item) == 3 else item[3]
self.cache[runtime_shape][graph_index] = InductorArtifact(
hash_str=hash_str, file_path=file_path)

def serialize(self) -> str:
data = []
for runtime_shape, graph_index_to_hash_str in self.cache.items():
for graph_index, hash_str in graph_index_to_hash_str.items():
data.append((runtime_shape, graph_index, hash_str))
for runtime_shape, value in self.cache.items():
for graph_index, inductor_artifact in value.items():
data.append(
(runtime_shape, graph_index, inductor_artifact.hash_str,
inductor_artifact.file_path))
printer = pprint.PrettyPrinter(indent=4)
return printer.pformat(data)

Expand All @@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
return runtime_shape in self.cache and graph_index in self.cache[
runtime_shape]

def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact:
if self.disabled:
raise KeyError("cannot read from disabled cache")
runtime_shape, graph_index = key
return self.cache[runtime_shape][graph_index]

def __setitem__(self, key: Tuple[Optional[int], int], value: str):
def __setitem__(self, key: Tuple[Optional[int], int],
value: InductorArtifact):
# setitem for disabled cache is fine, because we
# don't actually write to the disk
runtime_shape, graph_index = key
Expand Down Expand Up @@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule,
if (runtime_shape, graph_index) in cache_data:
# we compiled this graph before
# so we can directly lookup the compiled graph via hash
hash_str = cache_data[(runtime_shape, graph_index)]
inductor_artifact = cache_data[(runtime_shape, graph_index)]
hash_str = inductor_artifact.hash_str
if graph_index == 0:
# adds some info logging for the first graph
logger.info(
Expand All @@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule,
"Inductor cache lookup failed. Please remove"
f"the cache file {cache_data.cache_file_path} and try again." # noqa
)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only update the file_path of graph_index 0? Do we only support fullgraph mode of torch.compile?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be misunderstanding. if graph_index == 0: is only for logging, this line is executed for every piece of graph.


# Inductor calling convention (function signature):
# f(list) -> tuple
Expand All @@ -224,19 +245,20 @@ def compiled_graph(*args):
# the assumption is that we don't have nested Inductor compilation.
# compiled_fx_graph_hash will only be called once, and we can hook
# it to get the hash of the compiled graph directly.
from torch._inductor.codecache import compiled_fx_graph_hash

inductor_artifact = InductorArtifact()
from torch._inductor.codecache import (FxGraphCache,
compiled_fx_graph_hash)
original_load = FxGraphCache.load

def hijack_load(*args, **kwargs):
inductor_compiled_graph = original_load(*args, **kwargs)
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa
return inductor_compiled_graph

def hijack_compiled_fx_graph_hash(*args, **kwargs):
out = compiled_fx_graph_hash(*args, **kwargs)
# store the hash in the cache
nonlocal cache_data
cache_data[(runtime_shape, graph_index)] = out[0]
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug("store the %s-th graph for shape %s via hash %s",
graph_index, str(runtime_shape), out[0])
inductor_artifact.hash_str = out[0]
return out

def _check_can_cache(*args, **kwargs):
Expand All @@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
if not cache_data.disabled:
# compilation cache is enabled, patch several functions

# hijack to get the compiled graph itself
stack.enter_context(
patch("torch._inductor.codecache.FxGraphCache.load",
hijack_load))

# for hijacking the hash of the compiled graph
stack.enter_context(
patch("torch._inductor.codecache.compiled_fx_graph_hash",
Expand All @@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
compiled_graph = compile_fx(graph,
example_inputs,
config_patches=current_config)

# store the inductor_artifact in the cache
cache_data[(runtime_shape, graph_index)] = inductor_artifact
if graph_index == 0:
# adds some info logging for the first graph
logger.info("Cache the graph of shape %s for later use",
str(runtime_shape))
logger.debug(
"store the %s-th graph for shape %s via hash %s from file %s",
graph_index, str(runtime_shape), inductor_artifact.hash_str,
inductor_artifact.file_path)
# after compiling the last graph, record the end time
if graph_index == num_graphs - 1:
now = time.time()
Expand Down
13 changes: 2 additions & 11 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None:
"vllm.unified_attention_with_output",
]
else:
# v0 can use full graph compilation without splitting,
# splitting is optional.
# right now we still need it. kv cache shape
# will be included in the graph if we don't split
# the graph.
# TODO: hide kv cache in static forward context
# so that inductor does not see it.
self.splitting_ops = [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
]
# v0 uses full graph compilation
self.splitting_ops = []
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that kv cache is completely hidden from torch.compile, changing gpu memory utilization (and thus changing kv cache shape) will not affect torch.compile . thus they can share the same compilation cache.


for k, v in self.inductor_passes.items():
if not isinstance(v, str):
Expand Down
Loading