forked from suno-ai/bark
-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathbark_perform.py
164 lines (126 loc) · 5.38 KB
/
bark_perform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import numpy as np
from rich import print
from bark_infinity import config
logger = config.logger
from bark_infinity import generation
from bark_infinity import api
from bark_infinity import text_processing
import time
import random
text_prompts_in_this_file = []
import torch
from torch.utils import collect_env
try:
text_prompts_in_this_file.append(
f"It's {text_processing.current_date_time_in_words()} And if you're hearing this, Bark is working. But you didn't provide any text"
)
except Exception as e:
print(f"An error occurred: {e}")
text_prompt = """
In the beginning the Universe was created. This has made a lot of people very angry and been widely regarded as a bad move. However, Bark is working.
"""
text_prompts_in_this_file.append(text_prompt)
text_prompt = """
A common mistake that people make when trying to design something completely foolproof is to underestimate the ingenuity of complete fools.
"""
text_prompts_in_this_file.append(text_prompt)
def get_group_args(group_name, updated_args):
# Convert the Namespace object to a dictionary
updated_args_dict = vars(updated_args)
group_args = {}
for key, value in updated_args_dict.items():
if key in dict(config.DEFAULTS[group_name]):
group_args[key] = value
return group_args
def main(args):
if args.loglevel is not None:
logger.setLevel(args.loglevel)
if args.OFFLOAD_CPU is not None:
generation.OFFLOAD_CPU = args.OFFLOAD_CPU
# print(f"OFFLOAD_CPU is set to {generation.OFFLOAD_CPU}")
else:
if generation.get_SUNO_USE_DIRECTML() is not True:
generation.OFFLOAD_CPU = True # default on just in case
if args.USE_SMALL_MODELS is not None:
generation.USE_SMALL_MODELS = args.USE_SMALL_MODELS
# print(f"USE_SMALL_MODELS is set to {generation.USE_SMALL_MODELS}")
if args.GLOBAL_ENABLE_MPS is not None:
generation.GLOBAL_ENABLE_MPS = args.GLOBAL_ENABLE_MPS
# print(f"GLOBAL_ENABLE_MPS is set to {generation.GLOBAL_ENABLE_MPS}")
if not args.silent:
if args.detailed_gpu_report or args.show_all_reports:
print(api.startup_status_report(quick=False))
elif not args.text_prompt and not args.prompt_file: # probably a test run, default to show
print(api.startup_status_report(quick=True))
if args.detailed_hugging_face_cache_report or args.show_all_reports:
print(api.hugging_face_cache_report())
if args.detailed_cuda_report or args.show_all_reports:
print(api.cuda_status_report())
if args.detailed_numpy_report:
print(api.numpy_report())
if args.run_numpy_benchmark or args.show_all_reports:
from bark_infinity.debug import numpy_benchmark
numpy_benchmark()
if args.list_speakers:
api.list_speakers()
return
if args.render_npz_samples:
api.render_npz_samples()
return
if args.text_prompt:
text_prompts_to_process = [args.text_prompt]
elif args.prompt_file:
text_file = text_processing.load_text(args.prompt_file)
if text_file is None:
logger.error(f"Error loading file: {args.prompt_file}")
return
text_prompts_to_process = text_processing.split_text(
text_processing.load_text(args.prompt_file),
args.split_input_into_separate_prompts_by,
args.split_input_into_separate_prompts_by_value,
)
print(f"\nProcessing file: {args.prompt_file}")
print(f" Looks like: {len(text_prompts_to_process)} prompt(s)")
else:
print("No --text_prompt or --prompt_file specified, using test prompt.")
text_prompts_to_process = random.sample(text_prompts_in_this_file, 2)
things = len(text_prompts_to_process) + args.output_iterations
if things > 10:
if args.dry_run is False:
print(
f"WARNING: You are about to process {things} prompts. Consider using '--dry-run' to test things first."
)
# pprint(args)
print("Loading Bark models...")
if not args.dry_run and generation.get_SUNO_USE_DIRECTML() is not True:
generation.preload_models(
args.text_use_gpu,
args.text_use_small,
args.coarse_use_gpu,
args.coarse_use_small,
args.fine_use_gpu,
args.fine_use_small,
args.codec_use_gpu,
args.force_reload,
)
print("Done.")
for idx, text_prompt in enumerate(text_prompts_to_process, start=1):
if len(text_prompts_to_process) > 1:
print(f"\nPrompt {idx}/{len(text_prompts_to_process)}:")
# print(f"Text prompt: {text_prompt}")
for iteration in range(1, args.output_iterations + 1):
if args.output_iterations > 1:
print(f"\nIteration {iteration} of {args.output_iterations}.")
if iteration == 1:
print("ss", text_prompt)
args.current_iteration = iteration
args.text_prompt = text_prompt
args_dict = vars(args)
api.generate_audio_long(**args_dict)
if __name__ == "__main__":
parser = config.create_argument_parser()
args = parser.parse_args()
updated_args = config.update_group_args_with_defaults(args)
namespace_args = argparse.Namespace(**updated_args)
main(namespace_args)