-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] store inductor compiled Python file #12182
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,23 +25,30 @@ | |
logger = init_logger(__name__) | ||
|
||
|
||
@dataclasses.dataclass | ||
class InductorArtifact: | ||
hash_str: str = "" | ||
file_path: str = "" | ||
|
||
|
||
class InductorHashCache: | ||
""" | ||
Disk format: a Python list of tuples, each tuple is | ||
(runtime_shape, graph_index, hash_str) | ||
(runtime_shape, graph_index, hash_str, file_path) | ||
We use list of tuple for readability. | ||
|
||
In-memory format: a defaultdict of dict, where the key is | ||
runtime_shape, and the value is a dict of graph_index to hash_str. | ||
|
||
The data is essentially `Dict[Optional[int], Dict[int, str]]`, | ||
The data is essentially `Dict[Optional[int], Dict[int, InductorArtifact]]`, | ||
we don't use json here because json doesn't support int as key. | ||
|
||
TODO: better off-the-shelf solution to serialize the data? | ||
""" | ||
|
||
def __init__(self, cache_dir: str, disabled: bool = False): | ||
self.cache: defaultdict = defaultdict(dict) | ||
self.cache: Dict[Optional[int], | ||
Dict[int, InductorArtifact]] = defaultdict(dict) | ||
self.disabled = disabled | ||
self.cache_dir = cache_dir | ||
self.cache_file_path = os.path.join(cache_dir, | ||
|
@@ -66,14 +73,25 @@ def deserialize(self, data: str): | |
# because it is a safe way to parse Python literals. | ||
# do not use eval(), it is unsafe. | ||
list_data = ast.literal_eval(data) | ||
for runtime_shape, graph_index, hash_str in list_data: | ||
self.cache[runtime_shape][graph_index] = hash_str | ||
for item in list_data: | ||
runtime_shape = item[0] | ||
graph_index = item[1] | ||
hash_str = item[2] | ||
# for compatibility of old version, | ||
# where we don't have file_path. | ||
# NOTE: after running the new code, the file_path | ||
# will be updated. | ||
file_path = "" if len(item) == 3 else item[3] | ||
self.cache[runtime_shape][graph_index] = InductorArtifact( | ||
hash_str=hash_str, file_path=file_path) | ||
|
||
def serialize(self) -> str: | ||
data = [] | ||
for runtime_shape, graph_index_to_hash_str in self.cache.items(): | ||
for graph_index, hash_str in graph_index_to_hash_str.items(): | ||
data.append((runtime_shape, graph_index, hash_str)) | ||
for runtime_shape, value in self.cache.items(): | ||
for graph_index, inductor_artifact in value.items(): | ||
data.append( | ||
(runtime_shape, graph_index, inductor_artifact.hash_str, | ||
inductor_artifact.file_path)) | ||
printer = pprint.PrettyPrinter(indent=4) | ||
return printer.pformat(data) | ||
|
||
|
@@ -90,13 +108,14 @@ def __contains__(self, key: Tuple[Optional[int], int]) -> bool: | |
return runtime_shape in self.cache and graph_index in self.cache[ | ||
runtime_shape] | ||
|
||
def __getitem__(self, key: Tuple[Optional[int], int]) -> str: | ||
def __getitem__(self, key: Tuple[Optional[int], int]) -> InductorArtifact: | ||
if self.disabled: | ||
raise KeyError("cannot read from disabled cache") | ||
runtime_shape, graph_index = key | ||
return self.cache[runtime_shape][graph_index] | ||
|
||
def __setitem__(self, key: Tuple[Optional[int], int], value: str): | ||
def __setitem__(self, key: Tuple[Optional[int], int], | ||
value: InductorArtifact): | ||
# setitem for disabled cache is fine, because we | ||
# don't actually write to the disk | ||
runtime_shape, graph_index = key | ||
|
@@ -181,7 +200,8 @@ def wrap_inductor(graph: fx.GraphModule, | |
if (runtime_shape, graph_index) in cache_data: | ||
# we compiled this graph before | ||
# so we can directly lookup the compiled graph via hash | ||
hash_str = cache_data[(runtime_shape, graph_index)] | ||
inductor_artifact = cache_data[(runtime_shape, graph_index)] | ||
hash_str = inductor_artifact.hash_str | ||
if graph_index == 0: | ||
# adds some info logging for the first graph | ||
logger.info( | ||
|
@@ -199,6 +219,7 @@ def wrap_inductor(graph: fx.GraphModule, | |
"Inductor cache lookup failed. Please remove" | ||
f"the cache file {cache_data.cache_file_path} and try again." # noqa | ||
) | ||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we only update the file_path of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to be misunderstanding. |
||
|
||
# Inductor calling convention (function signature): | ||
# f(list) -> tuple | ||
|
@@ -224,19 +245,20 @@ def compiled_graph(*args): | |
# the assumption is that we don't have nested Inductor compilation. | ||
# compiled_fx_graph_hash will only be called once, and we can hook | ||
# it to get the hash of the compiled graph directly. | ||
from torch._inductor.codecache import compiled_fx_graph_hash | ||
|
||
inductor_artifact = InductorArtifact() | ||
from torch._inductor.codecache import (FxGraphCache, | ||
compiled_fx_graph_hash) | ||
original_load = FxGraphCache.load | ||
|
||
def hijack_load(*args, **kwargs): | ||
inductor_compiled_graph = original_load(*args, **kwargs) | ||
inductor_artifact.file_path = inductor_compiled_graph.current_callable.__code__.co_filename # noqa | ||
return inductor_compiled_graph | ||
|
||
def hijack_compiled_fx_graph_hash(*args, **kwargs): | ||
out = compiled_fx_graph_hash(*args, **kwargs) | ||
# store the hash in the cache | ||
nonlocal cache_data | ||
cache_data[(runtime_shape, graph_index)] = out[0] | ||
if graph_index == 0: | ||
# adds some info logging for the first graph | ||
logger.info("Cache the graph of shape %s for later use", | ||
str(runtime_shape)) | ||
logger.debug("store the %s-th graph for shape %s via hash %s", | ||
graph_index, str(runtime_shape), out[0]) | ||
inductor_artifact.hash_str = out[0] | ||
return out | ||
|
||
def _check_can_cache(*args, **kwargs): | ||
|
@@ -255,6 +277,11 @@ def _get_shape_env() -> AlwaysHitShapeEnv: | |
if not cache_data.disabled: | ||
# compilation cache is enabled, patch several functions | ||
|
||
# hijack to get the compiled graph itself | ||
stack.enter_context( | ||
patch("torch._inductor.codecache.FxGraphCache.load", | ||
hijack_load)) | ||
|
||
# for hijacking the hash of the compiled graph | ||
stack.enter_context( | ||
patch("torch._inductor.codecache.compiled_fx_graph_hash", | ||
|
@@ -275,7 +302,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv: | |
compiled_graph = compile_fx(graph, | ||
example_inputs, | ||
config_patches=current_config) | ||
|
||
# store the inductor_artifact in the cache | ||
cache_data[(runtime_shape, graph_index)] = inductor_artifact | ||
if graph_index == 0: | ||
# adds some info logging for the first graph | ||
logger.info("Cache the graph of shape %s for later use", | ||
str(runtime_shape)) | ||
logger.debug( | ||
"store the %s-th graph for shape %s via hash %s from file %s", | ||
graph_index, str(runtime_shape), inductor_artifact.hash_str, | ||
inductor_artifact.file_path) | ||
# after compiling the last graph, record the end time | ||
if graph_index == num_graphs - 1: | ||
now = time.time() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2862,17 +2862,8 @@ def model_post_init(self, __context: Any) -> None: | |
"vllm.unified_attention_with_output", | ||
] | ||
else: | ||
# v0 can use full graph compilation without splitting, | ||
# splitting is optional. | ||
# right now we still need it. kv cache shape | ||
# will be included in the graph if we don't split | ||
# the graph. | ||
# TODO: hide kv cache in static forward context | ||
# so that inductor does not see it. | ||
self.splitting_ops = [ | ||
"vllm.unified_attention", | ||
"vllm.unified_attention_with_output", | ||
] | ||
# v0 uses full graph compilation | ||
self.splitting_ops = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now that kv cache is completely hidden from |
||
|
||
for k, v in self.inductor_passes.items(): | ||
if not isinstance(v, str): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inside the
file_path
, we can find acall
function. that is the ultimate function inductor compiled and called.