diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index edf1eaff..32b2b8b4 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -17,7 +17,7 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if attr.device != torch.device("cuda") and torch.cuda.is_available(): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 78eeb100..102957df 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -17,7 +17,7 @@ def __init__(self, model, schedule="linear", **kwargs): def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): + if attr.device != torch.device("cuda") and torch.cuda.is_available(): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index aa3031df..6c7e0644 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -35,7 +35,7 @@ def forward(self, batch, key=None): class TransformerEmbedder(AbstractEncoder): """Some transformer encoder layers""" - def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): + def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda" if torch.cuda.is_available() else "cpu"): super().__init__() self.device = device self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, @@ -52,7 +52,7 @@ def encode(self, x): class BERTTokenizer(AbstractEncoder): """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" - def __init__(self, device="cuda", vq_interface=True, max_length=77): + def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu", vq_interface=True, max_length=77): super().__init__() from transformers import BertTokenizerFast # TODO: add to reuquirements self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") @@ -80,7 +80,7 @@ def decode(self, text): class BERTEmbedder(AbstractEncoder): """Uses the BERT tokenizr model and add some transformer encoder layers""" def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, - device="cuda",use_tokenizer=True, embedding_dropout=0.0): + device="cuda" if torch.cuda.is_available() else "cpu", use_tokenizer=True, embedding_dropout=0.0): super().__init__() self.use_tknz_fn = use_tokenizer if self.use_tknz_fn: @@ -139,7 +139,7 @@ class FrozenCLIPTextEmbedder(nn.Module): """ Uses the CLIP transformer encoder for text. """ - def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): + def __init__(self, version='ViT-L/14', device="cuda" if torch.cuda.is_available() else "cpu", max_length=77, n_repeat=1, normalize=True): super().__init__() self.model, _ = clip.load(version, jit=False, device="cpu") self.device = device diff --git a/notebook_helpers.py b/notebook_helpers.py index 5d0ebd7e..9e3dec35 100644 --- a/notebook_helpers.py +++ b/notebook_helpers.py @@ -44,7 +44,8 @@ def load_model_from_config(config, ckpt): sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return {"model": model}, global_step @@ -117,7 +118,8 @@ def get_cond(mode, selected_path): c = rearrange(c, '1 c h w -> 1 h w c') c = 2. * c - 1. - c = c.to(torch.device("cuda")) + if torch.cuda.is_available(): + c = c.to(torch.device("cuda")) example["LR_image"] = c example["image"] = c_up diff --git a/scripts/knn2img.py b/scripts/knn2img.py index e6eaaeca..a40ce710 100644 --- a/scripts/knn2img.py +++ b/scripts/knn2img.py @@ -53,7 +53,8 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return model @@ -358,7 +359,10 @@ def __call__(self, x, n): uc = None if searcher is not None: nn_dict = searcher(c, opt.knn) - c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1) + nn_embeddings = torch.from_numpy(nn_dict['nn_embeddings']) + if torch.cuda.is_available(): + nn_embeddings = nn_embeddings.cuda() + c = torch.cat([c, nn_embeddings], dim=1) if opt.scale != 1.0: uc = torch.zeros_like(c) if isinstance(prompts, tuple): diff --git a/scripts/sample_diffusion.py b/scripts/sample_diffusion.py index 876fe3c3..3a82185f 100644 --- a/scripts/sample_diffusion.py +++ b/scripts/sample_diffusion.py @@ -220,7 +220,8 @@ def get_parser(): def load_model_from_config(config, sd): model = instantiate_from_config(config) model.load_state_dict(sd,strict=False) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return model diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 613de5e1..ed0c8cd5 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -25,7 +25,8 @@ def load_model_from_config(config, ckpt, verbose=False): print("unexpected keys:") print(u) - model.cuda() + if torch.cuda.is_available(): + model.cuda() model.eval() return model