-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bump version * Update gitignore * Update flake * Add hub mixin * Fix interpolation * Add from_pretrained * Update example
- Loading branch information
Showing
8 changed files
with
339 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
[flake8] | ||
max-line-length = 119 | ||
exclude =.git,__pycache__,docs/conf.py,build,dist,setup.py,tests,.venv | ||
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006,D412 | ||
ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,D202,E203,W503,B006,D412,F821,E501 | ||
inline-quotes = " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ __pycache__/ | |
*$py.class | ||
.idea/ | ||
.venv* | ||
examples/images* | ||
examples/annotations* | ||
|
||
# C extensions | ||
*.so | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
VERSION = (0, 3, 3) | ||
VERSION = (0, 3, "4dev0") | ||
|
||
__version__ = ".".join(map(str, VERSION)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import json | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
from functools import wraps | ||
from huggingface_hub import PyTorchModelHubMixin, ModelCard, ModelCardData, hf_hub_download | ||
|
||
|
||
MODEL_CARD = """ | ||
--- | ||
{{ card_data }} | ||
--- | ||
# {{ model_name }} Model Card | ||
Table of Contents: | ||
- [Load trained model](#load-trained-model) | ||
- [Model init parameters](#model-init-parameters) | ||
- [Model metrics](#model-metrics) | ||
- [Dataset](#dataset) | ||
## Load trained model | ||
```python | ||
import segmentation_models_pytorch as smp | ||
model = smp.{{ model_name }}.from_pretrained("{{ save_directory | default("<save-directory-or-repo>", true)}}") | ||
``` | ||
## Model init parameters | ||
```python | ||
model_init_params = {{ model_parameters }} | ||
``` | ||
## Model metrics | ||
{{ metrics | default("[More Information Needed]", true) }} | ||
## Dataset | ||
Dataset name: {{ dataset | default("[More Information Needed]", true) }} | ||
## More Information | ||
- Library: {{ repo_url | default("[More Information Needed]", true) }} | ||
- Docs: {{ docs_url | default("[More Information Needed]", true) }} | ||
This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) | ||
""" | ||
|
||
|
||
def _format_parameters(parameters: dict): | ||
params = {k: v for k, v in parameters.items() if not k.startswith("_")} | ||
params = [f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' for k, v in params.items()] | ||
params = ",\n".join([f" {param}" for param in params]) | ||
params = "{\n" + f"{params}" + "\n}" | ||
return params | ||
|
||
|
||
class SMPHubMixin(PyTorchModelHubMixin): | ||
def generate_model_card(self, *args, **kwargs) -> ModelCard: | ||
|
||
model_parameters_json = _format_parameters(self._hub_mixin_config) | ||
directory = self._save_directory if hasattr(self, "_save_directory") else None | ||
repo_id = self._repo_id if hasattr(self, "_repo_id") else None | ||
repo_or_directory = repo_id if repo_id is not None else directory | ||
|
||
metrics = self._metrics if hasattr(self, "_metrics") else None | ||
dataset = self._dataset if hasattr(self, "_dataset") else None | ||
|
||
if metrics is not None: | ||
metrics = json.dumps(metrics, indent=4) | ||
metrics = f"```json\n{metrics}\n```" | ||
|
||
model_card_data = ModelCardData( | ||
languages=["python"], | ||
library_name="segmentation-models-pytorch", | ||
license="mit", | ||
tags=["semantic-segmentation", "pytorch", "segmentation-models-pytorch"], | ||
pipeline_tag="image-segmentation", | ||
) | ||
model_card = ModelCard.from_template( | ||
card_data=model_card_data, | ||
template_str=MODEL_CARD, | ||
repo_url="https://github.com/qubvel/segmentation_models.pytorch", | ||
docs_url="https://smp.readthedocs.io/en/latest/", | ||
model_parameters=model_parameters_json, | ||
save_directory=repo_or_directory, | ||
model_name=self.__class__.__name__, | ||
metrics=metrics, | ||
dataset=dataset, | ||
) | ||
return model_card | ||
|
||
def _set_attrs_from_kwargs(self, attrs, kwargs): | ||
for attr in attrs: | ||
if attr in kwargs: | ||
setattr(self, f"_{attr}", kwargs.pop(attr)) | ||
|
||
def _del_attrs(self, attrs): | ||
for attr in attrs: | ||
if hasattr(self, f"_{attr}"): | ||
delattr(self, f"_{attr}") | ||
|
||
@wraps(PyTorchModelHubMixin.save_pretrained) | ||
def save_pretrained(self, save_directory: Union[str, Path], *args, **kwargs) -> Optional[str]: | ||
|
||
# set additional attributes to be used in generate_model_card | ||
self._save_directory = save_directory | ||
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) | ||
|
||
# set additional attribute to be used in from_pretrained | ||
self._hub_mixin_config["_model_class"] = self.__class__.__name__ | ||
|
||
try: | ||
# call the original save_pretrained | ||
result = super().save_pretrained(save_directory, *args, **kwargs) | ||
finally: | ||
# delete the additional attributes | ||
self._del_attrs(["save_directory", "metrics", "dataset"]) | ||
self._hub_mixin_config.pop("_model_class") | ||
|
||
return result | ||
|
||
@wraps(PyTorchModelHubMixin.push_to_hub) | ||
def push_to_hub(self, repo_id: str, *args, **kwargs): | ||
self._repo_id = repo_id | ||
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs) | ||
result = super().push_to_hub(repo_id, *args, **kwargs) | ||
self._del_attrs(["repo_id", "metrics", "dataset"]) | ||
return result | ||
|
||
@property | ||
def config(self): | ||
return self._hub_mixin_config | ||
|
||
|
||
@wraps(PyTorchModelHubMixin.from_pretrained) | ||
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): | ||
config_path = hf_hub_download( | ||
pretrained_model_name_or_path, filename="config.json", revision=kwargs.get("revision", None) | ||
) | ||
with open(config_path, "r") as f: | ||
config = json.load(f) | ||
model_class_name = config.pop("_model_class") | ||
|
||
import segmentation_models_pytorch as smp | ||
|
||
model_class = getattr(smp, model_class_name) | ||
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters