Skip to content

Commit

Permalink
Move encoders weights to HF-Hub (#1035)
Browse files Browse the repository at this point in the history
* Move everything to HF hub

* Add backup plan for downloading weights

* Rename with dot

* Update revisions

* Add test

* Add requirement

* Move loading file outside of try/except

* Fixup
  • Loading branch information
qubvel authored Jan 16, 2025
1 parent ce65165 commit 28877ed
Show file tree
Hide file tree
Showing 21 changed files with 1,692 additions and 623 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
'numpy>=1.19.3',
'pillow>=8',
'pretrainedmodels>=0.7.1',
'safetensors>=0.3.1',
'six>=1.5',
'timm>=0.9',
'torch>=1.8',
Expand Down
1 change: 1 addition & 0 deletions requirements/minimum.old
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ huggingface-hub==0.24.0
numpy==1.19.3
pillow==8.0.0
pretrainedmodels==0.7.1
safetensors==0.3.1
six==1.5.0
timm==0.9.0
torch==1.9.0
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ huggingface_hub==0.27.1
numpy==2.2.1
pillow==11.1.0
pretrainedmodels==0.7.4
safetensors==0.5.2
six==1.17.0
timm==1.0.13
torch==2.5.1
Expand Down
69 changes: 60 additions & 9 deletions segmentation_models_pytorch/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import json
import timm
import copy
import warnings
import functools
import torch.utils.model_zoo as model_zoo
from torch.utils.model_zoo import load_url
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


from .resnet import resnet_encoders
from .dpn import dpn_encoders
Expand All @@ -22,6 +26,7 @@
from .timm_universal import TimmUniversalEncoder

from ._preprocessing import preprocess_input
from ._legacy_pretrained_settings import pretrained_settings

__all__ = [
"encoders",
Expand Down Expand Up @@ -101,15 +106,43 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
encoder = EncoderClass(**params)

if weights is not None:
try:
settings = encoders[name]["pretrained_settings"][weights]
except KeyError:
if weights not in encoders[name]["pretrained_settings"]:
available_weights = list(encoders[name]["pretrained_settings"].keys())
raise KeyError(
"Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
weights, name, list(encoders[name]["pretrained_settings"].keys())
)
f"Wrong pretrained weights `{weights}` for encoder `{name}`. "
f"Available options are: {available_weights}"
)

settings = encoders[name]["pretrained_settings"][weights]
repo_id = settings["repo_id"]
revision = settings["revision"]

# First, try to load from HF-Hub, but as far as I know not all countries have
# access to the Hub (e.g. China), so we try to load from the original url if
# the first attempt fails.
weights_path = None
try:
hf_hub_download(repo_id, filename="config.json", revision=revision)
weights_path = hf_hub_download(
repo_id, filename="model.safetensors", revision=revision
)
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
except Exception as e:
if name in pretrained_settings and weights in pretrained_settings[name]:
message = (
f"Error loading {name} `{weights}` weights from Hugging Face Hub, "
"trying loading from original url..."
)
warnings.warn(message, UserWarning)
url = pretrained_settings[name][weights]["url"]
state_dict = load_url(url, map_location="cpu")
else:
raise e

if weights_path is not None:
state_dict = load_file(weights_path, device="cpu")

# Load model weights
encoder.load_state_dict(state_dict)

encoder.set_in_channels(in_channels, pretrained=weights is not None)
if output_stride != 32:
Expand All @@ -136,7 +169,25 @@ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
raise ValueError(
"Available pretrained options {}".format(all_settings.keys())
)
settings = all_settings[pretrained]

repo_id = all_settings[pretrained]["repo_id"]
revision = all_settings[pretrained]["revision"]

# Load config and model
try:
config_path = hf_hub_download(
repo_id, filename="config.json", revision=revision
)
with open(config_path, "r") as f:
settings = json.load(f)
except Exception as e:
if (
encoder_name in pretrained_settings
and pretrained in pretrained_settings[encoder_name]
):
settings = pretrained_settings[encoder_name][pretrained]
else:
raise e

formatted_settings = {}
formatted_settings["input_space"] = settings.get("input_space", "RGB")
Expand Down
1 change: 0 additions & 1 deletion segmentation_models_pytorch/encoders/_efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import math
import collections
from functools import partial
from torch.utils import model_zoo


class MBConvBlock(nn.Module):
Expand Down
Loading

0 comments on commit 28877ed

Please sign in to comment.