Skip to content

Commit

Permalink
Clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Dec 17, 2017
1 parent 6e98c74 commit fffb893
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
8 changes: 7 additions & 1 deletion .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 10 additions & 6 deletions resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from keras.applications.imagenet_utils import _obtain_input_shape
import keras.backend as K


CIFAR_TH_WEIGHTS_PATH = ''
CIFAR_TF_WEIGHTS_PATH = ''
CIFAR_TH_WEIGHTS_PATH_NO_TOP = ''
Expand Down Expand Up @@ -235,7 +234,7 @@ def ResNextImageNet(input_shape=None, depth=[3, 4, 6, 3], cardinality=32, width=
default_size=224,
min_size=112,
data_format=K.image_data_format(),
include_top=include_top)
require_flatten=include_top)

if input_tensor is None:
img_input = Input(shape=input_shape)
Expand Down Expand Up @@ -319,7 +318,7 @@ def __initial_conv_block(input, weight_decay=5e-4):
return x


def __initial_conv_block_inception(input, weight_decay=5e-4):
def __initial_conv_block_imagenet(input, weight_decay=5e-4):
''' Adds an initial conv block, with batch norm and relu for the inception resnext
Args:
input: input tensor
Expand Down Expand Up @@ -363,8 +362,8 @@ def __grouped_convolution_block(input, grouped_channels, cardinality, strides, w

for c in range(cardinality):
x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
if K.image_data_format() == 'channels_last' else
lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)
if K.image_data_format() == 'channels_last' else
lambda z: z[:, c * grouped_channels:(c + 1) * grouped_channels, :, :])(input)

x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides),
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
Expand Down Expand Up @@ -537,7 +536,7 @@ def __create_res_next_imagenet(nb_classes, img_input, include_top, depth, cardin
filters_list.append(filters)
filters *= 2 # double the size of the filters

x = __initial_conv_block_inception(img_input, weight_decay)
x = __initial_conv_block_imagenet(img_input, weight_decay)

# block 1 (no pooling)
for i in range(N[0]):
Expand Down Expand Up @@ -567,3 +566,8 @@ def __create_res_next_imagenet(nb_classes, img_input, include_top, depth, cardin
x = GlobalMaxPooling2D()(x)

return x


if __name__ == '__main__':
model = ResNext((32, 32, 3), depth=29, cardinality=8, width=64)
model.summary()

0 comments on commit fffb893

Please sign in to comment.