Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VisionDataModule set/get transform doesn't change datset transform #1064

Open
jascase901 opened this issue Aug 22, 2023 · 0 comments
Open

VisionDataModule set/get transform doesn't change datset transform #1064

jascase901 opened this issue Aug 22, 2023 · 0 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@jascase901
Copy link

jascase901 commented Aug 22, 2023

🐛 Bug

Setting the transform of the data module, should change the transform of the underlying dataset.

import pl_bolts                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                          
from pl_bolts.datamodules import MNISTDataModule                                                                                                                                                                                                                                          
from torchvision import transforms as transform_lib                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                          
mnist = MNISTDataModule(data_dir = "/tmp/mnist")                                                                                                                                                                                                                                          
mnist.prepare_data()                                                                                                                                                                                                                                                                      
mnist.setup(stage="fit")                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                          
print("before set_transform")                                                                                                                                                                                                                                                             
print(mnist.dataset_train.dataset.transforms)                                                                                                                                                                                                                                             
#                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# Expect this to change the train dataset transform?                                                                                                                                                                                                                                      
mnist.train_transforms = transform_lib.Compose(                                                                                                                                                                                                                                           
    [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.6,), std=(0.5,))]                                                                                                                                                                                                          
)                                                                                                                                                                                                                                                                                         
#                                                                                                                                                                                                                                                                                         
# expect to print the new transform                                                                                                                                                                                                                                                       
print("after transform")                                                                                                                                                                                                                                                                  
print(mnist.dataset_train.dataset.transforms) 

Results

before set_transform                                                                                                                                                                                                                                                                      
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )                                                                                                                                                                                                                                                                              
after transform                                                                                                                                                                                                                                                                           
StandardTransform                                                                                                                                                                                                                                                                         
Transform: Compose(                                                                                                                                                                                                                                                                       
               ToTensor()                                                                                                                                                                                                                                                                 
           )       

Expected

I expected the datset transform to differ after I set the transform

Environment

  • PyTorch Version (e.g., 1.0):1,13.1+c117
  • OS (e.g., Linux):linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.10
  • CUDA/cuDNN version: 11
@jascase901 jascase901 added bug Something isn't working help wanted Extra attention is needed labels Aug 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant