diff --git a/.gitignore b/.gitignore index 0c70c5e9..7c5bbca8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,6 @@ .pt13 *.egg-info build +/dist /outputs -/checkpoints \ No newline at end of file +/checkpoints diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..3b790a8a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,34 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "sgm" +dynamic = ["version"] +description = "Stability Generative Models" +readme = "README.md" +license-files = { paths = ["LICENSE"] } +requires-python = ">=3.8" + +[project.urls] +Homepage = "https://github.com/Stability-AI/generative-models" + +[tool.hatch.version] +path = "sgm/__init__.py" + +[tool.hatch.build] +# This needs to be explicitly set so the configuration files +# grafted into the `sgm` directory get included in the wheel's +# RECORD file. +include = [ + "sgm", +] +# The force-include configurations below make Hatch copy +# the configs/ directory (containing the various YAML files required +# to generatively model) into the source distribution and the wheel. + +[tool.hatch.build.targets.sdist.force-include] +"./configs" = "sgm/configs" + +[tool.hatch.build.targets.wheel.force-include] +"./configs" = "sgm/configs" diff --git a/setup.py b/setup.py deleted file mode 100644 index 3117b885..00000000 --- a/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="sgm", - version="0.0.1", - packages=find_packages(), - python_requires=">=3.8", - py_modules=["sgm"], - description="Stability Generative Models", - long_description=open("README.md", "r", encoding="utf-8").read(), - long_description_content_type="text/markdown", - url="https://github.com/Stability-AI/generative-models", -) diff --git a/sgm/__init__.py b/sgm/__init__.py index cc9c7dc5..f639416e 100644 --- a/sgm/__init__.py +++ b/sgm/__init__.py @@ -1,3 +1,5 @@ from .data import StableDataModuleFromConfig from .models import AutoencodingEngine, DiffusionEngine -from .util import instantiate_from_config +from .util import instantiate_from_config, get_configs_path + +__version__ = "0.0.1" diff --git a/sgm/util.py b/sgm/util.py index 06f48a88..97713bf0 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -229,3 +229,21 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): model.eval() return model + + +def get_configs_path() -> str: + """ + Get the `configs` directory. + For a working copy, this is the one in the root of the repository, + but for an installed copy, it's in the `sgm` package (see pyproject.toml). + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "configs"), + os.path.join(this_dir, "..", "configs"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM configs in {candidates}")