From ac60da8c50f30d703711e0f627354fe5331219db Mon Sep 17 00:00:00 2001 From: Benjamin Aubin <62965623+benjaminaubin@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:09:28 +0200 Subject: [PATCH] Pre release changes for production (#59) * clean requirements * rm taming deps * isort, black * mv lipips, license * clean vq, fix path * fix loss path, gitignore * tested requirements pt13 * fix numpy req for python3.8, add tests * fix name * fix dep scipy 3.8 pt2 * add black test formatter --- .github/workflows/black.yml | 15 ++ .github/workflows/test-build.yaml | 26 ++++ .gitignore | 9 +- README.md | 23 ++- main.py | 11 +- requirements/pt13.txt | 40 +++++ requirements/pt2.txt | 39 +++++ requirements_pt13.txt | 41 ----- requirements_pt2.txt | 41 ----- scripts/demo/sampling.py | 1 + scripts/demo/streamlit_helpers.py | 19 ++- .../nsfw_and_watermark_dectection.py | 5 +- sgm/__init__.py | 5 +- sgm/data/cifar10.py | 4 +- sgm/data/mnist.py | 4 +- sgm/modules/autoencoding/losses/__init__.py | 6 +- sgm/modules/autoencoding/lpips/__init__.py | 0 .../autoencoding/lpips/loss/.gitignore | 1 + sgm/modules/autoencoding/lpips/loss/LICENSE | 23 +++ .../autoencoding/lpips/loss/__init__.py | 0 sgm/modules/autoencoding/lpips/loss/lpips.py | 147 ++++++++++++++++++ sgm/modules/autoencoding/lpips/model/LICENSE | 58 +++++++ .../autoencoding/lpips/model/__init__.py | 0 sgm/modules/autoencoding/lpips/model/model.py | 88 +++++++++++ sgm/modules/autoencoding/lpips/util.py | 128 +++++++++++++++ .../autoencoding/lpips/vqperceptual.py | 17 ++ sgm/modules/diffusionmodules/__init__.py | 2 +- sgm/modules/diffusionmodules/discretizer.py | 9 +- sgm/modules/diffusionmodules/loss.py | 2 +- sgm/modules/diffusionmodules/wrappers.py | 2 +- sgm/modules/distributions/distributions.py | 2 +- 31 files changed, 641 insertions(+), 127 deletions(-) create mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/test-build.yaml create mode 100644 requirements/pt13.txt create mode 100644 requirements/pt2.txt delete mode 100644 requirements_pt13.txt delete mode 100644 requirements_pt2.txt create mode 100644 sgm/modules/autoencoding/lpips/__init__.py create mode 100644 sgm/modules/autoencoding/lpips/loss/.gitignore create mode 100644 sgm/modules/autoencoding/lpips/loss/LICENSE create mode 100644 sgm/modules/autoencoding/lpips/loss/__init__.py create mode 100644 sgm/modules/autoencoding/lpips/loss/lpips.py create mode 100644 sgm/modules/autoencoding/lpips/model/LICENSE create mode 100644 sgm/modules/autoencoding/lpips/model/__init__.py create mode 100644 sgm/modules/autoencoding/lpips/model/model.py create mode 100644 sgm/modules/autoencoding/lpips/util.py create mode 100644 sgm/modules/autoencoding/lpips/vqperceptual.py diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 00000000..80823b44 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,15 @@ +name: Run black +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Install venv + run: | + sudo apt-get -y install python3.10-venv + - uses: psf/black@stable + with: + options: "--check --verbose -l88" + src: "./sgm ./scripts ./main.py" diff --git a/.github/workflows/test-build.yaml b/.github/workflows/test-build.yaml new file mode 100644 index 00000000..ffbeff46 --- /dev/null +++ b/.github/workflows/test-build.yaml @@ -0,0 +1,26 @@ +name: Build package + +on: + push: + pull_request: + +jobs: + build: + name: Build + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.8", "3.10"] + requirements-file: ["pt2", "pt13"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements/${{ matrix.requirements-file }}.txt + pip install . \ No newline at end of file diff --git a/.gitignore b/.gitignore index c0902eb8..5506c38d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,14 @@ +# extensions *.egg-info *.py[cod] + +# envs .pt13 .pt2 -.pt2_2 + +# directories /checkpoints /dist /outputs -build +/build +/src \ No newline at end of file diff --git a/README.md b/README.md index e309c39b..3ffd18ad 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,9 @@ This is assuming you have navigated to the `generative-models` root after clonin ```shell # install required packages from pypi -python3 -m venv .pt1 -source .pt1/bin/activate -pip3 install wheel -pip3 install -r requirements_pt13.txt +python3 -m venv .pt13 +source .pt13/bin/activate +pip3 install -r requirements/pt13.txt ``` **PyTorch 2.0** @@ -72,8 +71,20 @@ pip3 install -r requirements_pt13.txt # install required packages from pypi python3 -m venv .pt2 source .pt2/bin/activate -pip3 install wheel -pip3 install -r requirements_pt2.txt +pip3 install -r requirements/pt2.txt +``` + + +#### 3. Install `sgm` + +```shell +pip3 install . +``` + +#### 4. Install `sdata` for training + +```shell +pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata ``` ## Packaging diff --git a/main.py b/main.py index c916f512..5e03c1c5 100644 --- a/main.py +++ b/main.py @@ -12,22 +12,18 @@ import torch import torchvision import wandb -from PIL import Image from matplotlib import pyplot as plt from natsort import natsorted from omegaconf import OmegaConf from packaging import version +from PIL import Image from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import Callback from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities import rank_zero_only -from sgm.util import ( - exists, - instantiate_from_config, - isheatmap, -) +from sgm.util import exists, instantiate_from_config, isheatmap MULTINODE_HACKS = True @@ -910,11 +906,12 @@ def divein(*args, **kwargs): trainer.test(model, data) except RuntimeError as err: if MULTINODE_HACKS: - import requests import datetime import os import socket + import requests + device = os.environ.get("CUDA_VISIBLE_DEVICES", "?") hostname = socket.gethostname() ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") diff --git a/requirements/pt13.txt b/requirements/pt13.txt new file mode 100644 index 00000000..b4f1272d --- /dev/null +++ b/requirements/pt13.txt @@ -0,0 +1,40 @@ +black==23.7.0 +chardet>=5.1.0 +clip @ git+https://github.com/openai/CLIP.git +einops>=0.6.1 +fairscale>=0.4.13 +fire>=0.5.0 +fsspec>=2023.6.0 +invisible-watermark>=0.2.0 +kornia==0.6.9 +matplotlib>=3.7.2 +natsort>=8.4.0 +numpy>=1.24.4 +omegaconf>=2.3.0 +onnx<=1.12.0 +open-clip-torch>=2.20.0 +opencv-python==4.6.0.66 +pandas>=2.0.3 +pillow>=9.5.0 +pudb>=2022.1.3 +pytorch-lightning==1.8.5 +pyyaml>=6.0.1 +scipy>=1.10.1 +streamlit>=1.25.0 +tensorboardx==2.5.1 +timm>=0.9.2 +tokenizers==0.12.1 +--extra-index-url https://download.pytorch.org/whl/cu117 +torch==1.13.1+cu117 +torchaudio==0.13.1 +torchdata==0.5.1 +torchmetrics>=1.0.1 +torchvision==0.14.1+cu117 +tqdm>=4.65.0 +transformers==4.19.1 +triton==2.0.0.post1 +urllib3<1.27,>=1.25.4 +wandb>=0.15.6 +webdataset>=0.2.33 +wheel>=0.41.0 +xformers==0.0.16 \ No newline at end of file diff --git a/requirements/pt2.txt b/requirements/pt2.txt new file mode 100644 index 00000000..003a5264 --- /dev/null +++ b/requirements/pt2.txt @@ -0,0 +1,39 @@ +black==23.7.0 +chardet==5.1.0 +clip @ git+https://github.com/openai/CLIP.git +einops>=0.6.1 +fairscale>=0.4.13 +fire>=0.5.0 +fsspec>=2023.6.0 +invisible-watermark>=0.2.0 +kornia==0.6.9 +matplotlib>=3.7.2 +natsort>=8.4.0 +ninja>=1.11.1 +numpy>=1.24.4 +omegaconf>=2.3.0 +open-clip-torch>=2.20.0 +opencv-python==4.6.0.66 +pandas>=2.0.3 +pillow>=9.5.0 +pudb>=2022.1.3 +pytorch-lightning==2.0.1 +pyyaml>=6.0.1 +scipy>=1.10.1 +streamlit>=0.73.1 +tensorboardx==2.6 +timm>=0.9.2 +tokenizers==0.12.1 +torch>=2.0.1 +torchaudio>=2.0.2 +torchdata==0.6.1 +torchmetrics>=1.0.1 +torchvision>=0.15.2 +tqdm>=4.65.0 +transformers==4.19.1 +triton==2.0.0 +urllib3<1.27,>=1.25.4 +wandb>=0.15.6 +webdataset>=0.2.33 +wheel>=0.41.0 +xformers>=0.0.20 diff --git a/requirements_pt13.txt b/requirements_pt13.txt deleted file mode 100644 index 3d5b117c..00000000 --- a/requirements_pt13.txt +++ /dev/null @@ -1,41 +0,0 @@ -omegaconf -einops -fire -tqdm -pillow -numpy -webdataset>=0.2.33 ---extra-index-url https://download.pytorch.org/whl/cu117 -torch==1.13.1+cu117 -xformers==0.0.16 -torchaudio==0.13.1 -torchvision==0.14.1+cu117 -torchmetrics -opencv-python==4.6.0.66 -fairscale -pytorch-lightning==1.8.5 -fsspec -kornia==0.6.9 -matplotlib -natsort -tensorboardx==2.5.1 -open-clip-torch -chardet -scipy -pandas -pudb -pyyaml -urllib3<1.27,>=1.25.4 -streamlit>=0.73.1 -timm -tokenizers==0.12.1 -torchdata==0.5.1 -transformers==4.19.1 -onnx<=1.12.0 -triton -wandb -invisible-watermark --e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers --e git+https://github.com/openai/CLIP.git@main#egg=clip --e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata --e . \ No newline at end of file diff --git a/requirements_pt2.txt b/requirements_pt2.txt deleted file mode 100644 index 9988b908..00000000 --- a/requirements_pt2.txt +++ /dev/null @@ -1,41 +0,0 @@ -omegaconf -einops -fire -tqdm -pillow -numpy -webdataset>=0.2.33 -ninja -torch -matplotlib -torchaudio>=2.0.2 -torchmetrics -torchvision>=0.15.2 -opencv-python==4.6.0.66 -fairscale -pytorch-lightning==2.0.1 -fire -fsspec -kornia==0.6.9 -natsort -open-clip-torch -chardet==5.1.0 -tensorboardx==2.6 -pandas -pudb -pyyaml -urllib3<1.27,>=1.25.4 -scipy -streamlit>=0.73.1 -timm -tokenizers==0.12.1 -transformers==4.19.1 -triton==2.0.0 -torchdata==0.6.1 -wandb -invisible-watermark -xformers --e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers --e git+https://github.com/openai/CLIP.git@main#egg=clip --e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata --e . \ No newline at end of file diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 98d0af30..87d80155 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -1,4 +1,5 @@ from pytorch_lightning import seed_everything + from scripts.demo.streamlit_helpers import * from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 8f53b5db..2cf165b6 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -1,29 +1,28 @@ +import math import os -from typing import Union, List +from typing import List, Union -import math import numpy as np import streamlit as st import torch -from PIL import Image from einops import rearrange, repeat from imwatermark import WatermarkEncoder -from omegaconf import OmegaConf, ListConfig +from omegaconf import ListConfig, OmegaConf +from PIL import Image +from safetensors.torch import load_file as load_safetensors from torch import autocast from torchvision import transforms from torchvision.utils import make_grid -from safetensors.torch import load_file as load_safetensors from sgm.modules.diffusionmodules.sampling import ( + DPMPP2MSampler, + DPMPP2SAncestralSampler, + EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, - EulerAncestralSampler, - DPMPP2SAncestralSampler, - DPMPP2MSampler, LinearMultistepSampler, ) -from sgm.util import append_dims -from sgm.util import instantiate_from_config +from sgm.util import append_dims, instantiate_from_config class WatermarkEmbedder: diff --git a/scripts/util/detection/nsfw_and_watermark_dectection.py b/scripts/util/detection/nsfw_and_watermark_dectection.py index af84acf3..ab450c13 100644 --- a/scripts/util/detection/nsfw_and_watermark_dectection.py +++ b/scripts/util/detection/nsfw_and_watermark_dectection.py @@ -1,9 +1,10 @@ import os -import torch + +import clip import numpy as np +import torch import torchvision.transforms as T from PIL import Image -import clip RESOURCES_ROOT = "scripts/util/detection/" diff --git a/sgm/__init__.py b/sgm/__init__.py index f639416e..24bc84af 100644 --- a/sgm/__init__.py +++ b/sgm/__init__.py @@ -1,5 +1,4 @@ -from .data import StableDataModuleFromConfig from .models import AutoencodingEngine, DiffusionEngine -from .util import instantiate_from_config, get_configs_path +from .util import get_configs_path, instantiate_from_config -__version__ = "0.0.1" +__version__ = "0.1.0" diff --git a/sgm/data/cifar10.py b/sgm/data/cifar10.py index aa3ae677..6083646f 100644 --- a/sgm/data/cifar10.py +++ b/sgm/data/cifar10.py @@ -1,7 +1,7 @@ -import torchvision import pytorch_lightning as pl -from torchvision import transforms +import torchvision from torch.utils.data import DataLoader, Dataset +from torchvision import transforms class CIFAR10DataDictWrapper(Dataset): diff --git a/sgm/data/mnist.py b/sgm/data/mnist.py index ab7478f4..dea4d7e6 100644 --- a/sgm/data/mnist.py +++ b/sgm/data/mnist.py @@ -1,7 +1,7 @@ -import torchvision import pytorch_lightning as pl -from torchvision import transforms +import torchvision from torch.utils.data import DataLoader, Dataset +from torchvision import transforms class MNISTDataDictWrapper(Dataset): diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py index 6a3b54f7..cc1bdb49 100644 --- a/sgm/modules/autoencoding/losses/__init__.py +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -3,11 +3,11 @@ import torch import torch.nn as nn from einops import rearrange -from taming.modules.discriminator.model import NLayerDiscriminator, weights_init -from taming.modules.losses.lpips import LPIPS -from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from ....util import default, instantiate_from_config +from ..lpips.loss.lpips import LPIPS +from ..lpips.model.model import NLayerDiscriminator, weights_init +from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss def adopt_weight(weight, global_step, threshold=0, value=0.0): diff --git a/sgm/modules/autoencoding/lpips/__init__.py b/sgm/modules/autoencoding/lpips/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sgm/modules/autoencoding/lpips/loss/.gitignore b/sgm/modules/autoencoding/lpips/loss/.gitignore new file mode 100644 index 00000000..a92958a1 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/.gitignore @@ -0,0 +1 @@ +vgg.pth \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/LICENSE b/sgm/modules/autoencoding/lpips/loss/LICENSE new file mode 100644 index 00000000..924cfc85 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/LICENSE @@ -0,0 +1,23 @@ +Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/loss/__init__.py b/sgm/modules/autoencoding/lpips/loss/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sgm/modules/autoencoding/lpips/loss/lpips.py b/sgm/modules/autoencoding/lpips/loss/lpips.py new file mode 100644 index 00000000..3e34f3d0 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/loss/lpips.py @@ -0,0 +1,147 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +from collections import namedtuple + +import torch +import torch.nn as nn +from torchvision import models + +from ..util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "sgm/modules/autoencoding/lpips/loss") + self.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict( + torch.load(ckpt, map_location=torch.device("cpu")), strict=False + ) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor( + outs1[kk] + ) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [ + spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns)) + ] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer( + "shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None] + ) + self.register_buffer( + "scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None] + ) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """A single linear layer which does a 1x1 conv""" + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = ( + [ + nn.Dropout(), + ] + if (use_dropout) + else [] + ) + layers += [ + nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), + ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"] + ) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x / (norm_factor + eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/sgm/modules/autoencoding/lpips/model/LICENSE b/sgm/modules/autoencoding/lpips/model/LICENSE new file mode 100644 index 00000000..4b356e66 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/model/LICENSE @@ -0,0 +1,58 @@ +Copyright (c) 2017, Jun-Yan Zhu and Taesung Park +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +--------------------------- LICENSE FOR pix2pix -------------------------------- +BSD License + +For pix2pix software +Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +----------------------------- LICENSE FOR DCGAN -------------------------------- +BSD License + +For dcgan.torch software + +Copyright (c) 2015, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/sgm/modules/autoencoding/lpips/model/__init__.py b/sgm/modules/autoencoding/lpips/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sgm/modules/autoencoding/lpips/model/model.py b/sgm/modules/autoencoding/lpips/model/model.py new file mode 100644 index 00000000..66357d4e --- /dev/null +++ b/sgm/modules/autoencoding/lpips/model/model.py @@ -0,0 +1,88 @@ +import functools + +import torch.nn as nn + +from ..util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find("BatchNorm") != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if ( + type(norm_layer) == functools.partial + ): # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [ + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True), + ] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2**n, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2**n_layers, 8) + sequence += [ + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/sgm/modules/autoencoding/lpips/util.py b/sgm/modules/autoencoding/lpips/util.py new file mode 100644 index 00000000..49c76e37 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/util.py @@ -0,0 +1,128 @@ +import hashlib +import os + +import requests +import torch +import torch.nn as nn +from tqdm import tqdm + +URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"} + +CKPT_MAP = {"vgg_lpips": "vgg.pth"} + +MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class ActNorm(nn.Module): + def __init__( + self, num_features, logdet=False, affine=True, allow_reverse_init=False + ): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:, :, None, None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height * width * torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:, :, None, None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h diff --git a/sgm/modules/autoencoding/lpips/vqperceptual.py b/sgm/modules/autoencoding/lpips/vqperceptual.py new file mode 100644 index 00000000..6195f0a6 --- /dev/null +++ b/sgm/modules/autoencoding/lpips/vqperceptual.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake)) + ) + return d_loss diff --git a/sgm/modules/diffusionmodules/__init__.py b/sgm/modules/diffusionmodules/__init__.py index ce7968af..867b69e8 100644 --- a/sgm/modules/diffusionmodules/__init__.py +++ b/sgm/modules/diffusionmodules/__init__.py @@ -1,7 +1,7 @@ from .denoiser import Denoiser from .discretizer import Discretization from .loss import StandardDiffusionLoss -from .model import Model, Encoder, Decoder +from .model import Decoder, Encoder, Model from .openaimodel import UNetModel from .sampling import BaseDiffusionSampler from .wrappers import OpenAIWrapper diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py index 397b8f38..4135ac99 100644 --- a/sgm/modules/diffusionmodules/discretizer.py +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -1,10 +1,11 @@ -import torch -import numpy as np -from functools import partial from abc import abstractmethod +from functools import partial + +import numpy as np +import torch -from ...util import append_zero from ...modules.diffusionmodules.util import make_beta_schedule +from ...util import append_zero def generate_roughly_equally_spaced_steps( diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py index 555abc1c..508230c9 100644 --- a/sgm/modules/diffusionmodules/loss.py +++ b/sgm/modules/diffusionmodules/loss.py @@ -3,9 +3,9 @@ import torch import torch.nn as nn from omegaconf import ListConfig -from taming.modules.losses.lpips import LPIPS from ...util import append_dims, instantiate_from_config +from ...modules.autoencoding.lpips.loss.lpips import LPIPS class StandardDiffusionLoss(nn.Module): diff --git a/sgm/modules/diffusionmodules/wrappers.py b/sgm/modules/diffusionmodules/wrappers.py index 87ede606..37449ea6 100644 --- a/sgm/modules/diffusionmodules/wrappers.py +++ b/sgm/modules/diffusionmodules/wrappers.py @@ -30,5 +30,5 @@ def forward( timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), - **kwargs + **kwargs, ) diff --git a/sgm/modules/distributions/distributions.py b/sgm/modules/distributions/distributions.py index 0b61f030..016be355 100644 --- a/sgm/modules/distributions/distributions.py +++ b/sgm/modules/distributions/distributions.py @@ -1,5 +1,5 @@ -import torch import numpy as np +import torch class AbstractDistribution: