From 4a77c1784addf1a454a52915f68938b5d3b852be Mon Sep 17 00:00:00 2001 From: Wieland Brendel Date: Tue, 11 Jul 2017 19:34:48 +0200 Subject: [PATCH] Added preprocessing_fn to pytorch model (#32) * added preprocessing_fn to pytorch model * fixed tabs * use mxnet 0.10.0 for now, because 0.10.1 is buggy (https://github.com/dmlc/mxnet/issues/6874) * fixed pep8 violations * added test for pytorch preprocessing support * fixed test --- .travis.yml | 2 +- foolbox/models/pytorch.py | 14 ++++-- foolbox/tests/test_models_pytorch.py | 65 ++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 4 deletions(-) diff --git a/.travis.yml b/.travis.yml index ff13fbe3..e03b9400 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,7 +18,7 @@ install: - if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cu75/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl; fi - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cu75/torch-0.1.12.post2-cp36-cp36m-linux_x86_64.whl; fi - travis_wait travis_retry pip install --upgrade keras - - travis_wait travis_retry pip install --upgrade mxnet + - travis_wait travis_retry pip install --upgrade mxnet==0.10.0 - pip install -e . script: - pytest diff --git a/foolbox/models/pytorch.py b/foolbox/models/pytorch.py index 83e593c6..c41cb13f 100644 --- a/foolbox/models/pytorch.py +++ b/foolbox/models/pytorch.py @@ -19,6 +19,8 @@ class PyTorchModel(DifferentiableModel): The index of the axis that represents color channels. cuda : bool A boolean specifying whether the model uses CUDA. + preprocess_fn : function + Will be called with the images before model predictions are calculated. """ @@ -28,7 +30,8 @@ def __init__( bounds, num_classes, channel_axis=1, - cuda=True): + cuda=True, + preprocess_fn=None): super(PyTorchModel, self).__init__(bounds=bounds, channel_axis=channel_axis) @@ -37,13 +40,18 @@ def __init__( self._model = model self.cuda = cuda + if preprocess_fn is not None: + self.preprocessing_fn = lambda x: preprocess_fn(x.copy()) + else: + self.preprocessing_fn = lambda x: x + def batch_predictions(self, images): # lazy import import torch from torch.autograd import Variable n = len(images) - images = torch.from_numpy(images) + images = torch.from_numpy(self.preprocessing_fn(images)) if self.cuda: # pragma: no cover images = images.cuda() images = Variable(images, volatile=True) @@ -73,7 +81,7 @@ def predictions_and_gradient(self, image, label): assert image.ndim == 3 images = image[np.newaxis] - images = torch.from_numpy(images) + images = torch.from_numpy(self.preprocessing_fn(images)) if self.cuda: # pragma: no cover images = images.cuda() images = Variable(images, requires_grad=True) diff --git a/foolbox/tests/test_models_pytorch.py b/foolbox/tests/test_models_pytorch.py index 64d7f146..f9c3cacd 100644 --- a/foolbox/tests/test_models_pytorch.py +++ b/foolbox/tests/test_models_pytorch.py @@ -52,3 +52,68 @@ def forward(self, x): test_gradient) assert model.num_classes() == num_classes + + +def test_pytorch_model_preprocessing(): + num_classes = 1000 + bounds = (0, 255) + channels = num_classes + + class Net(nn.Module): + + def __init__(self): + super(Net, self).__init__() + + def forward(self, x): + x = torch.mean(x, 3) + x = torch.squeeze(x, dim=3) + x = torch.mean(x, 2) + x = torch.squeeze(x, dim=2) + logits = x + return logits + + model = Net() + + def preprocess_fn(x): + # modify x in-place + x /= 2 + return x + + model1 = PyTorchModel( + model, + bounds=bounds, + num_classes=num_classes, + cuda=False) + + model2 = PyTorchModel( + model, + bounds=bounds, + num_classes=num_classes, + cuda=False, + preprocess_fn=preprocess_fn) + + model3 = PyTorchModel( + model, + bounds=bounds, + num_classes=num_classes, + cuda=False) + + np.random.seed(22) + test_images = np.random.rand(2, channels, 5, 5).astype(np.float32) + test_images_copy = test_images.copy() + + p1 = model1.batch_predictions(test_images) + p2 = model2.batch_predictions(test_images) + + # make sure the images have not been changed by + # the in-place preprocessing + assert np.all(test_images == test_images_copy) + + p3 = model3.batch_predictions(test_images) + + assert p1.shape == p2.shape == p3.shape == (2, num_classes) + + np.testing.assert_array_almost_equal( + p1 - p1.max(), + p3 - p3.max(), + decimal=5)