Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vendor efficientnet-pytorch #1036

Merged
merged 4 commits into from
Jan 16, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 2 additions & 133 deletions segmentation_models_pytorch/encoders/_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,11 @@
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()

Check warning on line 174 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L174

Added line #L174 was not covered by tests
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved


class EfficientNet(nn.Module):
"""EfficientNet model.
Most easily loaded with the .from_name or .from_pretrained methods.

Args:
blocks_args (list[namedtuple]): A list of BlockArgs to construct blocks.
Expand Down Expand Up @@ -275,9 +274,9 @@
Args:
memory_efficient (bool): Whether to use memory-efficient version of swish.
"""
self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
self._swish = MemoryEfficientSwish() if memory_efficient else nn.SiLU()
for block in self._blocks:
block.set_swish(memory_efficient)

Check warning on line 279 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L277-L279

Added lines #L277 - L279 were not covered by tests

def extract_endpoints(self, inputs):
"""Use convolution layer to extract features
Expand All @@ -302,31 +301,31 @@
>>> print(endpoints['reduction_5'].shape) # torch.Size([1, 320, 7, 7])
>>> print(endpoints['reduction_6'].shape) # torch.Size([1, 1280, 7, 7])
"""
endpoints = dict()

Check warning on line 304 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L304

Added line #L304 was not covered by tests

# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
prev_x = x

Check warning on line 308 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L307-L308

Added lines #L307 - L308 were not covered by tests

# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(

Check warning on line 314 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L311-L314

Added lines #L311 - L314 were not covered by tests
self._blocks
) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)
if prev_x.size(2) > x.size(2):
endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
elif idx == len(self._blocks) - 1:
endpoints["reduction_{}".format(len(endpoints) + 1)] = x
prev_x = x

Check warning on line 322 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L317-L322

Added lines #L317 - L322 were not covered by tests

# Head
x = self._swish(self._bn1(self._conv_head(x)))
endpoints["reduction_{}".format(len(endpoints) + 1)] = x

Check warning on line 326 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L325-L326

Added lines #L325 - L326 were not covered by tests

return endpoints

Check warning on line 328 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L328

Added line #L328 was not covered by tests

def extract_features(self, inputs):
"""use convolution layer to extract feature .
Expand All @@ -339,21 +338,21 @@
layer in the efficientnet model.
"""
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))

Check warning on line 341 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L341

Added line #L341 was not covered by tests

# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(

Check warning on line 347 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L344-L347

Added lines #L344 - L347 were not covered by tests
self._blocks
) # scale drop connect_rate
x = block(x, drop_connect_rate=drop_connect_rate)

Check warning on line 350 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L350

Added line #L350 was not covered by tests

# Head
x = self._swish(self._bn1(self._conv_head(x)))

Check warning on line 353 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L353

Added line #L353 was not covered by tests

return x

Check warning on line 355 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L355

Added line #L355 was not covered by tests

def forward(self, inputs):
"""EfficientNet's forward function.
Expand All @@ -366,136 +365,15 @@
Output of this model after processing.
"""
# Convolution layers
x = self.extract_features(inputs)

Check warning on line 368 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L368

Added line #L368 was not covered by tests
# Pooling and final linear layer
x = self._avg_pooling(x)
if self._global_params.include_top:
x = x.flatten(start_dim=1)
x = self._dropout(x)
x = self._fc(x)
return x

Check warning on line 375 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L370-L375

Added lines #L370 - L375 were not covered by tests

@classmethod
def from_name(cls, model_name, in_channels=3, **override_params):
"""Create an efficientnet model according to name.

Args:
model_name (str): Name for efficientnet.
in_channels (int): Input data's channel number.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'num_classes', 'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'

Returns:
An efficientnet model.
"""
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name, override_params)
model = cls(blocks_args, global_params)
model._change_in_channels(in_channels)
return model

@classmethod
def from_pretrained(
cls,
model_name,
weights_path=None,
advprop=False,
in_channels=3,
num_classes=1000,
**override_params,
):
"""Create an efficientnet model according to name.

Args:
model_name (str): Name for efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
advprop (bool):
Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
in_channels (int): Input data's channel number.
num_classes (int):
Number of categories for classification.
It controls the output size for final linear layer.
override_params (other key word params):
Params to override model's global_params.
Optional key:
'width_coefficient', 'depth_coefficient',
'image_size', 'dropout_rate',
'batch_norm_momentum',
'batch_norm_epsilon', 'drop_connect_rate',
'depth_divisor', 'min_depth'

Returns:
A pretrained efficientnet model.
"""
model = cls.from_name(model_name, num_classes=num_classes, **override_params)
load_pretrained_weights(
model,
model_name,
weights_path=weights_path,
load_fc=(num_classes == 1000),
advprop=advprop,
)
model._change_in_channels(in_channels)
return model

@classmethod
def get_image_size(cls, model_name):
"""Get the input image size for a given efficientnet model.

Args:
model_name (str): Name for efficientnet.

Returns:
Input image size (resolution).
"""
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res

@classmethod
def _check_model_name_is_valid(cls, model_name):
"""Validates model name.

Args:
model_name (str): Name for efficientnet.

Returns:
bool: Is a valid name or not.
"""
if model_name not in VALID_MODELS:
raise ValueError("model_name should be one of: " + ", ".join(VALID_MODELS))

def _change_in_channels(self, in_channels):
"""Adjust model's first convolution layer to in_channels, if in_channels not equals 3.

Args:
in_channels (int): Input data's channel number.
"""
if in_channels != 3:
Conv2d = get_same_padding_conv2d(image_size=self._global_params.image_size)
out_channels = round_filters(32, self._global_params)
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias=False
)


"""utils.py - Helper functions for building the model and for loading model parameters.
These helper functions are built to mirror those in the official TensorFlow implementation.
"""

# Author: lukemelas (github username)
# Github repo: https://github.com/lukemelas/EfficientNet-PyTorch
# With adjustments and added comments by workingcoder (github username).


################################################################################
# Help functions for model architecture
Expand Down Expand Up @@ -553,15 +431,6 @@
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)

# Swish activation function
if hasattr(nn, "SiLU"):
Swish = nn.SiLU
else:
# For compatibility with old PyTorch versions
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)


# A memory-efficient implementation of Swish function
class SwishImplementation(torch.autograd.Function):
Expand Down Expand Up @@ -596,7 +465,7 @@
"""
multiplier = global_params.width_coefficient
if not multiplier:
return filters

Check warning on line 468 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L468

Added line #L468 was not covered by tests
# TODO: modify the params names.
# maybe the names (width_divisor,min_width)
# are more suitable than (depth_divisor,min_depth).
Expand All @@ -607,7 +476,7 @@
# follow the formula transferred from official TensorFlow implementation
new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor

Check warning on line 479 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L479

Added line #L479 was not covered by tests
return int(new_filters)


Expand All @@ -624,7 +493,7 @@
"""
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats

Check warning on line 496 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L496

Added line #L496 was not covered by tests
# follow the formula transferred from official TensorFlow implementation
return int(math.ceil(multiplier * repeats))

Expand Down Expand Up @@ -673,7 +542,7 @@
if isinstance(x, list) or isinstance(x, tuple):
return x
else:
raise TypeError()

Check warning on line 545 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L545

Added line #L545 was not covered by tests


def calculate_output_image_size(input_image_size, stride):
Expand All @@ -688,7 +557,7 @@
output_image_size: A list [H,W].
"""
if input_image_size is None:
return None

Check warning on line 560 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L560

Added line #L560 was not covered by tests
image_height, image_width = get_width_and_height_from_size(input_image_size)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
Expand All @@ -713,7 +582,7 @@
Conv2dDynamicSamePadding or Conv2dStaticSamePadding.
"""
if image_size is None:
return Conv2dDynamicSamePadding

Check warning on line 585 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L585

Added line #L585 was not covered by tests
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)

Expand Down Expand Up @@ -745,26 +614,26 @@
groups=1,
bias=True,
):
super().__init__(

Check warning on line 617 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L617

Added line #L617 was not covered by tests
in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias
)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2

Check warning on line 620 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L620

Added line #L620 was not covered by tests

def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = (

Check warning on line 626 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L623-L626

Added lines #L623 - L626 were not covered by tests
math.ceil(ih / sh),
math.ceil(iw / sw),
) # change the output size according to stride ! ! !
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(

Check warning on line 633 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L630-L633

Added lines #L630 - L633 were not covered by tests
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
return F.conv2d(

Check warning on line 636 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L636

Added line #L636 was not covered by tests
x,
self.weight,
self.bias,
Expand Down Expand Up @@ -833,10 +702,10 @@
Returns:
MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding.
"""
if image_size is None:
return MaxPool2dDynamicSamePadding

Check warning on line 706 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L705-L706

Added lines #L705 - L706 were not covered by tests
else:
return partial(MaxPool2dStaticSamePadding, image_size=image_size)

Check warning on line 708 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L708

Added line #L708 was not covered by tests


class MaxPool2dDynamicSamePadding(nn.MaxPool2d):
Expand All @@ -853,31 +722,31 @@
return_indices=False,
ceil_mode=False,
):
super().__init__(

Check warning on line 725 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L725

Added line #L725 was not covered by tests
kernel_size, stride, padding, dilation, return_indices, ceil_mode
)
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
self.kernel_size = (

Check warning on line 729 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L728-L729

Added lines #L728 - L729 were not covered by tests
[self.kernel_size] * 2
if isinstance(self.kernel_size, int)
else self.kernel_size
)
self.dilation = (

Check warning on line 734 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L734

Added line #L734 was not covered by tests
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
)

def forward(self, x):
ih, iw = x.size()[-2:]
kh, kw = self.kernel_size
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(

Check warning on line 746 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L739-L746

Added lines #L739 - L746 were not covered by tests
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
return F.max_pool2d(

Check warning on line 749 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L749

Added line #L749 was not covered by tests
x,
self.kernel_size,
self.stride,
Expand All @@ -894,35 +763,35 @@
"""

def __init__(self, kernel_size, stride, image_size=None, **kwargs):
super().__init__(kernel_size, stride, **kwargs)
self.stride = [self.stride] * 2 if isinstance(self.stride, int) else self.stride
self.kernel_size = (

Check warning on line 768 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L766-L768

Added lines #L766 - L768 were not covered by tests
[self.kernel_size] * 2
if isinstance(self.kernel_size, int)
else self.kernel_size
)
self.dilation = (

Check warning on line 773 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L773

Added line #L773 was not covered by tests
[self.dilation] * 2 if isinstance(self.dilation, int) else self.dilation
)

# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
kh, kw = self.kernel_size
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d(

Check warning on line 786 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L778-L786

Added lines #L778 - L786 were not covered by tests
(pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
)
else:
self.static_padding = nn.Identity()

Check warning on line 790 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L790

Added line #L790 was not covered by tests

def forward(self, x):
x = self.static_padding(x)
x = F.max_pool2d(

Check warning on line 794 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L793-L794

Added lines #L793 - L794 were not covered by tests
x,
self.kernel_size,
self.stride,
Expand All @@ -931,7 +800,7 @@
self.ceil_mode,
self.return_indices,
)
return x

Check warning on line 803 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L803

Added line #L803 was not covered by tests


################################################################################
Expand Down Expand Up @@ -998,7 +867,7 @@
Returns:
block_string: A String form of BlockArgs.
"""
args = [

Check warning on line 870 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L870

Added line #L870 was not covered by tests
"r%d" % block.num_repeat,
"k%d" % block.kernel_size,
"s%d%d" % (block.strides[0], block.strides[1]),
Expand All @@ -1006,11 +875,11 @@
"i%d" % block.input_filters,
"o%d" % block.output_filters,
]
if 0 < block.se_ratio <= 1:
args.append("se%s" % block.se_ratio)
if block.id_skip is False:
args.append("noskip")
return "_".join(args)

Check warning on line 882 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L878-L882

Added lines #L878 - L882 were not covered by tests

@staticmethod
def decode(string_list):
Expand Down Expand Up @@ -1038,10 +907,10 @@
Returns:
block_strings: A list of strings, each string is a notation of block.
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings

Check warning on line 913 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L910-L913

Added lines #L910 - L913 were not covered by tests


def efficientnet_params(model_name):
Expand Down Expand Up @@ -1141,12 +1010,12 @@
width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s
)
else:
raise NotImplementedError(

Check warning on line 1013 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1013

Added line #L1013 was not covered by tests
"model name is not pre-defined: {}".format(model_name)
)
if override_params:
# ValueError will be raised here if override_params has fields not included in global_params.
global_params = global_params._replace(**override_params)

Check warning on line 1018 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1018

Added line #L1018 was not covered by tests
return blocks_args, global_params


Expand Down Expand Up @@ -1195,28 +1064,28 @@
advprop (bool): Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
"""
if isinstance(weights_path, str):
state_dict = torch.load(weights_path)

Check warning on line 1068 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1067-L1068

Added lines #L1067 - L1068 were not covered by tests
else:
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = model_zoo.load_url(url_map_[model_name])

Check warning on line 1072 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1071-L1072

Added lines #L1071 - L1072 were not covered by tests

if load_fc:
ret = model.load_state_dict(state_dict, strict=False)
assert not ret.missing_keys, (

Check warning on line 1076 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1074-L1076

Added lines #L1074 - L1076 were not covered by tests
"Missing keys when loading pretrained weights: {}".format(ret.missing_keys)
)
else:
state_dict.pop("_fc.weight")
state_dict.pop("_fc.bias")
ret = model.load_state_dict(state_dict, strict=False)
assert set(ret.missing_keys) == set(["_fc.weight", "_fc.bias"]), (

Check warning on line 1083 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1080-L1083

Added lines #L1080 - L1083 were not covered by tests
"Missing keys when loading pretrained weights: {}".format(ret.missing_keys)
)
assert not ret.unexpected_keys, (

Check warning on line 1086 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1086

Added line #L1086 was not covered by tests
"Missing keys when loading pretrained weights: {}".format(ret.unexpected_keys)
)

if verbose:
print("Loaded pretrained weights for {}".format(model_name))

Check warning on line 1091 in segmentation_models_pytorch/encoders/_efficientnet.py

View check run for this annotation

Codecov / codecov/patch

segmentation_models_pytorch/encoders/_efficientnet.py#L1090-L1091

Added lines #L1090 - L1091 were not covered by tests
Loading