diff --git a/.github/workflows/mypy-type-check.yml b/.github/workflows/mypy-type-check.yml index 03f013862..2a86b874f 100644 --- a/.github/workflows/mypy-type-check.yml +++ b/.github/workflows/mypy-type-check.yml @@ -6,7 +6,7 @@ on: push: branches: [ develop, pre-release, master, main ] pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, enhance-torch-compile ] jobs: diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index ed2d7a5fb..42919d024 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,7 +8,7 @@ on: branches: [ develop, pre-release, master, main ] tags: v* pull_request: - branches: [ develop, pre-release, master, main ] + branches: [ develop, pre-release, master, main, enhance-torch-compile ] jobs: build: diff --git a/tests/conftest.py b/tests/conftest.py index 682e893a8..98c35817c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -587,11 +587,12 @@ def data_path(tmp_path_factory: pytest.TempPathFactory) -> dict[str, object]: # ------------------------------------------------------------------------------------- -def timed(fn: Callable) -> (Callable, float): +def timed(fn: Callable, *args: object) -> (Callable, float): """A decorator that times the execution of a function. Args: fn (Callable): The function to be timed. + args (object): Arguments to be passed to the function. Returns: A tuple containing the result of the function @@ -602,13 +603,13 @@ def timed(fn: Callable) -> (Callable, float): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() - result = fn() + result = fn(*args) end.record() torch.cuda.synchronize() compile_time = start.elapsed_time(end) / 1000 else: start = time.time() - result = fn() + result = fn(*args) end = time.time() compile_time = end - start return result, compile_time diff --git a/tests/models/test_patch_predictor.py b/tests/models/test_patch_predictor.py index ab59efc53..e72096be7 100644 --- a/tests/models/test_patch_predictor.py +++ b/tests/models/test_patch_predictor.py @@ -13,7 +13,8 @@ import torch from click.testing import CliRunner -from tiatoolbox import cli +from tests.conftest import timed +from tiatoolbox import cli, logger, rcParam from tiatoolbox.models import IOPatchPredictorConfig, PatchPredictor from tiatoolbox.models.architecture.vanilla import CNNModel from tiatoolbox.models.dataset import ( @@ -1226,3 +1227,54 @@ def test_cli_model_multiple_file_mask(remote_sample: Callable, tmp_path: Path) - assert tmp_path.joinpath("2.merged.npy").exists() assert tmp_path.joinpath("2.raw.json").exists() assert tmp_path.joinpath("results.json").exists() + + +# ------------------------------------------------------------------------------------- +# torch.compile +# ------------------------------------------------------------------------------------- + + +def test_patch_predictor_torch_compile( + sample_patch1: Path, + sample_patch2: Path, + tmp_path: Path, +) -> None: + """Test torch.compile functionality. + + Args: + sample_patch1 (Path): Path to sample patch 1. + sample_patch2 (Path): Path to sample patch 2. + tmp_path (Path): Path to temporary directory. + """ + torch_compile_enabled = rcParam["enable_torch_compile"] + torch._dynamo.reset() + rcParam["enable_torch_compile"] = True + # Test torch.compile with default mode + rcParam["torch_compile_mode"] = "default" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile default mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "reduce-overhead" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile reduce-overhead mode: %s", compile_time) + torch._dynamo.reset() + rcParam["torch_compile_mode"] = "max-autotune" + _, compile_time = timed( + test_patch_predictor_api, + sample_patch1, + sample_patch2, + tmp_path, + ) + logger.info("torch.compile max-autotune mode: %s", compile_time) + torch._dynamo.reset() + rcParam["enable_torch_compile"] = torch_compile_enabled diff --git a/tests/test_utils.py b/tests/test_utils.py index faab9ac17..cf76028aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd import pytest +import torch from PIL import Image from requests import HTTPError from shapely.geometry import Polygon @@ -21,6 +22,7 @@ from tiatoolbox import utils from tiatoolbox.annotation.storage import SQLiteStore from tiatoolbox.models.architecture import fetch_pretrained_weights +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.utils import misc from tiatoolbox.utils.exceptions import FileNotSupportedError from tiatoolbox.utils.transforms import locsize2bounds @@ -1819,3 +1821,19 @@ def test_patch_pred_store_persist_ext(tmp_path: pytest.TempPathFactory) -> None: # check correct error is raised if coordinates are missing with pytest.raises(ValueError, match="coordinates"): misc.dict_to_store(patch_output, (1.0, 1.0)) + + +def test_torch_compile_already_compiled() -> None: + """Test that torch_compile does not recompile a model that is already compiled.""" + # Create a simple model + model = torch.nn.Sequential(torch.nn.Linear(10, 10), torch.nn.Linear(10, 10)) + + # Compile the model + compiled_model = compile_model(model) + + # Compile the model again + recompiled_model = compile_model(compiled_model) + + # Check that the recompiled model + # is the same as the original compiled model + assert compiled_model == recompiled_model diff --git a/tiatoolbox/models/architecture/__init__.py b/tiatoolbox/models/architecture/__init__.py index 9ac1bbd82..6fac9b08b 100644 --- a/tiatoolbox/models/architecture/__init__.py +++ b/tiatoolbox/models/architecture/__init__.py @@ -150,6 +150,7 @@ def get_pretrained_model( model.load_state_dict(saved_state_dict, strict=True) # ! + io_info = info["ioconfig"] creator = locate(f"tiatoolbox.models.engine.{io_info['class']}") diff --git a/tiatoolbox/models/architecture/utils.py b/tiatoolbox/models/architecture/utils.py index cefeca1c3..94f970df8 100644 --- a/tiatoolbox/models/architecture/utils.py +++ b/tiatoolbox/models/architecture/utils.py @@ -2,10 +2,60 @@ from __future__ import annotations +import sys +from typing import Callable + import numpy as np import torch from torch import nn +from tiatoolbox import logger + + +def compile_model( + model: nn.Module | None = None, + *, + mode: str = "default", + disable: bool = False, +) -> Callable: + """A decorator to compile a model using torch-compile. + + Args: + model (torch.nn.Module): + Model to be compiled. + mode (str): + Mode to be used for torch-compile. Available modes are + `default`, `reduce-overhead`, `max-autotune`, and + `max-autotune-no-cudagraphs`. + disable (bool): + If True, torch-compile will be disabled. + + Returns: + Callable: + Compiled model. + + """ + if disable: + return model + + # This check will be removed when torch.compile is supported in Python 3.12+ + if sys.version_info >= (3, 12): # pragma: no cover + logger.warning( + ("torch-compile is currently not supported in Python 3.12+. ",), + ) + return model + + if isinstance( + model, + torch._dynamo.eval_frame.OptimizedModule, # skipcq: PYL-W0212 # noqa: SLF001 + ): + logger.warning( + ("The model is already compiled. ",), + ) + return model + + return torch.compile(model, mode=mode, disable=disable) + def centre_crop( img: np.ndarray | torch.tensor, diff --git a/tiatoolbox/models/engine/patch_predictor.py b/tiatoolbox/models/engine/patch_predictor.py index 0bc570e69..2aede1393 100644 --- a/tiatoolbox/models/engine/patch_predictor.py +++ b/tiatoolbox/models/engine/patch_predictor.py @@ -11,8 +11,9 @@ import torch import tqdm -from tiatoolbox import logger +from tiatoolbox import logger, rcParam from tiatoolbox.models.architecture import get_pretrained_model +from tiatoolbox.models.architecture.utils import compile_model from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig from tiatoolbox.utils import misc, save_as_json @@ -250,7 +251,13 @@ def __init__( self.ioconfig = ioconfig # for storing original self._ioconfig = None # for storing runtime - self.model = model + self.model = ( + compile_model( # for runtime, such as after wrapping with nn.DataParallel + model, + mode=rcParam["torch_compile_mode"], + disable=not rcParam["enable_torch_compile"], + ) + ) self.pretrained_model = pretrained_model self.batch_size = batch_size self.num_loader_worker = num_loader_workers