Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ruff for formatting and linting #877

Merged
merged 5 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .flake8

This file was deleted.

22 changes: 22 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,27 @@ on:

jobs:

style:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.4.6
# Update output format to enable automatic inline annotations.
- name: Run Ruff Linter
run: ruff check --output-format=github
- name: Run Ruff Formatter
run: ruff format --check

test:
runs-on: ubuntu-latest
needs: [style]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -25,3 +44,6 @@ jobs:
python -m pip install --upgrade pip
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
make install_dev
- name: Test with pytest
run: make test

5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,7 @@ venv.bak/
/site

# mypy
.mypy_cache/
.mypy_cache/

# ruff
.ruff_cache/
23 changes: 0 additions & 23 deletions .pre-commit-config.yaml

This file was deleted.

11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
python3 -m venv .venv

install_dev: .venv
.venv/bin/pip install -e .[test]
.venv/bin/pre-commit install
.venv/bin/pip install -e ".[test]"

test: .venv
.venv/bin/pytest -p no:cacheprovider tests/
Expand All @@ -16,7 +15,9 @@ table:
table_timm:
.venv/bin/python misc/generate_table_timm.py

precommit: install_dev
.venv/bin/pre-commit run --all-files
fixup:
.venv/bin/ruff check --fix
.venv/bin/ruff format

all: fixup test

all: precommit test
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ make install_dev # create .venv, install SMP in dev mode
#### Run tests and code checks

```bash
make all # run precommit, tests
make fixup # Ruff for formatting and lint checks
```

#### Update table with encoders
Expand Down
6 changes: 1 addition & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# import sys
# sys.path.insert(0, os.path.abspath('.'))

import os
import re
import sys
import datetime
import sphinx_rtd_theme

sys.path.append("..")

Expand Down Expand Up @@ -68,14 +67,11 @@ def get_version():
# a list of builtin themes.
#

import sphinx_rtd_theme

html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]

# import karma_sphinx_theme
# html_theme = "karma_sphinx_theme"
import faculty_sphinx_theme

html_theme = "faculty_sphinx_theme"

Expand Down
10 changes: 4 additions & 6 deletions misc/generate_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@


WIDTH = 32
COLUMNS = [
"Encoder",
"Weights",
"Params, M",
]
COLUMNS = ["Encoder", "Weights", "Params, M"]


def wrap_row(r):
return "|{}|".format(r)


header = "|".join([column.ljust(WIDTH, " ") for column in COLUMNS])
separator = "|".join(["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1))
separator = "|".join(
["-" * WIDTH] + [":" + "-" * (WIDTH - 2) + ":"] * (len(COLUMNS) - 1)
)

print(wrap_row(header))
print(wrap_row(separator))
Expand Down
15 changes: 12 additions & 3 deletions misc/generate_table_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,29 @@ def make_table(data):

l1 = "+" + "-" * max_len1 + "+" + "-" * max_len2 + "+\n"
l2 = "+" + "=" * max_len1 + "+" + "=" * max_len2 + "+\n"
top = "| " + "Encoder name".ljust(max_len1 - 2) + " | " + "Support dilation".center(max_len2 - 2) + " |\n"
top = (
"| "
+ "Encoder name".ljust(max_len1 - 2)
+ " | "
+ "Support dilation".center(max_len2 - 2)
+ " |\n"
)

table = l1 + top + l2

for k in sorted(data.keys()):
support = "✅".center(max_len2 - 3) if data[k]["has_dilation"] else " ".center(max_len2 - 2)
support = (
"✅".center(max_len2 - 3)
if data[k]["has_dilation"]
else " ".center(max_len2 - 2)
)
table += "| " + k.ljust(max_len1 - 2) + " | " + support + " |\n"
table += l1

return table


if __name__ == "__main__":

supported_models = {}

with tqdm(timm.list_models()) as names:
Expand Down
19 changes: 0 additions & 19 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@
[tool.black]
line-length = 119
target-version = ['py37', 'py38']
include = '\.pyi?$'
exclude = '''
/(
\.eggs
| \.git
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| docs
| _build
| buck-out
| build
| dist
)/
'''
24 changes: 22 additions & 2 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ def create_model(
except KeyError:
raise KeyError(
"Wrong architecture type `{}`. Available options are: {}".format(
arch,
list(archs_dict.keys()),
arch, list(archs_dict.keys())
)
)
return model_class(
Expand All @@ -61,3 +60,24 @@ def create_model(
classes=classes,
**kwargs,
)


__all__ = [
"datasets",
"encoders",
"decoders",
"losses",
"metrics",
"Unet",
"UnetPlusPlus",
"MAnet",
"Linknet",
"FPN",
"PSPNet",
"DeepLabV3",
"DeepLabV3Plus",
"PAN",
"from_pretrained",
"create_model",
"__version__",
]
18 changes: 10 additions & 8 deletions segmentation_models_pytorch/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .model import SegmentationModel

from .modules import (
Conv2dReLU,
Attention,
)
from .modules import Conv2dReLU, Attention

from .heads import (
SegmentationHead,
ClassificationHead,
)
from .heads import SegmentationHead, ClassificationHead

__all__ = [
"SegmentationModel",
"Conv2dReLU",
"Attention",
"SegmentationHead",
"ClassificationHead",
]
22 changes: 17 additions & 5 deletions segmentation_models_pytorch/base/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@


class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
def __init__(
self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
):
conv2d = nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
)
upsampling = (
nn.UpsamplingBilinear2d(scale_factor=upsampling)
if upsampling > 1
else nn.Identity()
)
activation = Activation(activation)
super().__init__(conv2d, upsampling, activation)


class ClassificationHead(nn.Sequential):
def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
def __init__(
self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
):
if pooling not in ("max", "avg"):
raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
raise ValueError(
"Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
)
pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
flatten = nn.Flatten()
dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
Expand Down
22 changes: 16 additions & 6 deletions segmentation_models_pytorch/base/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
from pathlib import Path
from typing import Optional, Union
from functools import wraps
from huggingface_hub import PyTorchModelHubMixin, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub import (
PyTorchModelHubMixin,
ModelCard,
ModelCardData,
hf_hub_download,
)


MODEL_CARD = """
Expand Down Expand Up @@ -45,15 +50,17 @@

def _format_parameters(parameters: dict):
params = {k: v for k, v in parameters.items() if not k.startswith("_")}
params = [f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"' for k, v in params.items()]
params = [
f'"{k}": {v}' if not isinstance(v, str) else f'"{k}": "{v}"'
for k, v in params.items()
]
params = ",\n".join([f" {param}" for param in params])
params = "{\n" + f"{params}" + "\n}"
return params


class SMPHubMixin(PyTorchModelHubMixin):
def generate_model_card(self, *args, **kwargs) -> ModelCard:

model_parameters_json = _format_parameters(self._hub_mixin_config)
directory = self._save_directory if hasattr(self, "_save_directory") else None
repo_id = self._repo_id if hasattr(self, "_repo_id") else None
Expand Down Expand Up @@ -97,8 +104,9 @@ def _del_attrs(self, attrs):
delattr(self, f"_{attr}")

@wraps(PyTorchModelHubMixin.save_pretrained)
def save_pretrained(self, save_directory: Union[str, Path], *args, **kwargs) -> Optional[str]:

def save_pretrained(
self, save_directory: Union[str, Path], *args, **kwargs
) -> Optional[str]:
# set additional attributes to be used in generate_model_card
self._save_directory = save_directory
self._set_attrs_from_kwargs(["metrics", "dataset"], kwargs)
Expand Down Expand Up @@ -132,7 +140,9 @@ def config(self):
@wraps(PyTorchModelHubMixin.from_pretrained)
def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
config_path = hf_hub_download(
pretrained_model_name_or_path, filename="config.json", revision=kwargs.get("revision", None)
pretrained_model_name_or_path,
filename="config.json",
revision=kwargs.get("revision", None),
)
with open(config_path, "r") as f:
config = json.load(f)
Expand Down
1 change: 0 additions & 1 deletion segmentation_models_pytorch/base/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

def initialize_decoder(module):
for m in module.modules():

if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
Expand Down
Loading
Loading