Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature proposal: TorchIO hub - A system to store and fetch transform objects for reproducibility #972

Open
tomvars opened this issue Sep 28, 2022 · 6 comments
Labels
enhancement New feature or request

Comments

@tomvars
Copy link

tomvars commented Sep 28, 2022

🚀 Feature

Introducing a public TorchIO hub where researchers can save the transform object used to randomly sample and augment their data during training with one line of code - transform = tio.from_hub("cool_recent_paper")

Motivation

DL researchers and practitioners hoping to reproduce other people's work can easily fetch model weights and architectural definitions (e.g Torch Hub or MONAI Bundle), training parameters (e.g AutoModel from HuggingFace) and preprocessing strategies (e.g AutoFeatureExtractor from HuggingFace) however, one thing which is still an obstacle in reproducing someone's setup in a few lines of code is data augmentation. Libraries like Albumentations and TorchIO provide a variety of common data augmentation strategies - but they lack the Hub features of HF or Pytorch to easily store and fetch strategies.

Pitch

Not sure how you would implement this. As an MVP you could have a separate repo where users submit model transforms as code and a big dictionary lookup between some chosen string and their transforms.

@tomvars tomvars added the enhancement New feature or request label Sep 28, 2022
@fepegar
Copy link
Owner

fepegar commented Oct 9, 2022

Hi, @tomvars. Thanks for the proposal. I think this is an excellent idea.

One potential solution would be leveraging the PyTorch Hub tools. I got this code working. What do you think?

import torchio as tio
fpg = tio.datasets.FPG()
fpg.plot(reorient=False)

Figure_1

import torch
repo = 'fepegar/resseg:add-preprocessing-hubconf'
function_name = 'get_preprocessing_transform'
input_path = fpg.t1.path
preprocess = torch.hub.load(repo, function_name, input_path, image_name='t1', force_reload=True)
preprocessed = preprocess(fpg)
preprocessed.plot(reorient=False)

Figure_2

@tomvars
Copy link
Author

tomvars commented Oct 9, 2022

I really like this API! You could maybe create a new repo like fepegar/torchiohub:main and have a single hubconf.py file as the access point to different preprocessing functions. In the repo users could append their transform functions to a large transforms.py file and the hubconf.py would have lines such as from transforms import ronneberger_unet_2015_transform

@fepegar
Copy link
Owner

fepegar commented Oct 9, 2022

I think it's more convenient to allow users to use their own hubconf in their repos because

  1. This is what PyTorch does, so people are familiar with the syntax etc.
  2. Sometimes, getting a transform needs some special code. The snippet I shared is an example in which additional libraries or files might be needed just to compute the transform, and we wouldn't want to put everyone's code in the same repo.

So the contribution to this library (which I'm happy to write) would be documentation on how to set up transforms for reproducibility on top of PyTorch Hub. Does that sound good?

@tomvars
Copy link
Author

tomvars commented Oct 9, 2022

That makes sense 👍 thoughts on introducing a class method for the Transform called from_hub which would wrap the torch.hub.load call and pass in the relevant arguments?

@tomvars tomvars closed this as completed Oct 9, 2022
@tomvars tomvars reopened this Oct 9, 2022
@fepegar
Copy link
Owner

fepegar commented Oct 9, 2022

You mean something like this?

@classmethod
def from_hub(cls, *args, **kwargs):
    return torch.hub.load(*args, **kwargs)

@fepegar
Copy link
Owner

fepegar commented Nov 20, 2022

Hey, I forgot to share some experiments I conducted. The code below needs unet to be PIP-installed:

import torch
import torchio as tio

colin = tio.datasets.Colin27()
path = colin.t1.path
torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path)
transform = torch.hub.load('fepegar/resseg:add-preprocessing-hubconf', 'get_preprocessing_transform', path, image_name='t1')
transform(colin).plot()

Here, HistogramStandardization makes it a bit awkward, but things work. We should write a tutorial about this. If you think the class method would be helpful, feel free to contribute with a PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants