-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
107 lines (95 loc) · 3.15 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
from collections import defaultdict
import torch
from erasing.utils.utils import *
class_list = [
"cassette player",
"chain saw",
"church",
"gas pump",
"tench",
"garbage truck",
"English springer",
"golf ball",
"parachute",
"French horn",
]
N_IMGS = 1
def main(args):
diffuser = StableDiffuser(scheduler="DDIM").to("cuda")
erase_concept = args.erase_concept
esd_path = f"{args.model_dir}/esd-{erase_concept.lower().replace(' ','')}_from_{erase_concept.lower().replace(' ','')}-{args.train_method}_1-epochs_{args.epochs}.pt"
train_method = args.train_method
finetuner = FineTunedModel(diffuser, train_method=train_method)
finetuner.load_state_dict(torch.load(esd_path))
seed = 1234
generated_images = defaultdict(list)
for cls in class_list:
if args.finetuner:
with finetuner:
for _ in range(N_IMGS):
images = diffuser(
f"an image of a {cls}",
img_size=512,
n_steps=50,
n_imgs=1,
generator=torch.Generator().manual_seed(seed),
guidance_scale=7.5,
)
generated_images[cls] += [image[0] for image in images]
else:
for _ in range(N_IMGS):
images = diffuser(
f"an image of a {cls}",
img_size=512,
n_steps=50,
n_imgs=1,
generator=torch.Generator().manual_seed(seed),
guidance_scale=7.5,
)
generated_images[cls] += [image[0] for image in images]
for cls in generated_images:
for idx, img in enumerate(generated_images[cls]):
save_path = f"{args.output_dir}/{'finetuned_'+args.erase_concept if args.finetuner else 'sd'}/{cls.lower().replace(' ','')}_{idx}.jpg"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
img.save(save_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="TrainESD", description="Finetuning stable diffusion to erase the concepts"
)
parser.add_argument(
"--erase_concept", help="concept to erase", type=str, required=True
)
parser.add_argument(
"--train_method",
help="Type of method (xattn, noxattn, full, xattn-strict",
type=str,
required=True,
)
parser.add_argument(
"--output_dir",
help="Path to directory to store images",
type=str,
default="output",
)
parser.add_argument(
"--finetuner",
help="Whether to use stable diffusion or finetuned model",
action="store_true",
default=False,
)
parser.add_argument(
"--model_dir",
help="Path to directory containing models",
type=str,
default="models",
)
parser.add_argument(
"--epochs",
help="Number of epochs the trained model was trained for",
type=int,
default=1000,
)
args = parser.parse_args()
main(args)