Skip to content

Commit

Permalink
Revert "Replace most print()s with logging calls (#42)" (#65)
Browse files Browse the repository at this point in the history
This reverts commit 6f6d3f8.
  • Loading branch information
Jonas Müller authored Jul 26, 2023
1 parent 7934245 commit 4a3f0f5
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 117 deletions.
26 changes: 13 additions & 13 deletions sgm/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import logging
from typing import Optional

import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule

logger = logging.getLogger(__name__)

try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
raise NotImplementedError(
"Datasets not yet available. "
"To enable, we need to add stable-datasets as a submodule; "
"please use ``git submodule update --init --recursive`` "
"and do ``pip install -e stable-datasets/`` from the root of this repo"
) from e
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)


class StableDataModuleFromConfig(LightningDataModule):
Expand All @@ -41,8 +39,8 @@ def __init__(
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
logger.warning(
"No Validation datapipeline defined, using that one from training"
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train

Expand All @@ -54,10 +52,12 @@ def __init__(

self.dummy = dummy
if self.dummy:
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)

def setup(self, stage: str) -> None:
logger.debug("Preparing datasets")
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
Expand Down
31 changes: 15 additions & 16 deletions sgm/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import logging

import numpy as np

logger = logging.getLogger(__name__)


class LambdaWarmUpCosineScheduler:
"""
Expand All @@ -28,8 +24,9 @@ def __init__(
self.verbosity_interval = verbosity_interval

def schedule(self, n, **kwargs):
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps:
lr = (
self.lr_max - self.lr_start
Expand Down Expand Up @@ -86,11 +83,12 @@ def find_in_interval(self, n):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle
Expand All @@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
logger.info(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)
if self.verbosity_interval > 0:
if n % self.verbosity_interval == 0:
print(
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
f"current cycle {cycle}"
)

if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
Expand Down
19 changes: 8 additions & 11 deletions sgm/models/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import re
from abc import abstractmethod
from contextlib import contextmanager
Expand All @@ -15,8 +14,6 @@
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config

logger = logging.getLogger(__name__)


class AbstractAutoencoder(pl.LightningModule):
"""
Expand All @@ -41,7 +38,7 @@ def __init__(

if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
Expand All @@ -63,16 +60,16 @@ def init_from_ckpt(
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
logger.debug(f"Deleting key {k} from state_dict.")
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
logger.debug(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")

@abstractmethod
def get_input(self, batch) -> Any:
Expand All @@ -89,14 +86,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")

@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
Expand All @@ -107,7 +104,7 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("decode()-method of abstract base class called")

def instantiate_optimizer_from_config(self, params, lr, cfg):
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)
Expand Down
17 changes: 7 additions & 10 deletions sgm/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union

Expand All @@ -19,8 +18,6 @@
log_txt_as_img,
)

logger = logging.getLogger(__name__)


class DiffusionEngine(pl.LightningModule):
def __init__(
Expand Down Expand Up @@ -76,7 +73,7 @@ def __init__(
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
Expand All @@ -97,13 +94,13 @@ def init_from_ckpt(
raise NotImplementedError

missing, unexpected = self.load_state_dict(sd, strict=False)
logger.info(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")

def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
Expand Down Expand Up @@ -182,14 +179,14 @@ def ema_scope(self, context=None):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")

def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
Expand All @@ -205,7 +202,7 @@ def configure_optimizers(self):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
logger.debug("Setting up LambdaLR scheduler...")
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
Expand Down
38 changes: 17 additions & 21 deletions sgm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional
Expand All @@ -9,10 +8,6 @@
from packaging import version
from torch import nn


logger = logging.getLogger(__name__)


if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
Expand Down Expand Up @@ -41,9 +36,9 @@
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
logger.warning(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)

try:
Expand All @@ -53,7 +48,7 @@
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")
print("no module 'xformers'. Processing without...")

from .diffusionmodules.util import checkpoint

Expand Down Expand Up @@ -294,7 +289,7 @@ def __init__(
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
logger.info(
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
Expand Down Expand Up @@ -398,21 +393,22 @@ def __init__(
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logger.warning(
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
logger.warning(
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
raise NotImplementedError(
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
logger.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
Expand Down Expand Up @@ -441,7 +437,7 @@ def __init__(
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
logger.info(f"{self.__class__.__name__} is using checkpointing")
print(f"{self.__class__.__name__} is using checkpointing")

def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
Expand Down Expand Up @@ -558,7 +554,7 @@ def __init__(
sdp_backend=None,
):
super().__init__()
logger.debug(
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
Expand All @@ -567,8 +563,8 @@ def __init__(
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
logger.warning(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.
Expand Down
6 changes: 1 addition & 5 deletions sgm/modules/autoencoding/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from typing import Any, Union

import torch
Expand All @@ -11,9 +10,6 @@
from ....util import default, instantiate_from_config


logger = logging.getLogger(__name__)


def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold:
weight = value
Expand Down Expand Up @@ -108,7 +104,7 @@ def __init__(
super().__init__()
self.dims = dims
if self.dims > 2:
logger.info(
print(
f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. "
)
Expand Down
Loading

0 comments on commit 4a3f0f5

Please sign in to comment.