-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚡️ Add torch.compile
to PatchPredictor
#776
Changes from 28 commits
eea228d
9475ff5
d77be4c
77c3bc9
74cd8d9
49473c5
e970d51
fbb1e7f
4c2a102
8c51770
d8c78e0
a344b70
9180ba3
a2512ef
0a82ed1
661e25c
d887d14
174b2ad
abb7dff
a14aa12
7f8a2f9
0def095
bd033df
e87e62f
d3fe49c
b0b201f
a3c0300
2f93768
879bb4e
a2ad0d5
52e6d06
366e4fc
150678b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,8 +11,9 @@ | |
import torch | ||
import tqdm | ||
|
||
from tiatoolbox import logger | ||
from tiatoolbox import logger, rcParam | ||
from tiatoolbox.models.architecture import get_pretrained_model | ||
from tiatoolbox.models.architecture.utils import compile_model | ||
from tiatoolbox.models.dataset.classification import PatchDataset, WSIPatchDataset | ||
from tiatoolbox.models.engine.semantic_segmentor import IOSegmentorConfig | ||
from tiatoolbox.utils import misc, save_as_json | ||
|
@@ -250,7 +251,13 @@ def __init__( | |
|
||
self.ioconfig = ioconfig # for storing original | ||
self._ioconfig = None # for storing runtime | ||
self.model = model | ||
self.model = ( | ||
compile_model( # for runtime, such as after wrapping with nn.DataParallel | ||
model, | ||
mode=rcParam["torch_compile_mode"], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need rcparam for this? We can just set this as kwargs argument in the engines. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you @shaneahmed. Having kwargs for |
||
disable=not rcParam["enable_torch_compile"], | ||
) | ||
) | ||
self.pretrained_model = pretrained_model | ||
self.batch_size = batch_size | ||
self.num_loader_worker = num_loader_workers | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not think we need this variable. I think we should only call this function if
not disable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shaneahmed. That could be done, too. However, I'm mirroring the PyTorch implementation, which includes a
disable
flag in the function (torch.compile).