Skip to content

Commit

Permalink
Added preprocessing_fn to pytorch model (#32)
Browse files Browse the repository at this point in the history
* added preprocessing_fn to pytorch model

* fixed tabs

* use mxnet 0.10.0 for now, because 0.10.1 is buggy

(apache/mxnet#6874)

* fixed pep8 violations

* added test for pytorch preprocessing support

* fixed test
  • Loading branch information
wielandbrendel authored and jonasrauber committed Jul 11, 2017
1 parent 0b58463 commit 4a77c17
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ install:
- if [[ $TRAVIS_PYTHON_VERSION == 3.5 ]]; then pip install http://download.pytorch.org/whl/cu75/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl; fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install http://download.pytorch.org/whl/cu75/torch-0.1.12.post2-cp36-cp36m-linux_x86_64.whl; fi
- travis_wait travis_retry pip install --upgrade keras
- travis_wait travis_retry pip install --upgrade mxnet
- travis_wait travis_retry pip install --upgrade mxnet==0.10.0
- pip install -e .
script:
- pytest
Expand Down
14 changes: 11 additions & 3 deletions foolbox/models/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class PyTorchModel(DifferentiableModel):
The index of the axis that represents color channels.
cuda : bool
A boolean specifying whether the model uses CUDA.
preprocess_fn : function
Will be called with the images before model predictions are calculated.
"""

Expand All @@ -28,7 +30,8 @@ def __init__(
bounds,
num_classes,
channel_axis=1,
cuda=True):
cuda=True,
preprocess_fn=None):

super(PyTorchModel, self).__init__(bounds=bounds,
channel_axis=channel_axis)
Expand All @@ -37,13 +40,18 @@ def __init__(
self._model = model
self.cuda = cuda

if preprocess_fn is not None:
self.preprocessing_fn = lambda x: preprocess_fn(x.copy())
else:
self.preprocessing_fn = lambda x: x

def batch_predictions(self, images):
# lazy import
import torch
from torch.autograd import Variable

n = len(images)
images = torch.from_numpy(images)
images = torch.from_numpy(self.preprocessing_fn(images))
if self.cuda: # pragma: no cover
images = images.cuda()
images = Variable(images, volatile=True)
Expand Down Expand Up @@ -73,7 +81,7 @@ def predictions_and_gradient(self, image, label):

assert image.ndim == 3
images = image[np.newaxis]
images = torch.from_numpy(images)
images = torch.from_numpy(self.preprocessing_fn(images))
if self.cuda: # pragma: no cover
images = images.cuda()
images = Variable(images, requires_grad=True)
Expand Down
65 changes: 65 additions & 0 deletions foolbox/tests/test_models_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,68 @@ def forward(self, x):
test_gradient)

assert model.num_classes() == num_classes


def test_pytorch_model_preprocessing():
num_classes = 1000
bounds = (0, 255)
channels = num_classes

class Net(nn.Module):

def __init__(self):
super(Net, self).__init__()

def forward(self, x):
x = torch.mean(x, 3)
x = torch.squeeze(x, dim=3)
x = torch.mean(x, 2)
x = torch.squeeze(x, dim=2)
logits = x
return logits

model = Net()

def preprocess_fn(x):
# modify x in-place
x /= 2
return x

model1 = PyTorchModel(
model,
bounds=bounds,
num_classes=num_classes,
cuda=False)

model2 = PyTorchModel(
model,
bounds=bounds,
num_classes=num_classes,
cuda=False,
preprocess_fn=preprocess_fn)

model3 = PyTorchModel(
model,
bounds=bounds,
num_classes=num_classes,
cuda=False)

np.random.seed(22)
test_images = np.random.rand(2, channels, 5, 5).astype(np.float32)
test_images_copy = test_images.copy()

p1 = model1.batch_predictions(test_images)
p2 = model2.batch_predictions(test_images)

# make sure the images have not been changed by
# the in-place preprocessing
assert np.all(test_images == test_images_copy)

p3 = model3.batch_predictions(test_images)

assert p1.shape == p2.shape == p3.shape == (2, num_classes)

np.testing.assert_array_almost_equal(
p1 - p1.max(),
p3 - p3.max(),
decimal=5)

0 comments on commit 4a77c17

Please sign in to comment.