Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⚡️ Add torch.compile to PatchPredictor #776

Merged
merged 33 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
eea228d
⚡️ Add `torch.compile` to `PatchPredictor`
Jan 26, 2024
9475ff5
🚨 Remove unused imports
Jan 26, 2024
d77be4c
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
Abdol Jan 26, 2024
77c3bc9
Merge branch 'enhance-torch-compile-patch-predictor' of https://githu…
Jan 26, 2024
74cd8d9
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
Abdol Jan 26, 2024
49473c5
🐛 Add `rcParam` import
Jan 26, 2024
e970d51
✅ Add `torch.compile` tests
Jan 29, 2024
fbb1e7f
✅ Change to a more generic test
Jan 30, 2024
4c2a102
📝 Update test docstring
Jan 30, 2024
8c51770
⚡️ Reset `TorchDynamo` when changing `torch.compile` mode
Feb 2, 2024
d8c78e0
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
shaneahmed Feb 2, 2024
a344b70
👷 Enable CI checks
Feb 16, 2024
9180ba3
👷 Add python package checks
Feb 16, 2024
a2512ef
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
Abdol Feb 27, 2024
0a82ed1
🚑 Disable `torch.compile` in Python 3.12+
Mar 4, 2024
661e25c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
d887d14
🔥 Remove `compile_model` decorator mode for now
Mar 4, 2024
174b2ad
🐛 Fix a bug where a compatbility warning is shown if torch.compile is…
Mar 6, 2024
abb7dff
🚸 Check if model is compiled before compiling
Mar 6, 2024
a14aa12
🚨 Disable protected member access check
Mar 15, 2024
7f8a2f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
0def095
🚨 Fix `pre-commit.ci` linting
Mar 15, 2024
bd033df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
e87e62f
🚨 Fix another linter error
Mar 15, 2024
d3fe49c
Merge branch 'enhance-torch-compile-patch-predictor' of https://githu…
Mar 15, 2024
b0b201f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
a3c0300
📝 Add comment above version check for `torch.compile`
Mar 15, 2024
2f93768
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
Abdol Mar 15, 2024
879bb4e
🚨 Remove unnecessary noqa
Mar 19, 2024
a2ad0d5
Merge branch 'enhance-torch-compile-patch-predictor' of https://githu…
Mar 19, 2024
52e6d06
Merge branch 'enhance-torch-compile' into enhance-torch-compile-patch…
Abdol Mar 19, 2024
366e4fc
✅ Add test to check if model is already compiled
Mar 19, 2024
150678b
✅ Skip test coverage for for checking for python 3.12
Mar 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/mypy-type-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
push:
branches: [ develop, pre-release, master, main ]
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main, enhance-torch-compile ]

jobs:

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:
branches: [ develop, pre-release, master, main ]
tags: v*
pull_request:
branches: [ develop, pre-release, master, main ]
branches: [ develop, pre-release, master, main, enhance-torch-compile ]

jobs:
build:
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,12 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]:
# -------------------------------------------------------------------------------------


def timed(fn: Callable) -> (Callable, float):
def timed(fn: Callable, *args: object) -> (Callable, float):
"""A decorator that times the execution of a function.

Args:
fn (Callable): The function to be timed.
args (object): Arguments to be passed to the function.

Returns:
A tuple containing the result of the function
Expand All @@ -602,13 +603,13 @@ def timed(fn: Callable) -> (Callable, float):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
result = fn(*args)
end.record()
torch.cuda.synchronize()
compile_time = start.elapsed_time(end) / 1000
else:
start = time.time()
result = fn()
result = fn(*args)
end = time.time()
compile_time = end - start
return result, compile_time
54 changes: 53 additions & 1 deletion tests/models/test_patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import torch
from click.testing import CliRunner

from tiatoolbox import cli
from tests.conftest import timed
from tiatoolbox import cli, logger, rcParam
from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor
from tiatoolbox.models.architecture.vanilla import CNNModel
from tiatoolbox.models.dataset import (
Expand Down Expand Up @@ -1226,3 +1227,54 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) -
assert tmp_path.joinpath("2.merged.npy").exists()
assert tmp_path.joinpath("2.raw.json").exists()
assert tmp_path.joinpath("results.json").exists()


# -------------------------------------------------------------------------------------
# torch.compile
# -------------------------------------------------------------------------------------


def test_patch_predictor_torch_compile(
sample_patch1: Path,
sample_patch2: Path,
tmp_path: Path,
) -> None:
"""Test torch.compile functionality.

Args:
sample_patch1 (Path): Path to sample patch 1.
sample_patch2 (Path): Path to sample patch 2.
tmp_path (Path): Path to temporary directory.
"""
torch_compile_enabled = rcParam["enable_torch_compile"]
torch._dynamo.reset()
rcParam["enable_torch_compile"] = True
# Test torch.compile with default mode
rcParam["torch_compile_mode"] = "default"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile default mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "reduce-overhead"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile reduce-overhead mode: %s", compile_time)
torch._dynamo.reset()
rcParam["torch_compile_mode"] = "max-autotune"
_, compile_time = timed(
test_patch_predictor_api,
sample_patch1,
sample_patch2,
tmp_path,
)
logger.info("torch.compile max-autotune mode: %s", compile_time)
torch._dynamo.reset()
rcParam["enable_torch_compile"] = torch_compile_enabled
1 change: 1 addition & 0 deletions tiatoolbox/models/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def get_pretrained_model(
model.load_state_dict(saved_state_dict, strict=True)

# !

io_info = info["ioconfig"]
creator = locate(f"tiatoolbox.models.engine.{io_info['class']}")

Expand Down
50 changes: 50 additions & 0 deletions tiatoolbox/models/architecture/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,60 @@

from __future__ import annotations

import sys
from typing import Callable

import numpy as np
import torch
from torch import nn

from tiatoolbox import logger


def compile_model(
model: nn.Module | None = None,
*,
mode: str,
disable: bool,
) -> Callable:
"""A decorator to compile a model using torch-compile.

Args:
model (torch.nn.Module):
Model to be compiled.
mode (str):
Mode to be used for torch-compile. Available modes are
`default`, `reduce-overhead`, `max-autotune`, and
`max-autotune-no-cudagraphs`.
disable (bool):
If True, torch-compile will be disabled.

Returns:
Callable:
Compiled model.

"""
if disable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think we need this variable. I think we should only call this function if not disable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shaneahmed. That could be done, too. However, I'm mirroring the PyTorch implementation, which includes a disable flag in the function (torch.compile).

return model

# This check will be removed when torch.compile is supported in Python 3.12+
if sys.version_info >= (3, 12):
logger.warning(
("torch-compile is currently not supported in Python 3.12+. ",),
)
return model

if isinstance(
model,
torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001, E501
Abdol marked this conversation as resolved.
Show resolved Hide resolved
):
logger.warning(

Check warning on line 52 in tiatoolbox/models/architecture/utils.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/models/architecture/utils.py#L52

Added line #L52 was not covered by tests
Abdol marked this conversation as resolved.
Show resolved Hide resolved
("The model is already compiled. ",),
)
return model

Check warning on line 55 in tiatoolbox/models/architecture/utils.py

View check run for this annotation

Codecov / codecov/patch

tiatoolbox/models/architecture/utils.py#L55

Added line #L55 was not covered by tests
Abdol marked this conversation as resolved.
Show resolved Hide resolved

return torch.compile(model, mode=mode, disable=disable)


def centre_crop(
img: np.ndarray | torch.tensor,
Expand Down
11 changes: 9 additions & 2 deletions tiatoolbox/models/engine/patch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
import tqdm

from tiatoolbox import logger
from tiatoolbox import logger, rcParam
from tiatoolbox.models.architecture import get_pretrained_model
from tiatoolbox.models.architecture.utils import compile_model
from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset
from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig
from tiatoolbox.utils import misc, save_as_json
Expand Down Expand Up @@ -250,7 +251,13 @@ def __init__(

self.ioconfig = ioconfig # for storing original
self._ioconfig = None # for storing runtime
self.model = model
self.model = (
compile_model( # for runtime, such as after wrapping with nn.DataParallel
model,
mode=rcParam["torch_compile_mode"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need rcparam for this? We can just set this as kwargs argument in the engines.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @shaneahmed. Having kwargs for torch_compile_mode would work, too. I may suggest to keep rcParam for now until we implement it in the new engine design. Happy to discuss it in our next meeting.

disable=not rcParam["enable_torch_compile"],
)
)
self.pretrained_model = pretrained_model
self.batch_size = batch_size
self.num_loader_worker = num_loader_workers
Expand Down
Loading