From 8429b5b7ca9d198072f3f7608a3b81f18fde5056 Mon Sep 17 00:00:00 2001 From: Koichi Saito <116609740+koichi-saito-sony@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:35:17 -0400 Subject: [PATCH] Add files via upload --- tango_edm/audioldm/__init__.py | 8 + tango_edm/audioldm/__main__.py | 183 +++ tango_edm/audioldm/audio/__init__.py | 2 + .../audio/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 314 bytes .../audio/__pycache__/__init__.cpython-39.pyc | Bin 0 -> 260 bytes .../audio_processing.cpython-38.pyc | Bin 0 -> 2786 bytes .../audio_processing.cpython-39.pyc | Bin 0 -> 2782 bytes .../audio/__pycache__/mix.cpython-39.pyc | Bin 0 -> 1704 bytes .../audio/__pycache__/stft.cpython-38.pyc | Bin 0 -> 5032 bytes .../audio/__pycache__/stft.cpython-39.pyc | Bin 0 -> 4985 bytes .../audio/__pycache__/tools.cpython-38.pyc | Bin 0 -> 2218 bytes .../audio/__pycache__/tools.cpython-39.pyc | Bin 0 -> 2191 bytes .../__pycache__/torch_tools.cpython-39.pyc | Bin 0 -> 3789 bytes tango_edm/audioldm/audio/audio_processing.py | 100 ++ tango_edm/audioldm/audio/stft.py | 186 +++ tango_edm/audioldm/audio/tools.py | 85 ++ tango_edm/audioldm/hifigan/__init__.py | 7 + .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 601 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 574 bytes .../hifigan/__pycache__/models.cpython-38.pyc | Bin 0 -> 3730 bytes .../hifigan/__pycache__/models.cpython-39.pyc | Bin 0 -> 3726 bytes .../__pycache__/utilities.cpython-38.pyc | Bin 0 -> 2259 bytes .../__pycache__/utilities.cpython-39.pyc | Bin 0 -> 2373 bytes tango_edm/audioldm/hifigan/models.py | 174 +++ tango_edm/audioldm/hifigan/utilities.py | 86 ++ .../audioldm/latent_diffusion/__init__.py | 0 .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 168 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 164 bytes .../__pycache__/attention.cpython-38.pyc | Bin 0 -> 11470 bytes .../__pycache__/attention.cpython-39.pyc | Bin 0 -> 11379 bytes .../__pycache__/ddim.cpython-38.pyc | Bin 0 -> 7197 bytes .../__pycache__/ddim.cpython-39.pyc | Bin 0 -> 7110 bytes .../__pycache__/ddpm.cpython-38.pyc | Bin 0 -> 11093 bytes .../__pycache__/ddpm.cpython-39.pyc | Bin 0 -> 11039 bytes .../__pycache__/ema.cpython-38.pyc | Bin 0 -> 3013 bytes .../__pycache__/ema.cpython-39.pyc | Bin 0 -> 3004 bytes .../__pycache__/openaimodel.cpython-39.pyc | Bin 0 -> 23673 bytes .../__pycache__/util.cpython-38.pyc | Bin 0 -> 9569 bytes .../__pycache__/util.cpython-39.pyc | Bin 0 -> 9604 bytes .../audioldm/latent_diffusion/attention.py | 469 ++++++++ tango_edm/audioldm/latent_diffusion/ddim.py | 377 ++++++ tango_edm/audioldm/latent_diffusion/ddpm.py | 441 +++++++ tango_edm/audioldm/latent_diffusion/ema.py | 82 ++ .../audioldm/latent_diffusion/openaimodel.py | 1069 +++++++++++++++++ tango_edm/audioldm/latent_diffusion/util.py | 295 +++++ tango_edm/audioldm/ldm.py | 819 +++++++++++++ tango_edm/audioldm/pipeline.py | 301 +++++ tango_edm/audioldm/utils.py | 281 +++++ .../variational_autoencoder/__init__.py | 1 + .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 267 bytes .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 220 bytes .../__pycache__/autoencoder.cpython-38.pyc | Bin 0 -> 4431 bytes .../__pycache__/autoencoder.cpython-39.pyc | Bin 0 -> 4372 bytes .../__pycache__/distributions.cpython-38.pyc | Bin 0 -> 3809 bytes .../__pycache__/distributions.cpython-39.pyc | Bin 0 -> 3780 bytes .../__pycache__/modules.cpython-38.pyc | Bin 0 -> 22361 bytes .../__pycache__/modules.cpython-39.pyc | Bin 0 -> 22086 bytes .../variational_autoencoder/autoencoder.py | 135 +++ .../variational_autoencoder/distributions.py | 102 ++ .../variational_autoencoder/modules.py | 1066 ++++++++++++++++ 60 files changed, 6269 insertions(+) create mode 100644 tango_edm/audioldm/__init__.py create mode 100644 tango_edm/audioldm/__main__.py create mode 100644 tango_edm/audioldm/audio/__init__.py create mode 100644 tango_edm/audioldm/audio/__pycache__/__init__.cpython-38.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/__init__.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-38.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/mix.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/stft.cpython-38.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/stft.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/tools.cpython-38.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/tools.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/__pycache__/torch_tools.cpython-39.pyc create mode 100644 tango_edm/audioldm/audio/audio_processing.py create mode 100644 tango_edm/audioldm/audio/stft.py create mode 100644 tango_edm/audioldm/audio/tools.py create mode 100644 tango_edm/audioldm/hifigan/__init__.py create mode 100644 tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-38.pyc create mode 100644 tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-39.pyc create mode 100644 tango_edm/audioldm/hifigan/__pycache__/models.cpython-38.pyc create mode 100644 tango_edm/audioldm/hifigan/__pycache__/models.cpython-39.pyc create mode 100644 tango_edm/audioldm/hifigan/__pycache__/utilities.cpython-38.pyc create mode 100644 tango_edm/audioldm/hifigan/__pycache__/utilities.cpython-39.pyc create mode 100644 tango_edm/audioldm/hifigan/models.py create mode 100644 tango_edm/audioldm/hifigan/utilities.py create mode 100644 tango_edm/audioldm/latent_diffusion/__init__.py create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/attention.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ddim.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ema.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-38.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc create mode 100644 tango_edm/audioldm/latent_diffusion/attention.py create mode 100644 tango_edm/audioldm/latent_diffusion/ddim.py create mode 100644 tango_edm/audioldm/latent_diffusion/ddpm.py create mode 100644 tango_edm/audioldm/latent_diffusion/ema.py create mode 100644 tango_edm/audioldm/latent_diffusion/openaimodel.py create mode 100644 tango_edm/audioldm/latent_diffusion/util.py create mode 100644 tango_edm/audioldm/ldm.py create mode 100644 tango_edm/audioldm/pipeline.py create mode 100644 tango_edm/audioldm/utils.py create mode 100644 tango_edm/audioldm/variational_autoencoder/__init__.py create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-38.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-39.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-38.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-39.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/distributions.cpython-38.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/distributions.cpython-39.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-38.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-39.pyc create mode 100644 tango_edm/audioldm/variational_autoencoder/autoencoder.py create mode 100644 tango_edm/audioldm/variational_autoencoder/distributions.py create mode 100644 tango_edm/audioldm/variational_autoencoder/modules.py diff --git a/tango_edm/audioldm/__init__.py b/tango_edm/audioldm/__init__.py new file mode 100644 index 0000000..075eeed --- /dev/null +++ b/tango_edm/audioldm/__init__.py @@ -0,0 +1,8 @@ +from tango_edm.audioldm.ldm import LatentDiffusion +from tango_edm.audioldm.utils import seed_everything, save_wave, get_time, get_duration +# from tango_edm.audioldm.pipeline import * + + + + + diff --git a/tango_edm/audioldm/__main__.py b/tango_edm/audioldm/__main__.py new file mode 100644 index 0000000..851d9da --- /dev/null +++ b/tango_edm/audioldm/__main__.py @@ -0,0 +1,183 @@ +#!/usr/bin/python3 +import os +from tango_edm.audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration +import argparse + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--mode", + type=str, + required=False, + default="generation", + help="generation: text-to-audio generation; transfer: style transfer", + choices=["generation", "transfer"] +) + +parser.add_argument( + "-t", + "--text", + type=str, + required=False, + default="", + help="Text prompt to the model for audio generation", +) + +parser.add_argument( + "-f", + "--file_path", + type=str, + required=False, + default=None, + help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", +) + +parser.add_argument( + "--transfer_strength", + type=float, + required=False, + default=0.5, + help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", +) + +parser.add_argument( + "-s", + "--save_path", + type=str, + required=False, + help="The path to save model output", + default="./output", +) + +parser.add_argument( + "--model_name", + type=str, + required=False, + help="The checkpoint you gonna use", + default="audioldm-s-full", + choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"] +) + +parser.add_argument( + "-ckpt", + "--ckpt_path", + type=str, + required=False, + help="The path to the pretrained .ckpt model", + default=None, +) + +parser.add_argument( + "-b", + "--batchsize", + type=int, + required=False, + default=1, + help="Generate how many samples at the same time", +) + +parser.add_argument( + "--ddim_steps", + type=int, + required=False, + default=200, + help="The sampling step for DDIM", +) + +parser.add_argument( + "-gs", + "--guidance_scale", + type=float, + required=False, + default=2.5, + help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", +) + +parser.add_argument( + "-dur", + "--duration", + type=float, + required=False, + default=10.0, + help="The duration of the samples", +) + +parser.add_argument( + "-n", + "--n_candidate_gen_per_text", + type=int, + required=False, + default=3, + help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", +) + +parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Change this value (any integer number) will lead to a different generation result.", +) + +args = parser.parse_args() + +if(args.ckpt_path is not None): + print("Warning: ckpt_path has no effect after version 0.0.20.") + +assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" + +mode = args.mode +if(mode == "generation" and args.file_path is not None): + mode = "generation_audio_to_audio" + if(len(args.text) > 0): + print("Warning: You have specified the --file_path. --text will be ignored") + args.text = "" + +save_path = os.path.join(args.save_path, mode) + +if(args.file_path is not None): + save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) + +text = args.text +random_seed = args.seed +duration = args.duration +guidance_scale = args.guidance_scale +n_candidate_gen_per_text = args.n_candidate_gen_per_text + +os.makedirs(save_path, exist_ok=True) +audioldm = build_model(model_name=args.model_name) + +if(args.mode == "generation"): + waveform = text_to_audio( + audioldm, + text, + args.file_path, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=args.batchsize, + ) + +elif(args.mode == "transfer"): + assert args.file_path is not None + assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path + waveform = style_transfer( + audioldm, + text, + args.file_path, + args.transfer_strength, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + batchsize=args.batchsize, + ) + waveform = waveform[:,None,:] + +save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) diff --git a/tango_edm/audioldm/audio/__init__.py b/tango_edm/audioldm/audio/__init__.py new file mode 100644 index 0000000..a8ad02b --- /dev/null +++ b/tango_edm/audioldm/audio/__init__.py @@ -0,0 +1,2 @@ +from tango_edm.audioldm.audio.tools import wav_to_fbank, read_wav_file +from tango_edm.audioldm.audio.stft import TacotronSTFT diff --git a/tango_edm/audioldm/audio/__pycache__/__init__.cpython-38.pyc b/tango_edm/audioldm/audio/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c85beab7327e9e7b20ebb4bb823b4b3c8b72d7 GIT binary patch literal 314 zcmZ`zF;2uV5VV~DCv=hc0EjkC`~idz1traef;5(iojY^3_hdUD8vetV+*0uY3M$q@ zhiF)7XJ=NU*}ZF;V}kX0bU(e~d@tf}F%(zW@f<-BMJ;i3ppFfU7R_ws>Y+wntw>vc zpo_FDFQ4U;NFb&M>9@HFS}IXild-y`dD+VmQUJf|Z~EKzU^q+CuR&<#Iyox?c2Bwl zU}Gq6gbE>_{_BisN})#28wrQSU*PMSZ@_831v738-vxNoW8x8Jug2u~ZyMj<{UKZk Y<4qEx+r8w)g#$HPeMvDfmS0ZjH*6VK*#H0l literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/__init__.cpython-39.pyc b/tango_edm/audioldm/audio/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90462df03852001507606b9849462d919c8804bf GIT binary patch literal 260 zcmYjLu}TCn5KXc<6uAY#AF$oVl-I~%WqAm<)ixpCBp5H*Ig+g6C;1b8NLnj@!O98O z_+Z|=_ZWC%`u>U$ZoW3Kqy62?e+V)UH2IbY81TR%pLi6L;LVHBML9{*-JT8dhtt_0 z>iXTh>4MrloB3G#WDV#RIUnr6>v}W?ED7iR_&Dx{Rb3UtkW*_*9Z{9JxRNBvPnIrP z+~{|!UNBk()}GW;!dH7JS(ybqYVgi2+2efn0peLGt^Mqa)_13mdU+oamSDH-&)n*+ F`5*9jLg)Yh literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-38.pyc b/tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3a91d58518ef0ad35daaf32aaac256eda76e1ed GIT binary patch literal 2786 zcmb7GNpBoQ6z=Nondupi9TOr*2un#U%77OZvB*FKF+u_fgc1Y7EbVev^-NECOI5X< zvD&AA#DVY|#up@h0Dq>gaLJJa5^&KQMrqBLDy%d2|Z_r9JF7Z-aDv|l=(#Mfc; zBTh~?4<>KGqjsU=4kxLz-Mq=RYv*3(Z?_1pxoJD=Y~fF$Uy|)EU*c`v zfl*I5;{_h@ZtU?MUx1Z9Kgaug5k`xg+;9fV_dt*H{(zX?Q0Q#_p~W{*O9oW`=qXqxQGLNWU6+DmqM zv&hOy3#vzgsw!Jmw<;zD75Sb>i&D@cri?Zbp>dT*Iw^9>G%OC2J(1H(1nTi#fSJ|$R zD>Pzh%64(3kPsxLSJULCutza^A6raNRf;HylgL`OLYXyg=L9r6O$_odMq3o!0OK{i zN}oNxXba-p=(M5DSvpY=twz2%4IvGikpAa+#WL_hVVoJ?V|gU#uF(4e5_9@SaQ@`R zPm2)GDdTIDUYk%ZVpgRZVuv|~q(&pqy#wCHdS{jeJNXB@PgDpN zjVL?g{Oml}6F}Vt&C4}AJYyqgno%n-Gqb2PptlLzb)hSnpGm$wyst11a9_@ulx(sB zXkw*!edkJJ>obVqd3c_txY{Idl|DBZ?3^*+q(bmsTBS02FEG_HI7K}O& ztlcD6F!XbP?>!eJq!sHLH!n@=KtP2`TsHk#xrUWaQqx6wmL>%ZD?~mFJNF)iE&^IRZg?tyXA>UvHFVmU z5_nwTNN4O_fce-Tw`ynFK61VVmOdg!1bEwyU2IGcRt9=748zsLgr`oO>5!8aBMAG~?}&G&BDVM9xHn6Y7= z=!y#kv;@qeWibTkL1oR?t(elW4&!`--q?D3i$*B}K<_7dbV7mj3T_RcqBs!F9#Ej5 zSKevXfD%UqZ^~icnu72>YV6xJW0 zI4zid;d#^2g^WffV3kJWxd}dEGC?LaJ^bK=Kz5Mn&de;+fkX=d5MM-<-D4IXbWHEW zZxCyoHk~|7kkdgkS_f~H41qZ0MLbVJ3(BUuj4KO6nZz;jU6Ngc1+@&F6Ff!wfTQzn zpDd9JM7{uP*7xDhqaVNi;n$z8Mpyy!hT|4I7ElPs8i7RYhn(DX@xy!v|2Vh@FaAq< zH#cYmUss8EhmU5=B8Zih9}u|&w+>wS6ue9ea+;L`U!p3rcwFTnd~glf+zz5NDZ@Q* z8Y}RWvzPcJXIT=3mI1<3+CBt9g%skn<;&1Vc8qgqIL5?V_s|fmml~>~CYK#>iyXZD zf89F%+|O-j{A?VQkAV(w7TuFCVfQCLE}!*-i}@`Ow=Y0`?#U~750Q^FFN@u!SqLr6 z8XV6_bIp>bkG;e{oTjg$q@1kH0EAJHPhy9xVj*F?6hMselBnvGnbu$nu;4ZOM@C#l Z%2wy03xJ#Yq~G<4`&;kT-u>she*k7r{Vf0h literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-39.pyc b/tango_edm/audioldm/audio/__pycache__/audio_processing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0834d7e33b8705ffbe6f53fcc1718f10fa3505 GIT binary patch literal 2782 zcmb7GNpBoQ6z=Nondupi9phM#5SEfylmRaRM1l-NBu0@q5Jiar$&if8UDY!^^p>h> zJ7cv^0f__Q7jPI~koW=onYxBcj@*C~?^TaCR#BR+uH{v|?R#Gjo1N`A(7yjWTA zIy~wJ=qits(vjWD-}frNax22!bIzjkjI-a;&WKd4s=>WS?tY-%3Z;GCc;xN3wKwvs zcGcv~M=o!jb#P_UfEEA1g>kbY56Oql-Rt|Es#Cd(&i+glR5Q8-_svxPGtNWzYq#N? zcd9lgI^Y59__%itc5v^!bH(wT_niCpssKD|vpooAheCUaMlw%n5^u{~F?wm`vRz)! z)1uUZ>Y<>jOqbRDl1V{Dwj+|f5Hyb{qjf}RRA!-$^NcbLi-UMaWOSGpG!fZA4_7+& z!i{09DEvw#IMq27O2;V!4QbBHgxT{VibSYEEaPPeV%y`XiZ~nCOY&_c+u|%^n-4=3* zhAc_gHm>9nf~53v5`QV|QHDWZD=!=j1@$ymak7kNc(k2|MR?JDR`kU&W!J{EEIHG=v@JcIe8;Ee|+QT zd4Wk_$=y>%O*KNk5ek3G_zI=h$CQhRm5GMfVU8iG;Sh9hfwz&~nr6XH{=x2J6@rCB z%1${yJ(!#I%4=RJg-q)t&_J*-{|$WP8o1qA^cQ{{7uO0xtmW2kOdG+!#cA8qZR~f zJI)jg{S4sy#03dy#oETr3e(sVP@xi+O?O(Z{Zhw?X`?(#;yqzH1>^ltWB_;5fJzt$ z%0$&M_WL!v;b6XhaiLlFX0?_JY2;Z%AyKHOPK+8< z`;FCMo{Cj2M6r5H<|8cRRaWvi@AEi{N(I=e=Lt{i!F0Qjc_@^E>RTzsX1+GD4+OAOF>veJ?Q0+3vBR2@>@a16EY>9#3MdI! zMaz5uu!E{PT(@FMN4g(nYxK^>jSU(m48Xh_>)|o|(Tlj%gKFYHHakFnf?oWnUIRKD z=DaS6F~sI%`#@qCdrwJ>o@oGt62OpnP#Cv>qywgSo*EzD##e~BCR12@fZ!x&`mHxj zL+3Iant+uWjpru#l*t%b)O7G;(+8r1Ot+?HnHD5k2!Qu2s%#%Ic&}wT$9{uY?X+oS z{TTTiG^2I!R`CExL!QI)q;COPcNcJFW*`B9$aQgg9TwC)bWZRp=>m=x+%B0Ti$uN! zYu5L{FT=;*{PgS3S3^YqVZ(6)9t$XhV}(E>_9IRnx%g3jh<_T~y|@1*z3XcJo-a|AnLVoTK74EqS>Fu8Brf_p;51g>31`pq zamLa(?Az+^pV0OR2r8rxrzx*MAKEd_q2U-4Z{0&fuwE*tii%utz%812N#F8K-}(u{GliB;5|e>U%xDN=cXaFFspHVNSG^@ zFn#PL{@FBL6~@JQWeOmSf_w!#WEBeu<0SxMjF&`JC(N`4TYv@c**`JjD$=!DfGz-T S>XL5TC+=^ZcRObn+J6A7kNal; literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/mix.cpython-39.pyc b/tango_edm/audioldm/audio/__pycache__/mix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..003a82b3753e63a56273c14c523d2124d7447630 GIT binary patch literal 1704 zcmY*a&x<2P6t1fNkxtUd%sRmtXNFSTINcQVzT z-I=ih;>?`=3xs4rMDY(0yod*ng4e>spa&WB;7t&9{azc{u>>!4C8 zF|_kuiM^ASkB$K?)T3nW-CG9ib60$5^S%79qPZrNu%a;5etNWLT zX3gJ7GWOng9ZIukScZA%Z=k1-Au=}Ml4~Z#I@6q_v{u>F8e13<+LB@_#?nz`DmH!8e3yeJ^sQ&ImQm+Z0%@T*U=6R zRB0-Xon8tOq|J}^=dB(LYWs7~tNYR824m6LFF$*|zjogL^!xLZ{dv&Et==E+e6#iW z7eDOxw$6Y5>Eg5P+jYl?WN5s8l%&I;6B=hYcoZt*^as7Y>!uI~M^QYA>#h;q)Hq4I z+s%v{M@d`WH+ENrhpht1c6w;sM`6@E$i8Ii8g}&e>y3j!95!Sa4jXT)!Go}qrHx=D zqd{9n-R>xj2Fc9s%Xp?V;^=61cx=3&J6L|koku~RV83eHep$?Hl+=#ngQSm*L=!P`6nDu zy6XU>$Hs0ZWFfOOJB1g}0b=c(T6fq7UyTbONf9|7X7rBYGix-jh`6FIPC4ZnmvjksN%~n?V;84F`WtL2YHVDF z-`C}QB|1X|XRzjHa3>Y*O*v}!`q~His?F%* zo1he3Vsz(g(v}%Jsl1sy4EmAW?xx!xjTL=jIY|}Wean9Lt$TM>4eI3+v`hNTPnH%_ z;(-h?sea8A?;dr+VFoG~w>waAkQqk>NiQ^3KTK2s(~KR5L1HXoo~=k04uT{B%*G1t zr>cmU83n9#6dN}f4#PwmCr~OlR%IHc)$0}Gq=Qi+l~1lJ1TteIiU?xdG*gibjok~P z#PG=Qp0U$m*lCsWFQT36OABI#GIy8)h2vYuDL)g>Q98dfh=-#rY|}V-(w86@80=CP zRqhF!Z$hWO%d5n1czsdgYp~SIh<=8~L}9v)yl$O9L6zW{0BNjK0Z8q+4d8Rzk*H-< zaU?2|c0gI1$RdRnYplk!*g&n>0I`w5g^g(K(nwy#Z0$;C1=Wpwq)~-JG^h95wF- U2r=r1H>t*^`JcinRZE-y1HJx?qW}N^ literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/stft.cpython-38.pyc b/tango_edm/audioldm/audio/__pycache__/stft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ecfb84646307dcd225a5e4eb57af758a95e53b4 GIT binary patch literal 5032 zcmZu#NpBp-8SUzQn#18DYPXYt;y9i}A|qvk1;gVT!0?cR{z!gAU31bs$iYaUysvs@NJ?&^zplM%s^0s)SNwQ+IWX}2 z$NO~f&&!7KFUp*LY-HZTTm1zlF%p&;Ve@6VX&TJPtAan_^X0Ik(@wUMuZFAnTDX?4hwIFE-$>lVdt@XYH%8{@+>sS-pv+HNC~Gas z&L`}y(F>kpJ;rvAsbI)u^e`=w@?pw(_B=$6A@L zQZ2P7tNU5h@6ea1!~AR0_1dZ34U~<}eQV2BZ70UP)?b~4rbVTe2c>uzizM2Os}$L^c)&%)n{4lja;nLg3qfb2 z{3uF`R7Mecr5TVMWxN>5Jr(Sg<3$>%oTuFep=xbDi3G$eWO8K|V`X6x%HsQ6S!p4a z`=F9>|DLi1k27Vbc|25(MvYp!la;aD;YC#n z#QSoqDknwKm-&_i=gWvE`Bpqh(lVPTwumU3<9)G;@o81nEf|A!n7jCEvkh~d+3eOc z+kOe(YxoANZFX3|I+o9TbCn$|6LUA`C)cOFnLk827w@}xs|`Q{G%-f((5%_itO@Q6 zDG&!KOQ)#J*3`3+cK}wiwvVktCf$*T5zq;hWG@nFTQu)<1$FDSKRQ=?qjSf!s(|PaEfr@c3vj@ z4$gd8@cWY#emfe9IC)|!w~``Fxboomle8H2HZ)$8oA3u|pDR0U z0lCrwc5%uZ2rNjuO?qeoZ{}n3(3alFuWht#jVO#PX_rihY*VNtL?XRBbK$1!P!990 zBp&8NE4-_C*W@|^fdAM<2(qRda(#NfcH!3DBX-X|W{IQ2hS$tPEov8~M%+Z|=`%b%`^0f7r-;a}JBi#!p-BRvh4dliL5140+mS9`enCPF?(*jP}>seP2LU zdD_wcIFc%KZdExZg3Rpy9`JX+v7(y$KK=7Dhm z!!S6}U*Uy#4A=feRFnm{S|Uy+AbIsX8@$;lpJ2A?S1`nd3&8yq!+hr2r}5tg7l<2n z^Y@M0vf0Kn+j5ytZ7rK!V&*g7*)U&WubKyK&5_ygXS_VIg?_ze2=N{Q0k}VApOKs* zVTJFM1QcP&JT~F~;rE`fZ5%7D@zgHuAltK=y`%9|hjy%htDE#G{=V>gXMz7qw5=%A z7;GI3pRIOQ*<*sVey>FW14@`fU<6SF&pOE{-Hsy2+iX_jMNv}rqeutFW_`3O@mm|FD&`PX%MA4O*#+BLbS(%(vnFD~#soe8C>d z=d|QViy@g4^y_t^?jn0bKa1@C?e1J|y{7Lfo2RZSz3xL?guB&VJm%fk-nri0>2=eh ztMB>U-`}`-y?gWXU@_E3@9lj2(f0E*-B{EUUmD2W?!GV;t@~u2vs|Wm7o&*Z5H<)& z;@2?Hf=w_65yd-HdJP8ungAPP%2EVSZXU~gk~RF?qY|6IJqx3X?TZ1mCG3T+4ww~~ zvskZkdBMeSKkQ`X@FcXp5FS^k%c-OJ2ga-D@?!*D*2Qhc#r>zvI%eBEc%0)`1`OR33xKG*7$ei#+dmti>_K$x;qVi#U*52@K-`1tQ3(X+s znxiO+P51;I1aWyaivoAgMZVECj=qNq(kNlO>ZjxV%_<$@DyuK7VSDCNHzzX9^z~JX zmT+ZW-rTt>;n2c%X_VgL_P9B}JvNEPMvx>BArMzX?b8JL39Xxy`+<(G+O0Q1;N;iq fly({Z8%+>?6ZA1E+Cm@k8{2>w-ZnS9js5=vylmy! literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/stft.cpython-39.pyc b/tango_edm/audioldm/audio/__pycache__/stft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a01acfda6f67baa8f452abf9cfea52a158569b1 GIT binary patch literal 4985 zcmZu#O>Y~=8Q$4%t|;om@`vIyA>1?#16yj`ZzE`vrghpB5e!>R(M{UziZi5^UM`v0 zp&bdUOXMJJ0RuhsT7WwA;Q!Ej|HE8+^1Y{^h5Nj-BqiBdVPDSo%+5T|`#z&|c{#H1 z{OgNt{?ARz`Zs0HJ}xrv;4S|G!!6EoE3SXp(6%jR4V~C!b%_`IXkq8UFpNV}=H!iG z6i36Qc**45yg6KsTPE%0E5p@zb+{I<4cFs!X1#B5p9haD9tdk-4=x-#@dnC5-auJn zQFf8DdsZiUiuG7K9j2qcP|3rr;G>5fTSw!RCtXn}A+@Wr;s7n2p~yRq-sA^GI?TF> zOpCrqx})J(in7c`g}!>eii@+AtM%YXIT@CpPEsklH1*ksHi@3{7L2fB#;w@qEOz+P zmsXE)hr5p&vCBQ~KeA$v2Ry{D&l^0#Z}6q%T5*VvYHIi1?(N;f+bK`SO7M11j)v_! zQViSogi!mbOowUl?XtbE)VRF0z1`2${$y_ptKA-#)M9&lphmL0f1^}ArQQ9sD7v0m zTWyx`)4=QCEq9n@DNC`+I<}|w2@bPDI&mr+)~y^^uX15=tgu044Nx|4uUbc*3I~m8 zLxt0bN>A7cJMz^Mr8&C_5L~v7*-?P{Cf?m!q>KJR4%tB?5|9|i7#{Rd*y#Vus&yM17GFP zsXTU$B6YEfDt2rmwKBa_MJH@}`NTSks<4WFY#oJF!~8ZXYkGxyRYb3bK3J@=m8EU&2YX^(ajyMq7$qJ?Y z2c=58_q8iUnrk;3rhV-h)aa$#`6yMpq9{jF`*~KRd0)F@ELO8!Nna)jIXWU>CdFhp zKG0r|`sfuY`dLYf-JA4!LUxw4TZ+6V*QwV9G8<$rQh%@Prn!i*DHEAd7qcmTccMnU zo@`OO6>1k@;U_i?x`)_-nEad!4!L<}=rQ8Nb)?8?mO{Vi9XOAq(wQ zcDPI|-I||Lm-c1;5W#%B@8K;s00U6M8n7d~VpF>!Xfvch5~LiHqB7Ue%thV3R;~_bH-WE5)QWc4d=PNK?RP!*Ir3g$IL$a6KJTAr>|0e3D~RY>+S5wt<`PFGO3UWJ1T zRWP`4LaU01N5`tzD)TOTl=Mw z85cT$ljm8{?`#;XXrGG*SyyOVjR3mrQ0OIKGbz*II2WaMA7tX8X1BG2e!4L!;KqeG z6!ImSo`9f(S&{a;qvF9^T)sl3VG2Of+&cUym07x%3xi}#@~KqXOI2Fj)Go(wXGy+J zU81ZQPgFuh^7~X^R?!GAHk5D2nk$D3qZV&^LaVk~ZAR3yaZXgBvPw=7Mct0(L zyiPrgxq*~!Ea?3`WXtcsSS)0b9pUwz7GN8)SD5`j--~RQ*=!Xh_J93goxQ|fvM;eK z_F>CFcrj3yXgz;IB0JivTtg?AP2np%gm4MR}ARgnYQi)d8Xe5vH?QyGg>B`~ew*Sf>kD>~u$^cFL@fq(}5B zK{_!83u8Tocwoh%5PH6d$=#8kR*HL zYh;K>ozO5q-ll^0$!!SmqH_=J9X;oSvXRK}5|2U%W|GR*sZSO zp_{vDcckQ~pmXXHZs!k4DB2pDtE5kPPU4cFbs!ueag_&I2v2rQQrB0AbI7~o3ywX= z%y8PMLdd}gaxlc1E^&rDq@R&`_6|I?$Y%jzYl5qbd$lB8GYR-Wz2QJI=9y~1$WP_3rdw=iI*QNzS=@>EWxG~pgi7}9jTVOGpj z>1XnfRBO&^5k1W6{*3aHMzW$>{ItN02wZ5f!;5Ek_EhA|kB&MK{3-t)-jdXsqML&a z4eeCUJxl%;ML;2)0x(Y;(b0nZ*op$mG|dsGlisAz{%E2QGtZ#QM7i~H3_TzbAQ#R5 zj*&{5+=3cAeD&NyUu^vXYe}b>B{z|vCF6=>xH(^NfbuyV`$)@-%qi;4x-Kqe;^K5^ zCuliaTF4Hsef`qTc6cG;E_}-0FqE>UP;>Ae(091Dv_Cjga?EgrArDj6+xJ=dj;WuO zP(3`4i+F6j(vbsw=?ri<=;fm(KdH3$SAKqJJZLE!+)u}%{l?F(w|6`3tZ18ieEZ`YZ(VP{_4{Zs)ZJg~ z{^st^^E2I8)Dur?$lYqcFcq!)be{7Zrg;~m$X^pS2ubo+Fr8=?fI&b>Aj5whK)^25 zlhfD!Fjd1zUQ=>FB`)E57Bw|Hko(k@U>Ex)K9(TOVyW5}g^>M&xRsCkr;+r9Xt+XM z&Ya3WFy4&mZPv!+#m9B0$y#>PK79EcN}ub94~l6)ACvin47oK)E0_1l)OrJ}ktJpS zK<1Px#_5;}UTPNwu78VsEoD4& z<*X%1v+QQ$gRL^_h($ literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/tools.cpython-38.pyc b/tango_edm/audioldm/audio/__pycache__/tools.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21b3c3b53e4643cd4e22d33d7b16dde2c15d7d54 GIT binary patch literal 2218 zcmZWq%Z?jG6s@Z6w%d-!<0NL7KtKa1fFT-30U;Dc2?%ADqez*6u#l{dtL*WkyPc|X z!i?Gy%43)C4fZl$!avjs!2-<(uwfI{HR`Up|kZckA6;)&4H`vYF{1dU}2j#m_H$rpj-OHTi{4CYuFg`=oo zGL+F%F=qu|l*sFsU@qptJj7S!Lf{Lrh?i_pvDMM3C7%m#r$}$w18)yc*uz2;K@s)? zWL(#a&a=VIep;b9}_)H?o3o?JynygnW~=Ew3NE?v3Hcs(&>RKCB`01l@r<2mHQO< ztIMs(MafJK^G?(?wr@LSM4=$=>?9v_(`1kh?Mxy)T6HCrDLpuN#`Ia_|KIiYex9oK zI3M=*hwZ7(NBH-)$!D4D*|cqwY>;=AOxyl1CdY%e&GU)bMoNu=vR%r%uk*BfKe<;z zSkpkm_1RGc#2l}4p4U^)P z2hKNIov$WF9#OMR=$idxsx+|@z-O+~Q98ew$jj1&W_YN`rGF>M#%|3fn&6nIY+(0Y zoB@<1RmSnDGnZ;aA* zm}fV%&b6+SF**dErWLvv?D0Wd-KasTvSTDW>&GiEpzXv~rBh1YFHXH~`0zB#6Ju-uHGXnt-kq><+Fx zx%AiX2RE+$vx@oWb!^)2FHBXExgbq`bwP5+w79+im$wsN%Ya@o0iP&+O2JS<3<;xj zE@E_o11}TqV-?A9uYQ8m585b3U;x#G6i=aug;<13CIhM~8x;aI1Dg!}yhMC>h%XCa z%ZnIqrNHSgIW{RIyqi>w!YknLw`(|1;VQWv4nS~q7erV0{pgwS`ncT^yW9G6WYwy= zs#ilVA1;JsNmuR9vL2BjnYb9bnWWQ+(lxmIefrb!(CIJ8_9evuJNiS^70O9}N>;+q z38PQosMU%AUGo1LUWUqxi8?>y2@vzDXrjvWH?a1IoySiZ1F9{rM82+|A>)dvS@wWi zpo>c4O$MAukw{kps+A(3IHTsBp!N?!B z2#Q8^73>}@kc4idX_fUSG^SL&0BOq63#sFKUK!v*o75km!JUwGZ|&{utr&Km?5C+A z1l|U!2AXwk-u2Suw@f9Fi!{`(EbHp%6G=y?PVQ#E@zHUj=%Gr&r!SGWVrtLaYOxy z)jza(`SP&&4tKtTK{3TAtXIvvC*0oac}ghnA?r!Cs(e*@$a}u3s{odoYN#gWx;mp; zm;=>OZM6cAhT^wa_{I^Uvh9JO>G|bjru%!``C|+PTX5W>WJ|vA6kl@szoj?F$ykp4 zk}0YDCt}V@zNnGcFWy|ty}86w^-|you?UuI(I}eZGfO@f&Q6hDwR_GkPuXQDO0Sf$ zhm336h9jq=XSj}NwH0#PHu}zNWOUvinn;B$>lc|B>{&UOjHdR?)+CFHU7hBcv5i@p z-7&7&=O9{n*=2XUX)y)}p*Pm})E1sjbHp z_}bObw|+EJqpaum^CB*Kbwr^c_S_^J_LFFsj*6K=dUR@QI@M-)@RXS@^8fFf&Alwq zn@a2H=8nn6`0sqvh3l(P9MAGmmR9?TN~*SWx}x zS9-BTC1oi~_JGe>`i}ErW*2PWC4WTRWRFBF5D76CQVHX3w|L+d&$6;=)=g#->S{v6%aab}YBQN9jttmk`~DXy$v&JF2D=$VU{%qj-!9|FHT zlr{ic+(ULH*TY1oC&;ZHpPWEh=wwRldL6_%VF>&Ch|v&@n?Vke(=MMf2){(eH+T^ z{bZt})OtzutM#JWdC1KNaCtrPb_%qj3HU?lQwljH#E>vr7S1U4nzzDD(bfG0~))E5EnMu}=KIbKpIcz3D$ zgj2iGZ#Pgt(Q9O8Gz6jb9S~dJbE~JqY2ijFwl~a2$SQ2wrc*vAA67!1lx@bdbWlJ< zqKOTlkx??8Xw!zfdrwy#2b~GY_A$i)JK_=g2<>Aola(;EA~&aSG`eCypZvdoS5EE3 z#0o#_2oUh5=%TC4HCPA4#^bAw1@)3s9@j_Ej2X{cH`#rBZ@_#MzRQ3UDG=$1N4-&c z)E5P8Kvm#NBj^6S=1%0;X_8y}ZtVkm^wI;Wflsqw%aP~Y+!cE&@v&zVojC$C7N&^LPiB%OO&Q|GDh?jD3 g*|q(|&6McjF&wV5Uv}oW3t!5`^1+`%He_0#wxBvhE literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/__pycache__/torch_tools.cpython-39.pyc b/tango_edm/audioldm/audio/__pycache__/torch_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d12585156ae7b45619271b60b313cb6b7c9b3682 GIT binary patch literal 3789 zcmZu!O>Y~=8J?N_B(;=eS(Yuw4b&os5=G^YOIpKqozzK-0wvNqXb{3KvF0o(O5`pz zvy5$J2o#Zx=HL`7mfMw_eicofT9?=WpN z>I~QGzr8E}y3W`?sd4u5(0G7b{{@9)k`GzCs(G8cwrC4idu`9vemg)dhT$k`N1RDd z`Y%~KmVpdW*W|X0W$h(z*JWLKDhrcJdNugKW+LpC=1oyv;HVcb9^*?ZoTxW z#scDJ>1SE5P@}Jzs$uqj-n+S1jIx_D%f>f9QpJAOE%nWGA_qlB4tl+b9u#@i9?DT= z+;lJRltnSryWh>oaB!4$@OFBurY$kP%CB*?0;?;ieR}z*>2_Ow45P$0kNH#> z_-exRH3REFs2c_wbhfW<8TixiS7IA|C7)r25!V0-0kiMV`P4&iV3_gx0xi7p5uN}* zywwb?H_UQH-?YI|rV4FES=dl!pAEX1^}1=<^i@p5t2*nCvb^lzfquDX*A{RsI-+qZ zX_pSXsvvHv%P7`xYw#00kHtFodCU{B1r*%q32piu5D{huxE&CGf`P5D$KF&-y{V54 zdSF0c0?7}2_0Vlvn85gRZ@KC4bkkv9NHG^vmVeL9XR1elN%jh~C)g^x=WP>reCcf= zNDe&p`pv9ruJm7t(^&*w&{9Ou6IYNAPrk(kKm)y(jxz_TYM@`OQb7PH0-zb$p!YP* z4{W1MRX;+-e8?=g=j0FC*&4TsGRyd1f0D6f}SuSp|GA4Q+i&;&2)N2H0)N!vEkZ{fssX~SZkxFC0Gisgu z{yyLlwpOS?A0T9h03&Cnb1)9GkD%v|i*V5|mtS<&Aw@=QBA+S(-7hm$NrKoCrd*a+o_djoJE zzEPK`*F}ZU>Y`g3DyssaL-T)tsd@_q#rzf*a8StWd>y|{xJdmFBj=n$w5Ja`>H`ty z95_gTkZ{MG;6&fPw-4w^oxcLPjf?O(Uxo;rXZ)U55%ntGqOPF$5Hg459Zm+oq0?8X zY=+hD>v@_(wmSJ_Wc}`NFt%|zDGQonB5Xcj*lV7qgIMMkJtcR2!#mm^RhtN>akf=?#h+_eq{An-^jXwi|Ag+L{ z1WP*r)%`C)Y}5hOg3)|~mhl{_Ve4G3A^{+(kk;wzDEd&WCmT~N#qckoHk)|KpT(h1 zvU@Vagl6S>@LlZQvi>+N_rB(e9(^O7wtjCo&}Fw6jqej3{*9X*?EoQil~a-3d&8oe z4)uK)i>WSHoWB(juNhQgMC#*!w6#*DxgHlfbJ{?XMBmr$^ET?`okt&k_^~6n+NP)8 zrQ%0aPzqB&p@OvCCTG|306lsG1yo2}hs-Ca>)ih~_7yGv9Ac!K%T)J$Gz?QaPIFx< zo3S~1;6Ui~HoHnG3o7B#me+E7vw+YB!{S+{U~qkrRSBX#9S=&EeLunnPBX91ZwU2s zS`J0p;-Y;m6HXfq%DhMioE#^ygc+YB<&Zo7r#A$v0uxd!9uuE%QUaeM8OQU|U=}wxJC=dIh~|G!J@P~g+kxxhj{%J`cBBbT`CqArNgDP{T+=yepP6d?&yNjL&))C z_mhAP`6eubS!lvJN0#wTc#PxC0aNdra2n+aU`OZ^N2Za99t&HHlc$=J9 zL&=@Zo$Vb5m{XswAoycbr}rZCW3xim;cvG`L0<#+s!2|BS0X)= zAa`NR1kR~`%LUsTr#9D+e-M(kHYhXYPJ}i=C_f$KDLF>l08!Y_R01Gu2wN5-8|l4C zuZR3Z2-+|mkF#7>5*?OVt_x*FKDH~}0_;Wy?lU>6q28m%f--wvYFh(oI{5gg3hZ`r z#?!jJdd}5a{TffUujp)uV=nk`Zl!!+`F^`b*|_4#3YiO*kWycyN|a8j(;~(+9ZwjN z-VNMHenEo?!f3I1-(f>LAg8H_pgAyo$30ya)*|Ii38Rq(|g=6U~H6 zln6l+FYUPW5P%*e-7_)O;2q{uP*U*rNSRNSrBY4bt;z|edVm#FQdxHht52zeWTH~+ z_sJp{mcgd~uW&l`bi%p0z~`)hmUzQbKvxk*X9VpWLtaz33V)gw%@%qq+fz%s z)!e9r{~a2AL9Vce*%#j9~tU&kHSulyH=Sy(Ip literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/audio/audio_processing.py b/tango_edm/audioldm/audio/audio_processing.py new file mode 100644 index 0000000..77a4057 --- /dev/null +++ b/tango_edm/audioldm/audio/audio_processing.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/tango_edm/audioldm/audio/stft.py b/tango_edm/audioldm/audio/stft.py new file mode 100644 index 0000000..ce76808 --- /dev/null +++ b/tango_edm/audioldm/audio/stft.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +from tango_edm.audioldm.audio.audio_processing import ( + dynamic_range_compression, + dynamic_range_decompression, + window_sumsquare, +) + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + device = self.forward_basis.device + input_data = input_data.to(device) + + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + )#.cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + device = self.forward_basis.device + magnitude, phase = magnitude.to(device), phase.to(device) + + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy diff --git a/tango_edm/audioldm/audio/tools.py b/tango_edm/audioldm/audio/tools.py new file mode 100644 index 0000000..d641a98 --- /dev/null +++ b/tango_edm/audioldm/audio/tools.py @@ -0,0 +1,85 @@ +import torch +import numpy as np +import torchaudio + + +def get_mel_from_wav(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) + log_magnitudes_stft = ( + torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) + ) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return melspec, log_magnitudes_stft, energy + + +def _pad_spec(fbank, target_length=1024): + n_frames = fbank.shape[0] + p = target_length - n_frames + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[0:target_length, :] + + if fbank.size(-1) % 2 != 0: + fbank = fbank[..., :-1] + + return fbank + + +def pad_wav(waveform, segment_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + elif waveform_length < segment_length: + temp_wav = np.zeros((1, segment_length)) + temp_wav[:, :waveform_length] = waveform + return temp_wav + +def normalize_wav(waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def read_wav_file(filename, segment_length): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) + waveform = waveform.numpy()[0, ...] + waveform = normalize_wav(waveform) + waveform = waveform[None, ...] + waveform = pad_wav(waveform, segment_length) + + waveform = waveform / np.max(np.abs(waveform)) + waveform = 0.5 * waveform + + return waveform + + +def wav_to_fbank(filename, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + # mixup + waveform = read_wav_file(filename, target_length * 160) # hop size is 160 + + waveform = waveform[0, ...] + waveform = torch.FloatTensor(waveform) + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + + fbank = torch.FloatTensor(fbank.T) + log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform diff --git a/tango_edm/audioldm/hifigan/__init__.py b/tango_edm/audioldm/hifigan/__init__.py new file mode 100644 index 0000000..e227e9a --- /dev/null +++ b/tango_edm/audioldm/hifigan/__init__.py @@ -0,0 +1,7 @@ +from tango_edm.audioldm.hifigan.models import Generator + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-38.pyc b/tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e0db319198ac4d3ac8fbc10fbe64128f236fc0a GIT binary patch literal 601 zcmYjOu};G<5Ix5!1k?(_%m8DT8%NvB{e!rrxRv@LvVOWxTwxh~u>B^q*}Adfy^B9u_kdy>l!h#_Y*8Inuy z?NwF!Qfxsm0{a1ZkK)0L1-wNM(LyYuPg1v_52XS^5jW#F3|u$Y)^b-!5UZy0h@2?` zM7s4!YfEVVI+c1WPBzz=Ij9QfMI<{}D-=xdrve)=yt-AGQ?l2fA8##!zN ZF;A)Q+Jf(i7(OItD}CquNIl0;{sA6ikZb?| literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-39.pyc b/tango_edm/audioldm/hifigan/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd2381df60d4e38d483db28feb5589f312583bb3 GIT binary patch literal 574 zcmY*Wu};G<5Ix5!3RNq_!on(NEKegWbiHM+C=#MxvsQLo7!fy@z3NAxnoK+Btk(Fyg)9LP{DJQ$QKA9VK(Wb z^QN&EMb>1@+Z!{mt^@K;R|AdMz*DdaHllz{@Gf3S;f6kx3W!NuPC`F$?OI#QfBOir zsEP)WGerdusa)u^$$gmGh4Xv2_q={kFNm2*okRS9pIN-DOC2k%*YTCDS2}B4oVKc{ zi6+lmSJc%O8&$^3A}jyia`s|Kd0C6N95Bn~s{k_S6byaDpN#PK&tW|3` zA0ghgso9=%RDQgZmuX`V?TARXfgje6IrIZ_0ij5Am?yGFyxKU++dkwmbVq+UIyUC; O(dkyi8NWnFG?Kp;lZCbb literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/hifigan/__pycache__/models.cpython-38.pyc b/tango_edm/audioldm/hifigan/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a0a6fa95a112d0f2e4abbfcc1ff53d33724372 GIT binary patch literal 3730 zcmb7H&2uA174PmDjYhI8f2_TB?9EbaLO>MQc&mUMQrTsf%~HV_Dj^4^LJfm@Y>hk` z$vvZFS4bzhDpeG@!2wS4iSmU5H~t6yg1%84Qm$M%g@gIMo{?mG*-*5q-*mtJdj0yn z-+TICak0u!F3P_+{QWv(|D?{Ly|U?BvD>M5l}^>GqHlW(a#@z8d&<3l*z@AwMDx=w!9_q z>T(IZr5SHo^7mO|^#by+gND$~<8JcNp43b9{kaN~wAW3;J&B=(kHhGwmHA0mb@ZwV zJKc{${|fKuPk*HHw+ElmW>h}=&Dvj{{$p_ZplRg{LrDo#AETz<1R1h6Yx9vcv?fCG zGcLukrQQJ7YTJ|2kPpR~Q4IvO4zx{@U0k9x99{AKW*nrc?`zlhlb{o#WgkX~)J`s% zwq=k7x=d0BvEQ&Y?`XRd28p)POzMhRPb<%vT7xuSJlbt_JK^qeH)^(`-Ja^U!)CUd zcKeBJW}V$ENRGOGC_B4BUq;;+jaGCR9RHX(8LDUR+MD)D#xKp!q`uvL8x6B z#X%Nz6LXWCzz&&SI0`er7f2Z;N6TQO=47t!pxs2hR_Jl4q&pxX^SET5kd;r^A(xi4 zPb;1!OVT-Io^9-od8JSAqMif!s@i=oOdrMF=J5sH=G;;n_^7)e7bTD;k!hj!5-QD1 zU8p2u86UE-n20lY0G^)tRt6);oUA+)_*s$ugEQ)kR?%;B>+XuV{P zJ)48pbk31=M(0*wfR#x#TR7vK-IPIcI$!eXnyY93KV!vNHJnwOo+TOdSUSk{MrP6w zY8xM2E*?;A5_3_P(ta;g>Qx$p^&6A(HE(H=B)a;OuI$I*ew1d49#qkboA=gE6CRZA z=~B?^#m|j77QFk7lD5+@K2)!pO`EMCfpMj3ix!g*=UkCF&rM6chA!Dy`a2K?`*yf3 z1fTyLUi-qn6xBILV9rZ8Epcb6oNEq8#z*xrYPt_X z4%f!hx5Zfekp1$_55-7~tjwC&L!On+#K?vL+G9Sl4bLukFhk-EF>8a3_(^>TM?L2U zfG2AZ_0-p~wQ7K9{*JE3VQ~E1S7F@O3;XXq+5fuoTFb^SDuDRsDiAkuGv}zJJE$`>-6BI zKsV&=(wM-<$zVq{grGryG83z*T!PZ z@sRJriY-)oRKnA)OxCg6ZL?EpXx(C409ovcwiFx8^h>a0R0N&U7Sanki6HQ#H9Kb` z+k|}BZ49Hq8T*LcH+C60g)E5Q?N<+YH8$R^~Pi!{=|n4PIJe$Ipbd^e>b;v zaLemA-7?Z|XLo2G+FO&NZQ#;5xF|!l^H|g1>Z`;sQ2DC5NA0(Wyhg;Jn-zzNx30qU z*A%V%D{(Z{3rW9Ia8j?{@1;Sf2P4G8hN)hp!6`4DqpRoAKnHl~ZBB;`Do02>3Mkbs zfv?|Fp|L&nH4^6y5N%?c$de7*JAKG(!^fZSms6`7N+Os|mY5)!tf(A$`D z{AoU5(gSRtAKpbw_y!$JnMk8zT-fNubEhjck2LRVehesTY2MacD&tDzLAgJXCQne) zw?T5?uK|bzU%PZHTU7a7z&wF|jW3&T_1_2smrkw9SMdY?a@pDvn;0Lg&fMnZSWOw! zV><7zAVW5R<(_c^U4+Mg0>No7+^W^;8mFW2#9S-cI>J{@LX6T zMv!dT8?0DiojgQ}@n6TjkG+b~@-|`qT7E9c1Ia3dWYr0L*6>@%LgXokE)klvw2jQORm2H(pGJQ~qvBbr_NixHX8t^cjekH*DaRuK zu0zpvZlk(lJxgws)Xnu0}C*Vf8&CH?BkT{*0Q|Kp3oR`WxJti~nL!QEy`QmA#A&n{YsA&_gZ4LG5*S z)gT1I5LuccJOPSO_S#rc2l!~$Jj70t7`N58rrn*xe$u40JBW34-p=3F5{Z=Rl~Z{?x!ZG;n+u7Xe=unK2e7PeJ&s`lE4 F{{dm^GoSzf literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/hifigan/__pycache__/models.cpython-39.pyc b/tango_edm/audioldm/hifigan/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d4b61d15155828febf38dfb2608cbba9739cb61 GIT binary patch literal 3726 zcmb7H&2uA174Po(&`6f$kG0p1ce7MBAs`BDyaxhRuvxNfwkmNlsP5mkmX$-ka{%U$0-k z{=KJYe!jv`{`AqN{M{;J|D?v{;-c{Yq<;e@ndAv;i*V2w-ay4IPPt|ue$BHndz;l zFI(MEw)Xb=y46kdb|O1l2d%x#nWNV5QFp-3J+LORNul}zh(_uFCdpXF z2W%w9;uPMGYiFU8!2vQqD-8sGj^qPZikDax44k|+64LqHAMtBr_;d_vFBs!a$DlSD zv+tbJzGXOHd0ffnPB~}Sb&#C)mwYnk^4|YXU$IvedsQcUNd`N5ySd%SP3l73!pD?~ z3smbQoHw5C_hO}9rapMSb-B>+10&MZRNm{#eiHArbf)M+72UYKZ{s)ddi1vOqFyg~ zX5F!n-K%@X)p4??UbBlf526&#rPT&aCMC|eB6ps-mU814a+7jjL~q zk@y+=#ScFgLoswRXY3Am=ADY63kP&ZeCS%4Tgc#sBpYDV8XNMX+5ivrjPC-Poa0tc zeG5ygI+)>an@SQ#htEP4Cw()w^W@RaN8txMKY#yG-BSdEhCel;xedB!#MA8hGcN6Q z)u)k?zr;{o17rLquL4zwD*SkHdg5YWll*5=t{}{JA;nP{4q6>@ygELDpC9>j{A;oe z@<{bPOsG5gX~=B$NN1u@JGrK+q9HAi4!{_cqu=qZse;eP4!Aj?g!%^9SxL>|V=7?? zP=y}@(yL*3)Q^%J;H}%KAJBN($;=l)pcyAg7+z9_&Z)>h?K>e{wNMNoem5vvSrf7= zh0utKrnwXRv*?GgbP@pl2n7lKaul4QU&)jo$7!shtgEj0S{WTd?Q5$H{X!SS?)FQCE_b0v3FHj}%w8L>;yv+!kKKZiYxnr_=cd7mpVy_aj=;kDGYF1UOe?`$Myc9>Z znM?bfLZVHr-_ud22PeeE#@ftN=R~Gw=-LRVqXWD&>yvJa$}J=v1(ceAz&Grv*!rIO zCaLo}m~k-0lw&MG{u3(-zGij0&K321Aa|TfMebzElOt&3Az9mx&DunQWF^lBYIr0;{}z+VLr z3BGdSJFckkJAiot{VHFy-^#xc1}^++g)id=`Ng8MA=c4BD{-zT}_H{CH4-n_rQ!tXmViOEJ_cGIHB%Q?=JOLsmc5zT^$<4x9;7M9BLJ?!&^7LY z0=^;cpk+Z>952m`|3AaV1FQv`d`;LG@G+<4?P$YV+o7;xYlxEc*JWY}FFz}xcYaoD z7&@?egV?p>(6~Q?bQO%jyC%Qm+cWi_4=U8=5OL@c2?@%ZbsUKX5X86AHO&IX{Ay^ z@cdr{kTaOEgL9cUB%ck(%x5bu5QI@&bcz^JLJewC zixy~++H{1L=qN4gH?Uycq!o}yfv?i%=qdVqV9+t(O3%))@&X;F6QDT><20OghGMXu zqSIjW;(de8+@Er6dX~O)7da#J9Ie4zqUY%am`CX>eVM*;7dvH&*HHb{Lx=~p6rMXQ z0CRW>Nbu(MnLmD6xVRF0_pN>oSAvJDKktk`e6$k0{?+79Km7SkN+*QqCt+KDVFPDK_J%^Ao6W@<^CZyk6Y&N6R_o7ZJ&o0A>DP)@3MqKbXw~3NR_ry9#L0>JFGsXY|g|w7}d(!XZ z8mba8k0)UgxZIP>DSFh4-SxR4Gq3E3yOjCfwlW1{R8{*u;l;g(n14IJwGdzj*Fh&yZlMEkVN~X{*$pS)!Lc4a%H)8O-dH`YBH^H7Ns7owRsxXW&s02LcLuY`f}%t*8Ke1C#{>U)y3Lc zH|+HwW1zKjc2?BVzO1EfkZN7FEovTTwOHea(Rxu?Jsu`fSu*8*N16HvJ_~74rd9^X z01A~6v7}y5R-956Ii}z3$y~`9^ZT6Vr9F1WsU~d8{m(gtyAKd=z?Of#*X*P*Yj)Gn z?}W`BPj53{HbvS`s4wHD^pYTT8I7A>pN44!Oeci;c}cS`!w5jgM5DK@&IU}n-lhjA zTL%37OyHe-OdgvizXV#oZsg&k zk)nG)N5?cU8Xc6}MSumY>%WKh@Gjs5DrB^tG1XUu&|?xvhHh20>OML^pTVuU0e9&d zdVpJX41n)L3)G1sjZ07hk9%>Rq5)}wwHjFpSi;`rN{WO&Ux2 zy&{$y$Lg@FA^Nc@1Jiy(Sa%HkjFH&Qi#pxzVP_d6G#I|c>P~`3i zpU({ViJ^NSW#k@d>-)sW3_uNOaeWw6L?>Al#YUeoxadvC)R!XH?x-U4TOd2aF}J); z@wE4qBhTG1X*2#Y7>N8wk8M&lZ9Go+sy+dLR#7B{9=iW@l_~yR{N~l3iYSA3c zQ=5*o5iLB%Kt~^AT7#h#EB{p6Ta+>5lx`yK2oMe-++z2GU$&2feoxQMifzAeN=^4i;BZw(BDv^f*2 z;Grw+s24izq$gxy)lI#Q){v!uxjgctw!__&HS;cYLuYkn$Xt|mz#YmwcT<{zF)GWw zM7UuRFl`{@#DUypJYs=%E99|*Z9no;-whnE<3{jaP7EV*4jCLh*rtBqrhXjldnyXn z4q22A;?<`JSs+Pzn_e8X{I;h6FnrWyt9ccs_!>wGg^}S3+QAs5c$Z{&4^uLUwn>I| zjXkt&_-NbA$ga7EuAw_0Y+IR?8A@75oMuL5O`=_E4`pP?bqVZiDKj%`4y{eYH$ZRj z4K0~zE;E&n9SofIy@Ok1pplsXf2*-Ib>+%@btz3*m{ila%2^P*wA$ivSUn3E5GmZZ zLwzsZI;J#lFMZay(OAA*UFrHs0vQ9XtrKTOHSVR=xCK(R%Qi*T<*XVi{4iR}ODo}i zlu9d&xz~}V`UF1cfajviSAQ)MW*oO0O za1wW)AYKQTe=gNKamebFv7~;B$9I{Rin`mQehkNGwR!?DGmrxs);m7jog3AAsUH9U znV3#C<%u>+9e3RYgslRC;Sdl^3?~3AZi;j2QqKTExQNTx#wIS3!gE3@e-pDz%9y~Y zJSWdhlh1&brwV)cD3qwqppkJL7=;W<9w5L0R(HRHcknh~1#ZV^{Kr%^4WVWQNQUl| zmFh0qL!Tpz)}YU>peMLd!vOXkz`sT$nZ5w&cexwt^bE)au$CipZ5FY637=m8@pqVL ztEx_}wHh3YcF9*;8j6J}MjPFQL^_UHDR5ogMM-THP zGvKF*9)gsShoq&tkdYaHH?yIC8a}!Y=&%6pGqbfkUxaY@TR>z!ik&uhsmzPD9%HIx zy`-n121S!|MY@@=nkh}yc=?Cg1ZujScA&&J+;#DCrb_!+a{Q>p_$Oc>%0Q6L(s@!L zHXbMZ7VOjj)I~{RaCBG1k&cJogcp)TNpnn>H8xvKFNC2w z2G`YK&q3em1#DhbtN0iQTowdilcJ#t`o)pMIgoRB{Lt=!2GQJK=C8Ta_PA}?BmV%V Cb#V3o literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/hifigan/models.py b/tango_edm/audioldm/hifigan/models.py new file mode 100644 index 0000000..c4382cc --- /dev/null +++ b/tango_edm/audioldm/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/tango_edm/audioldm/hifigan/utilities.py b/tango_edm/audioldm/hifigan/utilities.py new file mode 100644 index 0000000..ce2a685 --- /dev/null +++ b/tango_edm/audioldm/hifigan/utilities.py @@ -0,0 +1,86 @@ +import os +import json + +import torch +import numpy as np + +import tango_edm.audioldm.hifigan as hifigan + +HIFIGAN_16K_64 = { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + + +def get_available_checkpoint_keys(model, ckpt): + print("==> Attemp to reload from %s" % ckpt) + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def get_vocoder(config, device): + config = hifigan.AttrDict(HIFIGAN_16K_64) + vocoder = hifigan.Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + vocoder.eval() + # with torch.no_grad(): # TODO: we need to figure out backprop stuff. + wavs = vocoder(mels).squeeze(1) + + # wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + return wavs diff --git a/tango_edm/audioldm/latent_diffusion/__init__.py b/tango_edm/audioldm/latent_diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-38.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9fb99eb9a58af77b8ba5bf58c1366e3aecf7a69 GIT binary patch literal 168 zcmWIL<>g`kf)DJGX(0MBh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o6vPKO;XkRX;mF zGdUwuzn~~TD>b=9zc{}%FD1DoSHC1NFFij#H6>R+u{0$!KL^OjNi0dtD~V6ZOiL>* k&dkr#kB`sH%PfhH*DI*J#bJ}1pHiBWY6r6BGY~TX04KmJ2LJ#7 literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c3c3cc14c48e8229474b8f7ec8876ae477e1abe GIT binary patch literal 164 zcmYe~<>g`kf;pjjDIoeWh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o6vLKO;XkRX-&) zwLm|hC_gJTxujS>u{0$!KRzWhEv>XTGd~Z)&q>ME&q*vv%_~8YiI30B%PfhH*DI*J U#bJ}1pHiBWY6r67GY~TX093#!L;wH) literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/attention.cpython-38.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..654e4ad5f64585fdf75e6c3f3177b2a15651432c GIT binary patch literal 11470 zcmeHNTZ|mpS+09mSNDv^<5}O=TYJgk-uSW?Si;7b&DuN5#f+5LX0x4adTY9>r@N;w z<5M-hG(7~c7duFf6d@22yJ3b)z(`1W;SE6~B!qb2k%xLh;vw<_LgIyNNHE{`SM_Du zV+2U>M7QeHslU$s{Qvh~PF|a3Agq?`Yvdxc`n8?vL#D{@Z4C5tPO704N7?%7JhxdIFgq*2+Gm$7%0bb${|pWhfjm@bWS-8%8Bq9P@c&tM?mqzXF+*3 zryLE9%Ub2Rhg_$Mo_ejgvDvF9t$uF>k4CSDw5Xy$si3zOu|&NY)z>%rtzJ^cGLC=D z8@Mjuid`f_EbDslvi3kp@2q@IsRBubgf6yV2S8eyiSWo!wCVc2rN!#{JD+ zSWmiV6LiJ9yxQMr*9vXKr(u>+{eOuenuoULXt}1Mbd~+FB%qlGuf8X2EhkB^}hH=e875$1# zsbZwYnU_-|nwq96qDtzW3tRQ1-tTUlUtl{95(legXoJI>z1Foe=Q{m*(238VnH0o| zS?kfQm@9|XpR8Qgq{Vvs*jPSUHe8oHGW;>mOE%;LG6_b0TUX^FCX-)Y`_5 z-Zo3bQ3OIBx#p7m!3+70gyO?rIHPIMxM3nV|(>uR{3+mKJ#LZzc}_8ns%z6U1??mTeyeLA{8CSIj-(3Jm6D&}T-S z?BZZ{INm%~!zO`O!921Ip?TlI-3{lcMlw{Q7`C5wOZ=c zYTbUg*Ig2L!dsDU^*nN)8ujVwWlJb^PZj@ayGG*hiqf{yQ607&M1dF#h(}VA#Ng> z<{t2lxdc4fauNs*aKykqSn;{U+O}~yc<<~O$Q5?LUp+T*NkORsOAMJ7hxU#xWgUv$ zjtMUG$b}g!loXdh7A`fRZUkR4F7z%(*EWg$P$@rxLaFu6+n27&BBh0ks=omqOwC@D zR9x9$;Q{q&DcJ>brRG{^6Wp4YEVg>RNI|o7t6r{&z|BhbLw07&yf34Pm?%T$E$g0X zJUVK)KQtUmH?UUTqruYD>h8_flrJvn1za(8)QGJ|6W7(7paOr%o9b02GK$Gi$VcaK z<+wLEFx_d+mlT6HP-a=d7Shy@GMO|p;kBQ@i?g_dzg>NPUXij3MR$#us)Qrnf!2I~ zfr^lTvTYDluDODo`rw@B<5!JT-^W(?Ni*^{0tFd?jKzMx;Rk*uCw|ZieHA5}s^_BqwM(!%g>PSpK<@871tZ*lm5(AUm;fah*(D@kJyP7=Ykw zHDd?>Skktv#B5s-a*HZuw1RkVC+4;_#G5b^aYqyl)|CyaqezyaFRay85LXIQGH*Ad zAdGXhbD`h6{%Qy<+pk?)zbhtY?|0Es>QV!-V(K+pVHou&)q6=Y*HNU8lI*~M z0k9vQS&+Tums>(X627^P%usJ@U$(Z4EpyA-vLPFgi5&u?F?7gHZIi2U-}WJ!}1ht~BV5GnlStSz%L_MlApbqw}HpAYj z)n?Q?eID6J2oVUO5Tf4TxtZ433HqMy7wj%=>&;ZZ0o*-?meqSq$V}P0Lfmg8u--&4 zV6wM*@n%<5@xD@?pk1o3rg~j{h1F_I-ba$^*HRst`+8ca_j^fn6Ce%_m=Pb*up@8O?gG001K0;7qA`wr<&&%lT^X!q+1b!A2DUXRLnVrA? zjzzb0c_W~Xy9^s^ETy|lBqB3;k*&wSgk(xK!Fe?U${B_}a#muuopu58N#hB!2#y4= z*1&;#(u+@^gGcu)UlOT?&Y?gnV(SAlEv-iixLAt;XQ>k>suf0QVIv5`R&NbkK34FG zD=65OT0H>Z1Uc0g`6gr%F=-%4U216>nj(y}s< zvqA$oD{;09p$W*f2pB^Tc2(iF9$E)9K(LoM?PA+&mxdnk{=Pv-0d%mYjhq)2kV7eH z1J7u~0Nnu0AZl?RPZ&!%A1TIpi1xI%^8KY%zdiyN$~^$ZuMc)p>luAeUTw8ct+v;J zC9U=EVgc~`hri5Ba)F!qRlnu8GYDz>t*n`>O@S5ZsiF%2ROVC<^=B9$0~f&+QLzDw zoMS>kNWD=E1Loler1p_iN>d19G;6a>uCidA$ryN2<3{eih~Xqw;A<%Hi@4(FkZ5pw zz3guiYT(it3;LodIf9QdIKC(6C8r&2E?)m4R_vM8@e&b&XZbeGoUTL;szY%mg@#)o|zV0V| zpHAfIU}T5*H=v1ppnTY3bSLQeQFk>G%!vI!MSicJ`0++mZ#7y`nAwcW>4D8Ssx{zp zHf!USRvyxSdH5z0v4XF__(F&a#{EVos7G~f>*&YbWodFy{C7yEpwuChLgus@RWZPn zM$Q44!8OW`vF@r@ManJuMDTDjpoN^mh-58SD8_I?DLzJ_RF){2Lx40qCkJu}|44lE zxxsUk=gxe0-h{&GFTBgc@ zDHV7N<&Np|p4OlV0}kB)_)cV?W;0;Oue zK5U^@=kod0q=&ThGu@58O5RqguPQ~s%CrbCTiv1wotAQ|OYCVHdg@C|E->N9s~9kr zt|wjWWpIBJg~a&L7Cpm(B6dvAMBEV9BLC)b+px-pYs?$-`ry!RV40{|g8t}7aUHAL z4jvfJS}#JOTJgeI?a~&{ta7SYh;puW3-ni1l|sZ?$XPL;90;^dIz-VPI<&rj?pK5yL;V)y=u-h~p)mNF2pf;J%wNbyqWQZi)H^WQm z0u8mGGeO%?utg#kYpjD{X}n%_n?Wo+&#GL}s=HO|-=kXWBhdCEd+Y9tsDc1?7hoLTEqGO)t38zUh67r<+iKHVL=VT8qeJQ~7gg>4Td;RWvK2!IyB(?uLMc^?|P zQ?~!1(Xfudq}$QpcNVwj*e>WYY6N~wP_^xOc%S6xQgC%zML>-|?=KUuCVG69E3ZvB zUa-Qyj$@oYVmzx6HY5Vvc3MdSYd`ZwPd$l6m+55XqqQ(+F9Jvh`mWoPc)e`<+0oMLrgBb_>LDU#aZNj9}Woh}9!P26b?~ z72Qy@PBW49et`!hg6OGP50bRli*D2yBa20_p+9)gQfPGgLGsG$R1C zjVHVX4B)N#gtsPxS`dEva)bt(0WjPh_2)wib?hUsx{hj4JQQS|eYa`-IrkAD%P8kG zT9HGD==8gca!%`&(N122p*8eYKur@O>kOT?i)%cq(3Sa6wG|O{rzjVLcOMhwEV4Ky z%l|LT|1(&cI$uJltkn$R?lwcl&RT7;^u}Q|@r30SO2MT)0-mgR>eq40#8-rM;=^w7 z1y5*{n;DTNmCGPRmPr?IEJ-9`Op!3Wuqg?B>7-(0dBxTeST6GqvYHx3(v zXJ#bqF}-fDs2d!K;MHf5yMlAVmCC&OF;ZV=^5aO-Lai3|>p0V*>41oCzu!ss)oO^1 z_cQaEr-3iBIKAjR2zW)#O%*=*2&v?fVpOtctdCDWZF{X7g@{~hQJYvLRB}875PdN0d1<#ltap+ZBd|d|uHJ3j1 zRuH%9SMVW3OlKidZ{gH_UHr%>q;w2X$2;?W2~ z=ThR%9>wOSyb_v%sZBn6MGl5iqtVEoU(HW8#Iq;1q&7L{mF)Ivc6&_`k4KiF&(6mB zDH{w2P<46w8rB*ODZZEc5R;9`JDN8N=^~Pc7~TTIxugHTxrH)RNm(D9*mJigegyYs z^kaUkHW)30AECji_xZ_08xhk6R-PwZc+zg;RA@vW8Vv@_rru1BFHM;W&W3ZtRK1Oc zA7ls4*`eMV2vdwMJWj#*I;YCdCHP&%VZ+cZoWD4l^=R<)-jm9Es6TN|7tt2R$xF^A zz)AWcDnO`1&=4aA1-{dLaA!uxiH(M1TF}R3saA{?9H?{L6-RxHo7!RWyG(wM37tIk z`%M0ji3qSHGPyqh?YB5N5f*(3XD&-m-l=?6U}Zwi$`u4}@J-SR&WLuy&J?1sgD2z1 zM|{if=aF%p96P=@f6Q|g-_>#L`dRr?Obi-h24ia(Zt#PzBG3Y-G&5*XP8ajI-xOX9 z{z???6vba9gou&X>IQ4k=}v1s^1~nr@~Boz#mOl@Ka-vc(3}j`gGrR8BJE};shR(W zn8xZUA64`4*cC?9f{l&NEgz<0FT^(uelkHaKafwpMx7$$*12^G$Fk$bWOUL?7T=d0 zS$+)kxUXcw_zsc@=nV^MG1R^WOAHWY+_UcCfbeUupr(%|hWb4^U6X@}0uIC6yV_m- zu5s7AYc=3-0jAxgIJ9o-7xCqRz@e9vcpz~PZB1&+fky^~&eVZ~YGYK79c+^`set1X zNY)@C?;qUK6YZYKao#iTX`y!C{zm?#*9bv5Kd5tZ4M+_6tNlnG*otq&`p(epA|D0u z#$tfa!8-UY%9(!?cUBDG{uk(Z1+`UIY2r4o^oE1PPykW0}cW^F3IJD%W2AqaaO^eH|Uwd6j zm<<2$qX3r6pb8Y?7hV%n?ujqomD!!^w#vL_H<8Q z#;0n0X?h66yVwzOq<}y|Brp>oU?HKr@CuQT@_-PJJk%2s50M`b5$cpTo@*JV(evshDPec*o?CY%?{xiMxn7pM+YNe^dIfn;l|C@)RpqNP z?h7hV72Fq9RW0CtKrN~RxG%kJs3mpq14A8*-Oj-~c6}L?WpxOYLj~oKT2W7-?20<9 zj^O?jN*qRC{pEhxv8 zb=jyr_mJyUGqa$bw6^-qv_0sr;?e5&k(PBFX&v>i#VpZWkDD8tgLXe{Vi_kt<}F-r z;7WWXBP{EddD(biu0CMX_Km2kOL!s4_kQl&`k)t|Q*pd`?ici+6F1Z3T(qUygNACi zT3ZQ*kUw`-?_4)Z<9<5(CM5x7&Tih$y!ci-Ns|h@<`@i9lM|jizGb{srm{Z3DBm>h zm^;R{!W_;Th?_4Rm>|gotl+dXJ5&zcxH;aj4D}h zUbUNPbI{wou!QG-#bpeSZuQ&O&z|oNno&2oaCTY*>*Q?2w-c@x)_l5Nd1DUi9Y71? zX&ywrDN|;<_%Y8*w&4^qDaL%q)YTEDkz(Y?A4iI(5jIDAC{nXzxisq`&&|tBHQ#qh z#jR+on;ydJgw+hwHyu;+dC#(u6Kq^E1kO1%cFd8{v3AXlT|virr%vXk1KnKD+*Z5a z8f30YZ*Rt#-HUEv_GUWY{{mh#`th{}_-HFWhIYq`!2F&Ouy<~pjNoNt&^z|~n6ZsL zPt6@`WbT?LjBRUV?ARl#VvMX^%S0*RfrpRqd&pVTFves(!l;~Hl&)t^a$TpHk4Rhcp8m~P#9)lF5<7B2c?N?&yVBgFL?n;Y}CPK zb?dItGKE)t0F$j8z=3lY)2q7}t(&=TzxDQ|tIcs$R%hNV?n#JDYTUsV>;lbF3lla+ zmNHk2T};2>C{Eg1eJGoc=8~<=SZ98tAp>t{HkNsp+I?U^=4?VzF>NP_yRFQ_q){v@ z@jlo}9{?&Ejb=AWl13xnK8k>T5y?kzCEOFPz;ICpeRkZ*9u8)Qlg(o_Y!Y}C%p=R7 z?EBWc*0zbwGi5els};_{S=B#+()u`(2c~`&x0<7CxMgm0eXGBbnHMwXTHM{L*_up| znYZu^V43H%5_b1yq!kMlV?JsQ7xFO|?;dzuqkhzj8;vYzGIkI?%ATLBA}9$LXUrw4EPUcB?}IY< zq@?^kd=g}pi*c+j<`Y|e$Tlv1XWVR4#JPB)`19gEBzu zPI5kVc3fN@-g~H z)`Jk+5T@&KxEX242jnXW2dyv)a~TPvz6y1mZs~s5PR~j;KuRzz4Pc4Wh?-%D<{Hqg zUS2gbH)kd!>a`+8^4DC$Tr)XUV(g{)SrOC8>aq{q&(!Fc9e^Qh6^qzk+u8n`Cwjk{7k0#9{yoTg|z({5FfP zLQz1SRoeYYmH8=@StjbXhw;y&a-mv1(|@F|;JJxw&u@S6FNye5Jy84!1G8$v&kp1W z{>ib~B#fdzjh36^2HA^?iz{T~O5Q{=6@G}j(Xd7k`xRr`N$rjU(RQdv#%qW7Zffs1 zBfL?$V0)rou%c{GpFnaETEbatM@g+TBkAsX9I2#GHg632H(pcFt%Jt(jT>?YOV!*N zq5d3N%6#e_R?LF+wo-AQvb>-Ceg>sQ?kK-lSrc%6bZ$ZRm)t69QiOzKZXh!m#`e6tW?;nh3nzX@XUpf zw6!rcd9zH^f93+Ru@GVqln~i1bK zV$wj8nb$KD`uavzY7Y8od<)PG4ww@h(W@LraukUHi4RQ63?5Z&O8+AA)kmH^JTkx0 zW3;rUGVo-qyo#28k$rGtzPVV^q&y~+WN-fdJC@tZ<;{q?>@sYriCpe6g^0rBKDM6x z3X&OV1g|wLC}9}%$T_Lo@j4|)B#k6UA9xVl*#OV&OD{ix4zAp{d`09K`h_B_iK7qf ztg;bnU|=HwbY)(W>b8or(q^PoyMGOvJW=HOGHVgdGN%u)o1&zy^G%D%H72w-GN0O* z2IW{GO7(s8avP-+YA2Cvi?SU&^6mL_`1F1&y}!WIha?D!z<&#w5iLz9KiDdH4%Rbr zuv*|oFy|l#P(;oh0TclaUM~76o)s-58#zZ=$T_LEQ&KiS)FA|n0@z2TJEn3D835HF z^*ZHF(5Z|9BKUoaAOSF7O&7U9m5@UzX#>w_!vfs`a3JPzA5R!dwHPVJd4zUxu=>4~ zwXivc3raE|B5V%#Qkyw_SY2y(POo(~fDY}A?_dG&`-i{EO=|&=#kH^8WOlNY@thChE^^a1IxO6f1m7Y0@-q3NvpUfqehN=zDdP` zF_ZV15K1!ZX5ocUhH?{|X4Eg^N}fYvzKj2h`FBPQ5aw@hK!;jHj~2H}oJ$CfGB)CC1qE&p#i^ zi*WKLycFGzwYr&xkKLuo=Og+kW(0h ztmO;E8ciwX$0(G_3MF#{5QYcjK@Q>Ph>yK6e2((mUF;p$_7P_A@38MBD0@J3XQ^@U z9ixBLG1BGR-!@@VBa%o{Yw2@F-+BdkY*i`R0Ord2)5?BbY~@hV>x{dvhu9I?+gKM^ z-jmf1iOchK`eHC8lRB_v*J3U9`D6n{)6Z8_QgKp`ZRb|~WV^%$dAu4q;jwBQ?*1$? z-@tUgU!QPILLNgCn%swQ1>~_OD=-~PVSJ#I9u@_9b^vc9=NnX-bi&n4q8oA8i+clo zdqUREdNlHVMU33Op8gs$K)+09C;lB#?tkri~S7f-$x-aZH(o>@}OotJFpS`!?n!6 zebTj@s^wdYR@EFH*$XREMN7yZ|0t})KT)(@WEhm%qB?bwrHRs|ZJk@$RIm`_Lg|+1 zf9N_zh_#S&5@EajOY@7F)rdJ>S zDpn1<{RP}K^ZM7A&`Z%bnS7DSHj?bX94}={G{d6q6l=#J6Nx?0u42eD*{J*LQ6fFh zD_qg3dll>7qngAUjS?fnjwsT=aUNNg|DTp^{lIek?_1uYJv=yh(i5~|43w}JTU%-3|MwhZ+JOPSfn9-X8h{Ib2uk5;@l4l%n$t5MUqsF%1 zzaUf>)HOv=c7^(qP(4s@!EYN=ms1?{@agvq<$uk-V}8*b`TCm}0l~inUYqvm@478Z zm5cEh2!54z0+fVTw`(GdSq4v+aZnU|Xzfke{)bk}IR=An*MJ{c-dSL~pv$NczO_L$ zb{64HlA|lp)maq*CH{PPkytU+;`3a2eai8I72yqpzy^rcti{-n7;xKdrzvdt+`Bye zBo*P-&$P9lLVzu#q(aDYf)lR0yuQ=8cHPN@hiTc-(3@bl_5QHe(@P*+iXqQ3Q zjr6rR#3IgbfjHu;wF9{cOzRgBcf`7&4sNvLn|j1nMP!4|@|Z&qJ+qrpnw9(U%?6`g zi3m3I2aiiit?nR7U;P*r12S1j$ATicqaEZ%^PGPbOff7gVe=Z@xSyGg-1WBMzv&K# zF(=`7LHie6xgIE2ofYe_BjHenM9r$TY(wVFqmMlQu=Nq^usu9AFXem9yFJ3q&L-+$ zsSxgK6mfB9KeUb+UuSs5`iiOSub7`PzKp}1k<~F!#`+;;00>gX{vl;-2CV=F^y~-~ zE(2V+d;0H24(hqbV1*qwpol2gCdbmDE$H1x04%4RH)vCiAhNUXa>{w5Uqw4b4TjFp zTM0F7h_E;EIzHy$&5my}Ln^#7BI*=yKYHge0na0cGZOy)BK|)EtC_PS1jO3w3NG-v zVuY;S5&LfvJQMF&UZGSi(7!=Zk@%O$?$n`Wyhg06+==~7KGMcR+*6|bb3*A5J~L-M z>bCY3zk8_qEEpgUb_Fk(_}eRv2S4?sdCIDKi`FS?_{^NhJ*MW}Reg&i5%}UU!4(|- zt=1OxkCXawCO?5BD>WKw(8NI%EeQl{2ZL^QpwU3kdyrexJmPzi#pzNPF}^Eu7^;aB zn&Z{4Gofhb$@B<!fcTYtWTt*4LbhUXQ1b5sFB*0#A#ntPE%6&&DFGjo@tz_e=cP0dp!`w3 zJd+P*@)ej7Mx$DU+;E8b0s57o>eOyWXiUuaH}MLQaVhcPKE)PhoD?p4=914|l@p)L zYPIs`*NP(xaq|f(nM=-jHNU-<-#)4d&|^dKS#~y2SNUK#fV$5kRn=&KPsSY_^o;2vtHrq4%$wbKGGgZ%tBKc}OuJj3ENP(>uZ%t_K@ z~ID|29-U|Ap zEOm&rij#4UxaMj4bo%Q|c9{GjlRsecM@+uKWGcBbtv?3scQ`E(5OW2`C@W9iqheNI zU&74l6-00FNzp3cZ!gr$pa?s7GG=^@Z`u7KGQL;f#lJ6}2z9~!dh*e zq^HB;aCs(1^HQ`CO#?MGX*WL>&0Ro*HP%jts9MCwt}vt)ZEkjNhp-m=3Lhtg=@h-f zNWRz_cZ!%>_x5R=u}&J3(Md0P0AF_G;$uL^eI*kngp(=IjY@i18DD}i23WH0Irnhp z_azujvu6(@^B(=K>B&PWb$9%`#$EHSb=SV@wBU0An*Fpqa_*RK;u8Xaz#y&g#Ni&= zn%0&RjU4{GnG*-iD8XZ=)a1-V9P=q8Fr~n@3fef|9?;9@QCp7K57@UzJxBHp&Qn6>L14kkM-0eE zcpTyLN512j%Z~Sb=aE;o%;iVJ6F+1H_COq4_i(*|E1}0xoUIlfJz_vmgT?u(KsVtG zdn>$xgAl@>BVQ!o_=BogylnIO$E1YKh!0QnSuTe{VJ*UuDo!R~-z_0OJigDeP9au~<#MUs7h``Htmp^^aIHIn#^uTxF>ZCK3#zn9JyA zR(&eY{`vgtkqel%k?C literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/ddim.cpython-38.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d1ef72d25c5b5828d45ad848627264d7610a1f GIT binary patch literal 7197 zcma)BZEqymRjzxhy1Kgh#qD<6?&GEfiEB>k!*w@Bk>y$KOjQl>-~Tt_<&YOXcgc?&}PkZZnf>U zvm4Q^I(^^Y>YV31=k}ZBvTxz@cmMom^r>IBtbe1-?8iaoMLfxC5W;F%f(hPZEoOds z%f_$WbNWupVN~bza(%bu_VcZL-)njOLaV^6mo1SK?lVick$-5nih@71n)!pu!$N|){XS-}wwg>k>0{AvI(nlY?nQwRaX(1H{-_rZc7jnT!+w-TGSLec z)oI+1k~A76+8+$#Bno=*UZfYI{ZxkSG>8Z3Fi4|8GL&taWcE{=W|D6s5n0%n)#761 z9jk+l3j0|BTXaP3nbmT{swfEm8EfSt_c(uIi8V1Ns?XRVYkAX>bulk$rlc?}*$@k& zZc6-V$#t~N6U(Ng^p53Ntul_Mp^Fb5y!!e>+IS?ji`_?|Y}+$yr+35g z6!9dtAX4klI%X#pvlLTSz!mG-_;nPA%n9qE^_*=rU7btDBaG<oMlajQ9rm`<|XajNQ`?C-j9<+y^btlVmSw)-Mv2y z-@3Cq>_>O@hH-m0zB7`;Zq!ciq<52`A0~S{G1B2+XBhO_X@8oGME_1W7V)r$L@&gv zO9K&iI^!fB4(^af+#WrZwC^gO#DlPy&s^rQgZj*|-9FplrE{Y*%;EVGp5#76YAJrm zQr6{rT;5i^Yp3iOntj6dY|0-yXZaI0;b~6M`SL>>8jmw=@-6MO$0EeZI_c9$k|K_+X?znVfoxq~bd;3|9nAR;{t}U6`@+Z`mFZW%$+_!SIR?JsE8U^iUZIFag+^L2ZYpLmmk+10$tpRci|yz}&eDx9b5v?r|1qj9NJ+36fg<}hku z*)qS*vb8o%QR$ksY}L`aZcThM&!Y0Fr%|+}Em+;W0YnieZ5aRwcZWjP$VN@^jPYh= zFyM=_{F6v-50gmZbz29&681*BVba(QpG1vx*w~I5BI*EUL<4%$h{E=695MP`B^(hV8JI+(#no zrE}YuI?%HMl z@JY?|iWx;l?Ra!$1if}q)&K03%J=Eqp#TN5C@da3x$@F&>`(p@n(3mcN#Y$~>@<_@ zMORJDbYo_#t65vx@6|QyTQlPXP>WzFN6znP%Azd!I@xwZMR2bZR7MLDVR}@W-&v`Y_Fu%X4cl6 zS(&keD2)#y?U{nJ6Od+pj~qXNvHSv!_nRc{L9}jVM{{{g=Pbj_kSX4$sf)qm@+I|J zM|tuD!m>G|ufzCwK~{SyyXny<-^VRt9Y4w?a7pS^PrdC;-mp*{AC=6+$2Hfv)Qa@58%<;4%>UW6!&+6 z^W6O7aX27_3q)kdrqB)%8O`@!(ASI(j=@0B?ZoI3sFmTRLJmj6ajJ`H)CUC%)3J>7 zT+$xOD1ZPaKwEY7JQ-w^O56LvBV8m^iGblxgH*4N2j|;@HSUaK5f0i>kYJ&aUcFp< z-Vz7U^t2x*@gM=~0x5As+)lOA1JlsXUi37PU&I0E+(?2o%U38}Hk&##j^;*YkL;oL z%`paaj9IVJ>_qY_=-jFfz(D}RK@vS4N0?EP`n8JZr?~%u{0)@I*GX)V_+1ijkoY|k z4@uBlk&j5cN#gfOyhUOj0$8sq0k8U{4*V%6%B4F-5wJj(cmg&h)lQ)HB)C51`lls6>_N1;{Q}>E;tx~fvNc(xGGVF3Z-L-VTvgv;^a$ts z99Hkhzfb+ND422&!I1}ES~>=;0uR21x_?ma0JCH>OqM~KuEA25g)8z$&SWK>Q!9t; zf%W#k!ctc!Yq0NC&>os~9XqS4g${%D){veT9(rsjj}Ym*Og>PBES20!Yu`jP)3uI4 zt<=T|pIlehl~=b!fryo=*R5+7PQ~ihRUO!}q1J&Nd7wqv;6+u<3;(DzxuI@g74=C& zN9;OI75jn)xh7Aitp*Vrcr! z8DBdy+?Bw3K1qjdL_TTU#(Qt$>f?XVEm?xFaF)15;3ejA_yHSXHr~d;#wA3*+Ph6> zb@m~Dr{d!VFj=j#PO-u5BPUk6Hj0PRmny#Mwl*3aJYyL5?2{lQlS7E6V>I6!_Jt>I zmNO)9aYIM2P@(1~9N4E6y)A)BA&wRLl#~7W%eB8A;mUHSF4nIQxep)JAcKO+mZ90In58cj@(i*L(04DK<=wTi9 zS8z1nV3W#ZZcxX7tbvXFzC+r)b8F(XeAU5_SeRordab&CiI zs6|1zdk)6>3082eyWCv?ttfyDan!Au-q+Fg#x+YdaOyYJO?BOn25&2i*^C3Vagm)_ zTMU(vL>K8oWYk>EtcV<#aoxHIui1MVnA_XzF2!z#DQ;SJ$AdjxJ-;lE4?+rmG9O$* z^SJXAkpiL)6DDUUOQE<4jx!7<(3NyMnuX5BRtb{zd?%KPvBw?UgP4nB?~Kh$av6x{ z2ior-1ffYmX7bapr+F(wYr4MCtS#T8(Tx)=$!HXXX^B-)M; z7db?*kO>i+%UTq%9}zI%wO%4 z8m|vPiDiKQ9gb*4?!M`D1#}G1-6h1?Mk8x}2?$feQ-yzW);nn&If@Qwy{!jj$kpko_uP2s_~nx>u3Ln+`jsc$n9|&_p%!^a}!5K pXgHE~huidTiF()ya5x-t`P`iq5Fn$#yR!0t<80Qlya z!+C9WkIW@%l0aNXf-MH}DF-_MY3QxFnPlTjxNlVLUrvv`<}WQP`+{nSq@>Gx2GEgZ~hbFuQS z)x|-D{iKK^I->BzYCB?86h-iawF|L( zgYe=m>q3kN;boqmc6MSh>Bo3KP8#>SvW$i09Ef&zzZ1#r^eNNDFzocBG!4V2)-S$t zduKF=Z;Ln{-+oh$dT}R9Z$}f6j6#ugyOT5-4fA?m3~u)$yr^uZh2H$0 z5Q_!OWj@<)%v{o~v-4cKCAzT!o-gA`?}B8O;s-2aJ-*B3Eya6w#*QGz$86W8`jK;1 zKW0;&6%^ejKd>R+xWN|R)J|t2BHW{sJs!upkU^|=wENHqn07_{cG8J8&qmry<1ETD z*|H^VT{}YvuXczf8(`#Qk(%aYyq!RXWwFm~=;S z@NXDqJ`X!|=(7)9UJ(ynM*3rxUxH%qSv{s4!494GATFrcJ!e`_gS7s1xp44UT)DFhKW;zgJd{KuQZ=(+RwhA*ZZUG@LApTikU^mon(Av2EBICH2Cb5#*gUU zApk|QDQq4mx$@F&oKOA+y6KW>Nt5jX4xCr=v*?;>IX##;>T2GX&U5jQYk37z^8)lVo>fmB{Ca7z1#NwoyoTA{dq%#QhAB+r;I=F-!&4({^V((Kc>N zl0X@8Xq9qg%+E$rY*uN{nVlH9kF#Vy*1ln!omiUnd@}0`4h+GFjmygKQ6mvM&Bi4zWfu;tDo=_7jrl#`zy~B@6G%2MQr9H z0n(xUo_#h@AQ$G3@TC6;0`RgXOiEZe*s%jHU&Aj~&XIj=9l&c_u~w`eAb-W0E~zEO zVe@1LxK~R*1k5Y`L+fvx1IDc3Zz?D6H`C{obGfIodM+$vQPqTV2rH>#=s8v-^p7K?oN}`}Ds;sJNPR*;jT2Kw)9@gIC)3xckT2t#?<0Wq9 z@KyOnB!_SSKi__?1+KXkHbu_Plo5m zg3fJE5)lnMahPJGv0lC0dfpQk(DHSVq{%SNsK0hZ(#f>bhd-g6-S}}TzkmzSg|UPu zEMKN_#T@EPC0ZM~7P1FAFxME;HRhwrvkS>rF}PhD!qWiQhH3n05@SXA8nH{~x48F$ z{C#jF1r72IB5xA;Dv>`X@--s#R^$UB4~cx8$e$3I2LV`YrJ8+NvsX0x1^|RWz%v_; zQ(eJ9VmXLK0(UXJ>PxwKw{K#a^h+StstG|%?Ve+K7en3R&S7o+soE; zk$A#Vi{FJCH+5BGo6#ejA8}~CBY%5bG`dRr1&J(=dYiRpB)XD+7XMO9Rp{J|UHE8!LTpL<*9p|d4g)W2c)={1p-k~$yP(I<& z2TXoR74uTse9d|X@k!4*sslAPj`{Sux~}|&C46$8RHI>Cvv4O?510UC*--014i8vS zF=$a!^P+fIp59P5u#3jDsS22XAIP#w$fW|*z6XS9qE2qux+*PM0qxRckKPB%v-n z+Rb8agpJKLdo-R_K}JI{ys*^T#)h>GXC~KdAD#~a=~!s$xgE=wFqd|^{bVdDuGGaC z2AL2IkkXy>`+$_ua9jJ2@_0PfmGNoG_@mm5#^ZP>S~pLl9YPMXRIcTln9~4=E}nUs zX@1TNoFE{{;o#E{Hb3#X^Gl~t z<^UHKobVH{KL+)xyo%Nu_kYO?b*IW}#{LJaN?rkAsE(e`1flX!K5|{%9eG~JA`<^CZJ0_<44F4kPmNo1z-jF4-j&+z^a4qWF9n& zatZ!{tq`i9?w;2j($zDszbIe8D)O^LNCM>-i5R!+4^h&A5lEAqkuEinS!#K?<4Q{FoVRMsTP+AgpZ=)qrYccB?bK7M-ivV`a0;5qAc9iw;s6G-w%6@s~tMJLS6|nuS-Ckzz}F^FFL?N`hxq_(>Hy z$f1J)3t;9QqiYU~y=(mQ8595Wa`CcJ&nN59(a%1VQzJPwg0KWg+(s&>XZ@IgvLCRo zSs$=(vuSlYH?64(sZIL4DwCm+cSuyAS`66)er!+Yhu_Scqk>{E!ki-hY2DQA^ZJ6R zR}|ldh#BKY&57}QajB({Znf^s1%mJJNSN^Nvlnkc~SHPA3Yd96dsZN)b3&UDAyg9?G z!69&Z^QwvrDZX6?0fn<-ArL4)!MQkZTf5S>GSjwxrEPhpZNs<&^9|%p)ap_F z7%mW;2M_!DTeGhP81w7L?AvfD@K)yF;;fb)WyqZDOoqF&WlZ?4+8yEaPz;Wnqrjgx&ui&*d+u-v>nKHnbfh z_;JXeLP{01?R1C+V263o1J}Z&uW)DP*JN$li=UdH(QY37x9up*+HRC*cs#=dWB%%( z%vgG$Ng_l1Vemff$lWuet{{#tLrCw`Ve|un zBk$t3Om%mi(U;N&ZlM8V)!@rW2n4JS$KkSZzvN%Fn55h{BHJc@jAT7*&2Qp!-&8hp zF+43l+?20SH_97qw&phPqZ_(jFlWbXw)p)`Ol&&eZx!#~&l3)p`yhZ0LrP)eFM}W? z5yBqMae@VKOjQ3Lsw4!GU&wpemD#D z6lIwdjU*dQF;<#^_^j&so^=x3ab5i3@$0|;UcT!Ver=!sD1Xak<@)JdKQFgj?^pW& zr`bWPti3Q4qYidSP7Zt(WELOGmr%~#=IiFq4g8sL>*}8wwSz$)%jN7w_UfREtU-w;zpOun(0;dzsXPGbi{|7gl6jlHL literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-38.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce7e7dbc05294bd48daa66c08a3c64d8caa0fea4 GIT binary patch literal 11093 zcma)CYm6M(Rjzke_w@9%XL?>9KYG0nyAylHk6qilS$3Uxvzrih60&BKbT+B2?pxhc zJ@shcs`lDV4T*M~WO;ZM0+EogRx?5%%p*S#NGLxN2?>x8LVyH<`b83n#o|u@35h^< z_|C0J9E(nwA5|f$Y z$#q%0mAWF{YF!m?&DZO?oRu-^Ii#tc>F4Wtk*0YCf2uwu!n#-Vi}j)i8{V{Es+atk z`ix($m;FQaLn1xr&H8ioIT1F!ia%eU_ZR95{$hPmq$%D~eR(vFJS(s%R=lg$53>bU zVl%hp`U)$vL${^+k(VSk%jRxNY|ha;N4M4bF{D)3JW}QdDaV=eid0*)yRsw|-f%ru z_1&9sk2}?{RoIB*ZuHFBTHB2`dX3d)=&yObrpvhP2GJUGg3yiJpnZn9tyV8`!{AK! zRvhx?##+M*8*7&?oxRv#t>(qnQ|B+Wp1RB~TxeZ>vT>>L)cNzy#dFR@=fb&*Pc|-$ zTe3ziS;Lk_y*!w=+P(F#A%Pd6G3cAh=`^x5-GXVk#BGqQ#kNli6mvU> zS|Rrx9u>}zWTN2v=D#eb~;-PP$(U#jr3&zj5@rQl8rlzyLz3EsD zC$_Dqx#6&$=cIG!nA^?Ra)UUuVkd|~o)&0p9m{h!oz!enQ%yTgrx4;{6xnecr1C8^ zXbf`l<8;+#{JB@%c}?l(Su!{KjmH+HIi$Q*AK<8XuO8{oET zyx~UGRxfDAV5@2zR%ztbh!|6pT61pLo~Tx+;>SWKss>?PZEn~>+o{^YEkESW>OHBZ zq*~;7tyH5v_#C>?STd4+SVqge`pTl!L)uRo&>+mJCHJ%<{Kr}tn z3tVD5D+J4mW}ZNDHa3|bL%4}&;zwhnUvw74ks$Jy_+tQA_3;u|eMBbJY$x;&q zR7<3e#^g;SIu+}0ViaOEYQ{P*qZ&U1k{TPg8XV$imLgiq_57NF<>;Z;)U?3RZfsfH zq}-X&u33WD(lR)gxx`&Kpw^~?UHo7aIXHD;@L|qFUQ>AmHS33{W>#r%r1;}VzbB`8 z%W8Ud6rm5nR7YlgH=qFU1iG%{#B&Qn+m@cs~s(o+u*K6i78CI zqip3w+1#M)H$_PJ1W}znkZ+W+@C(I%?IkD{ZZXlMco|g zemBvXxh(DI=tZ9BB`BUYsg7o2k-l9z35p9L_gv&Mu>1vKSXNrFh=(EY5iVNR z^`7lzIf7$T&4Y5`7iqxL1Ga8Do;PG|!QUs5E9444j?nw?M6`OQTqNF=q&-DeaJrQ} zk>3<)`^d^59Q`K(L&-uwhN(4Emg7 zt85LeR@hl~4)1v(8yCQ%ms9P97e4Xw{eOe-OAXHr9Gj<#cYgiFz6oWv=0USwOLf-= z2DncO^FB$s`y?vwGmv-BTs`vd-@SPF`FDT5@{O0i)BpA7-~GRP|MB-PH@^7%yT}&5 zLE!t(r>fJ9nn#D}cvE-^549J)j^@XrH(3&b_He8)Wn0FPS|TybpqeaAP(7JJ*hosC zxugJUCR3pKqzGC_ra`BY!=S}@8Uko>n?!giMy*bXw)c)AOGzGSWu%n(Tk#>v5#oL} zS%&bUyg8)J@pqDCge%EB=zLNET}T!{7a>Scj(;nGSh6s}&T=xBkg_Nzhmu(cC! zZCHns70{LVNatvB1m%uFc^P+TM8^?Zflw*O)y^XbAII2oSxzuJN;8LPV(updeN@oL z1buwKqQKq2-t-rYdY9Ja2~pytP*>&kn!yHyPX5U8nc2z3** zt0%BmCxBU4^~BAyC$5pu0y1iZkpq1;)pl<x4cqIm| z%U3ATJG}Igwb$pu>07>&AhvM~V${5-YJS~rc z1ALTD-JmXLjxZ%EA$QxZ=*w$_)*Fu7-oOMrjEvOiFd^pyi2>lIhq%+m@_~hh?WTDW zjaGcC>wwK6H3JK4al_`Be%4yxy>vS0`PS$-)Qzk}ssurr{{)#xwC3nGY{SB`BKwAe zUE*nJkTAgiqd3-?4BGmfXEy} zDhj}isAH(3{9!p+`ERq>pC-*Y!puzqaKj&kNAUO>B9QCw*#HhdgHQ&HDWqg+1Y37> z#ElLhFergsqm$oNUy(+6)CNqEjzXy<%kRnSwE_>(I`0y>4w5Qilxkw{v`**-(CLD? zQgg&zrZwa%66nNGHCZslK_EEQ28*CN@#LSO?WhVO5#UqI_Y~Fqp`v4jB>eYi)iu~# z+FrjbR%Kt+2P{ppwXWic$TopumMlqFSFi*S#*(}u@gA~ZgiKDZre)*oBNN_;%sS=? zVh;>G8V~(%ksk@$MQrC_fcpwGNwN%ogc`{YCkFT?h8kmt+cGaD@>UM|7={dnsU>=1 zF!_$Ym19bhTY@PND;*VRK;Vk4e5~xiTuJf_5b{gPvIN7WkmNf0Yx0&F8;L4VMou6D z6KxS{$#-<7F4Hc<9ErXd7nnh~C!UInooPHJLI7JpI6E_pkWOw}-X#P>@jT-BeenX~ zg?;gWmKz}rP4T;q=}qqk#p!FIi+!(mB~6Y`*T=%+#(}qO=9nhM63~$cy{NIGc?Gqn}HPY;G6E*}mKgt%I0ZVHXaHY`iPd zz7q4vRMw{`Wj?G_r@|Is)+rpgt~1XT=Ong>&;lGDQ;E4v?a~YuY1}K))^u;F)VAyF^#tu(2^;w$v3_F4_41hZ-Tn)t?VAamb zuDo+(SHft1AubGD49CeHKOjX3`3!8vRI}lQyT6BD|NXYt-LRwT26PE|s2WZct~nqr zR)rN-b?heGSOa_pn58qJz;r@zQCi9>Tg@K1DOox<8He0(ZlrV3bsk%jIq6Z67y{72 zwb+XeW*;Fb@*{JXMP*IQkJ0>;TUN>w;Oxkm-Z z`%jpSt3-q=z!;gW&~C;C;(86ElX9bg;2R^4tUn}dw!;Ix2dIHCUP4j zHLeM7D?dvS-Hie};J->?)%LtvDU-lLhYQ3Yu)}z4Lc>#wsgZ4^Xq0~mDRq-vx#Wq@ zdZR`U$ZfuXw12=8{Vs@UP~4dsFYRfH@}8!Zm9mV#x*$WHOXWRH zeb~FQr|Cu6fci4tBNayVP)@EOMDL}+8=*b4nj?>fB5PTBQg?fXQN(ykNGZOj>}iIn z6wz;gL8!dZst+OJ1O8^(>n6?abr76&iLZjnaF0nHIJkBpPf4C5g|L7sVuTg)e6?R- z(yOq)KP5et)vSwh#zf7mZ2i*$9b_srmjSc}IwVjJEXS#HaN-bt!g=2eyTI>ZnuEPX zMPMP4mywe%^uP|Kd4O#)bh^$3fSicRNdQW-*cTH(sUX6p(l3cM85N#zxRNr~@La_c zeFy|Pmaq-ndL3n3O(1chB_%kXu?o-|gY+#;l+#9WJ=4T3;GvpgAH#Tr-<^E#FF7&X z^$>nIZQO12$g-$jKXucJt4Z~yb+)!TD*B9A3M9hL;qfz~+Gp3RH>+3itgfzpu9_V% zTt4A^;|9WWqNwTy-Ci6$(8{1baDT($RAc1gtqPO4dP)e$>bcr#^&<{eLGF8C|+4OX=rw8Y;1YHN~N?5Z{cjE3!6#OJjsEC#%g?L`sqksRXn9Pr=GchXy zo1o;A=SM^#J}q5OAw%@h*A0krAyzl%R6%p-{DLiyokoO&j43G!DmO|#lpi?GRJ6#xW z!H+#^-N~cIXcG>j7YV-jg7DC5(1%6Ral!*$93Hv@6`Qo`!ymEsXcxYS2#yv#a!dXV zTz!0A-qL9WN4tQp4aBONNX^HP-6##aW?S7V0C&TUz6ncc7ss-rCpvtd8mu3b*@laq z@bo*%+tN-B7e_$aK6o-k7`(gXCWY!I}$sCN3TS zm{MsnsV;DlP;PW52^{1E{0!x&&dqMEm>q-34$>lbx{e)N;>@RVBbBY}GV_n9;%v5M z6T3rvlNN(8=sP^Ld^_4acw)>hTs-7nqLB|YXNFklexj(B)YG_28MqaOWhWG;i9YC< z2*tUB0JZ_lzlCLIGVY@^VT8ht#?er46!f;dWyD&n$1qk=Hn$B|I)jzn$-(`tz94;W zU2G*xP}IyPT1&>}sR$R~x~4-nMTZ70(S{cnpps#l-WIVOgV~EicO@})rW2jO%JnOW z!tau4+`gI)keaYZ0H}eqDHSd>kIH_F$R_1$UKq49AC&M$eI1E4Guv;0t}B^_ zg?9;lI3EHZ9hp>iX?;dLlQK-Glo&J=`leOTz9KXmg)HZFoumAj$6p$Vk%g zhXq~0O@s{n2o9-+{QhjFItHvEnwZ#gI*cQ*^cW%#Wwe6=GPN`z8WnH=pb|8Ig5yXZ z&nWnWgZ~9eL*X58%>5Om>Y1i5H0-#!A^fqTbmqDfr%3qOM$Wd>$XNI5G&=(=NPyh9 zL!J_&&+;auISrYHuHuQF0fD+w$5-~-7&3e*B+HezF!XKpEu49gf=Og43Dmie0P1=Z z4NUx~SsGka0Euy1{}#aWQ1dC7VN?*TBEA<9TLvZ+B78h4^I1sj2D6{Li)sBAbVyLemZRnm7os4yf=(I6xdMSPJ;|U}$ZQ1?iuoOdX<&+}%zzU>qTj;dkKeC^a^5TMS>l zn6a4bXo1Du?Qf0eNt#H5drjLbsQsc?>A~Ib1k2N&(loE)iEe@<@@vMHx&`f?X>V8B z$|vd$#L+G^6A3*`SSCFKkDChLiuORNhSb~xsrs{0tdoxi^$e6S@Xq0#PvEo5#Fr35 zxRCN4e7tkZ!gzgLz3Z)p9c;p~j$3hRz5#CwW#GFgTCkC zMC6~T1r?I*pq$E+Q&!k5_h?Iq8|f5l>cXA@gIU`vX$y3XKQ!J(#pwf==cwtQBJyD( zFA{kPqy{Inat+Rn7a_sPmP2s8w#YA0fy+dmCh{neEh1kavOz>4^4mo25g`mEuv4lv z!qDSFgJ-N6gyI@76YV6w_+JzGJ0f%p_;-kWfXJ8uPg0a7$)5%JG@gjMHT8WS1&HOJ z@c+HihFLOmFu{la|0#T4k&9(YHK$4?TrL#x%;V7urfG;?Yx1>PHO*NT3!C`HL5*%j z{00q4%nJ??ctci$fcr>Gby$tvTm083EI$1JFL`c5tW%m7-<@GyqRjOp6yVb6gRs@% z2Vtu)by0>SMJA|n6e8eIKlzZ-2R|QM9o?;+`bkPtXTpV*$wPs+^5Szf-03>061D2{ z_;TaGS55SB5`BXvKK4qDm&GR_;;v6D+P*y;Ns`oj7MKfPiaZC_l_JuE8k|8UbS$}P Uiu6~OKr1>@RowF!rZ&6!|2=*ToB#j- literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb475d005326564fc9cb8c9faac0a6a704d1663 GIT binary patch literal 11039 zcma)CX^bT2S+2XQkLkI3u07l9Ydcxb?BPA^n0S+2+lduVOuV@|PRiB&Rrgd+9lc*w zuXmdoVr4eO92+Hp{9=V>0s@&M{6HWC5<&bRKM(>%LWtrII1*;Xp8ygPA@T6MU-dCP zyN=SWdh5H-_j}*xeZR7&rV0{%cR#txrY}p<_bIafry+6~PxLkjlj;(anc~THS-h3H zBHn6U6>rVg>$;qkG3q&_sh;WQ>v@r;c?G{%FN(14mHbk@B*KO_<(KPaf4V;H&(vr9 z+4`(V&v|qHe0^SoO|RlF)EE54`l7#7UlM7Gw_INtjU&$rtjJ1t)%qc}$jWT`wp>5V zX4vd)sea@IiOsS3+Y+01^v=<3wSEjK6}Etsg+a=3ro1B6mh7%9Nrl&4k5zs5X58aW zHEb0&;7r9~ZME6!4^5(`` z!wVa07cZW^&|t0Rh1U7A7hC5qv2*8Im!51~Y@9!P*12%Tx!{~TbK%Lxxp7O@s3mLI z(x{jF^H#gJ9yG+T>jbvD=K5V8-f-A>blIq!r=EQB(o<(Iooh5MUA$m7&pdhYB0J9* zM*CEw(O_rUQ%{{f+jK?^j5}Uy+r21q?O>er9%Hx?`rS{t(OPW#v_LVpgQyj9-{Dc= zi6^RfH{HTLFqgDmlhblD3}WYI?Aw9eb~sXUU7N?Q?IEbe*O_0_Qp36FMsbwpnA5U* zUYwSiVK0a+l(hXQUBZ^cc7RH;W3_nbTWGZ9w$Xy|F*knBw`*!T=i8f()o^0likce^ z>v>K(kB+(Bj4d~aLo0TIDCB8@w$`ycchgDDCNx z>VwatdqIBq+}cLyJK#O1yY?~*6bd0!}0EBagTxTv8-6^5r3{R*7dRZX_>i^ z-S8Y{WuvD#ipi^d3E?Kqcle{Qjz=uS5d;z`k(Z@d>c|O1N}{^*SLF)QmS};L_0(X_ z4Y%oF4!Ss3sRl-3UNBhuBIy5%Ls*OW#}E90*R$Z_9~fsfhxrB~nLY@(n3E`J(hCj6$qN z%~j{E^!tP zr?u&TlX(Rt>$6lcD>OJh{86O8C8v4IYI=4Qp-sU;w8Z>jkoV(>s2fQx;BUzNm4mCW zIvmbA`k6vWUIxM0v1r&$gDMYxdAvs)GKWKx$UD-mv<3bUZ>&xS>(P_=$wb=0swK*< zvSkGSCer0ig zrHt6v1}7xMXT^*S?vY4=bW@q@3zk4Qhm9pc%B#rZ4-v^q@o5SZuhepBfl9gUUf7Go z_VPz4a|I+-NzCP<>pjQmLo}|3@kjlx=lD4L4ttLCkk|5pRZ{t8D#zp9k^GlP`#>?Es+ zl-WVb1@=DnevvZAF0xbX3G_M7R@oX_t+3PV4BiVu1}=g(FQwYmt1rKF@82QhQp0lt z$L6Wxon3#vZ$e?MdC;EMQr-2T=kJkHyhjr49*M+z4CI|Bt{nM~w=W!e_MM-reEWre z?f=@d@BHYk|NP!djW?cs2if8`2>jsLRCT&h^XM=gZweRjP#Ekn>M3@XS{1=W&ig!QBhY9s~FTv7xzlM-k?nF1{&hd_&Q34&*7n?&_gjPji+ z+TJ^gEG2oQp`Fe&e>0w;93j?clNAUp67X|Ko8!NctROs}EPz&$3g|+z2)YOXf^z&X z5{M&f31yd)`Gl0hOfs9yK`H0%x;!dMJSNoA%z91XAH{Ox z*q`p7Z1y;ZtnhAC!vF;8s2c*v1?}ny?9~b25>`EN^Yn>pB&>jb8e!zrjI`KxZ#aPf zW>GpdiO=$}>pM}5>S-xsC<{Q3$7x{@a{>m876OSoRs^U}fL8G=a9hs-^l|$R`kLa- z^`6U}$ZB(&J>SNF@msw-2t0C`zkoRZ2uPN+^V%@)^5t5g|2Q%4X3z5)c5~D692ba( z0}WvcZ6|1=W4G;xF8e78hArUSo7r*bMgh-4fYp9~nR;>2T$gk+ASxi7A#H^3n~Obl(B2hdfu|)bh_p8 zC>|6CK1rtr_!DW^&{4FNMNLAzSb->A?g6!T!3)&4?OE8Hkb(UoA+A$hbj@=#)EpIc ze47*INb@3x;`J%Y7@7w>of!oOkSHzQpe|^RuoNmGciXP$%WH(z>yF#rzyv&uj5Oyk zA?E{Z0mG)V+-YNXz&*ou(>zHfrYiWVe?ESYb~%{Iu-PMYkU%ntVF5= zL7IP=EFfBQbQ`u|Kv|J}!vS~lv^+?FDM$c+6vsM~y;`65?1tbR2AGX8gt6_xCU;`) zHY2OiYqcC+Um!3EV+?$c!4YKS*wWh33&0!Cr)^Y$c9+e=4O{Hw0I*e zy6u428M_(PF)W(_n0vP@s?A$e<2536BA+AjDv{3<`7Dvo5cxQfmx=r=k&h7}wZWT2 z93q5Q`6r1qhW+@M(E;iOC0uKC^1JFQ(kPGGfHLYRluBa!EqT3G;2~P) zT_V>(QYDO1P3)c43EcqdTrgK^j=0OThI~b`oEWMm3#K@T{fcd{2&xlL{uSDe$Y7HQ z<0|?KtIA*idQ7wLl*3f z$w}3;Y|a;8Mi%mZ@D<2r{Y>rC&@ zVKd_v89}MpWhu#30N;|rvLxQ<^DxbJmeQA`EdVvlWELYY?Hl7InJho!Bl;1Xip*&dF8(6OJ9zN> zmuEA;35R=s#_tIy-T*c`40zy%*BD@k=y@!60sftc@Iw&Sj&M>8oMa7vKo{2kVATzB zJ_U@EIh3mKYw}GhINpDiqQc-e1`f|$3o1E2cQ&0K;gd}J143iBQWeV}v}~FKdUgGt zpDEdGYU!7W+y+UFYr=oZ*C?X9QD6uB>l9XP&#RR);Vaa(00{yvjE5zZI<=S@*)EDk z`IjlhBrtMR`}#W|SEj ze|1sTNC&Btba7A9OR)D9#dw#r6_qDVUqOi8%Y!#UduT03{t89bvht+a_6(zhag>o# zdRN)g3{xqgzy6}oa--EAqQUz-$F#joJkc8nPi50$~gVD#6{1)q+YHq;F}WoHmN< znUZb+BGnZ87{(<$?c{BL!HMCZhg>8-lH2H!u~EH#@}?D6lj=?DbZvE1^l7mqNQCLb zt{w7=JbY%+aa7_=jCe3h7`gl42%1yEa&3pl?0ieQcdMel=Jk>@VtV85vfkgODA6 zBb)!7(fq$kA)y3*1)*9lJLtbdslP{rcDi1k)Rg1$2Unj8Q7aF30= z3yHRc#o*5;`VOEu({}}A#wCMJ?ZE4%zKZLDtsLC;Mv_ZRAu)36Lfu2?(d7B>6CtfX zQ1(CgPIg>WKDkObKO{*r-TDUz;g-a+{jSBojmW_F&QDVj(JooVL*E7Y-f_5qg5Y*3 zWc~~~U1Fuv#qsk0#3R<7JYtMCfUGPDp7?_B&uh?=CFnq7uQWVt2P!pb#fLvSpQ9Zh zxzNV?ccgF0pz_z{Pe@;rzaVevw1T5uz;^{=RZXPkW610dpnsxmt6K#?aX8aAVI1w^ zNOtr@Z_7-BL4-2faI_P|{uWyOsOzu{&C36n2u&%~1!5BFjqV_Ukeq;@q8!z^*{zkbLonGvTH;REv13b| z_*8DBvXxz4{t;E2%QkIde~5F^QV<4xhliGLN1F%Fi`j+0hkQ*m@@NL-CFKzIi1w$f zuHqhL;9nS)n^2b~iqjzx>T(AGYywz+3yaQV+(l`^DurQv#2iLo=4=mcf1UrrQ$mu%%0KGaBbY@E3+wg{)4cBEf;g(gof=9uofnG-eXwiu2q z96qsUUqUhdheT*kfkFvlW?2FQ>vX9cZ~_gf2?GTP8_=6l;Y-`3vfm-np?u8?gLdYO z5^kw)BC%#>doAF0CDW^LGr=L}L)61n7G<6NIW?fVemk9-$W3MU;GRwm_X=KY z*vC*CE)&2@azFu4;?aIw&;=Mp$k31Ah-%31&t*zuz!sv3i9M&|ID$xzAOb-~J1Brs zOA~@o0S5peLG~*+jP#L=f=@O0cTpOu?tq8xZzxsI^n0OU$IT7lm=&cnAD%ct!r?aZ zx}`?Ow%?@L8E8QQ5Xb#-QH(yzn~>!+WE#4PCwdwLib@?{&u?SMaITOfSKh?Xx79as z-bD%~k)rFH;@uOyGaN?pb#?AbjfYw8ur)1_*L9mMWE<|h@m{17t@ubXW zAg>$D{@`6q>%Xak{_z1I$)>dWj2AX-FZv(?SxKxMP5&Tr`U=v9{wmSQSd)nE{*LO9 zXim-RgRTB`T1xUiNw_oJhaxin;;__&i9riWy^|FdJp)2tV~CSY*aw1m1>ZklQAshv zvj~w&XBgr((b`yO616V2;kWSR26#lx&L;m$Gz1(Vj#rM>Y-EIvMd@FoOdZ0Dyx&eV zU>_lo;f~<#C^a^5Ukum1n75enXqmYsa(e6vtkea(MRewf`b#nTk zo`Dhu-Z{MU37mMD_!43Wmr}lik92NX7_X0OcfHlHgH71maVt*E*WrGl41E8D3w>nL z4K}XI^1WbqJ+Hv55-U7AzR({u#{UFU@Yj|C!qOCn71ephB`8lv8=~(h6(kE!qqL6mW_)b#c#tA+7C|wMDwZ zADV2V;`A}g&rs9Ea{MZh=ZHKHQiC^Ixdv~?^N`>lLIQ&8wIzOm3S1;|iO9o5wupR@ zNSlc8uKYHIewPR#D1n?(tr3P^roA(!3_@`gn2C0hU;J;0e2>U^B5xBRDYI|Rk5QB+ z$v*({89dPjh^g=MGe9i=l>hISH_Wn`gMmK$|8L<7id>qZRI^wv!y8`0vw%l0n5H3m zt;yGF)ih^WENtTY1~s}B@f$QGF)z4D;3HWL0`4O%)nO}kZ}DHHu=wl)yyUqJu}*1T zd|!rji8#r=OaYd~?H)e)T7_AQ9Fh;2c*;?Tj$8fcsLa97!d6EYYRBHY3^gpgR++>T zXeuA!V?KtPTPIbbR(%2AWgPfMi9S7|FUrKHTdDDq_`pM4>xpIAw@V`_k($o{QQ{5{nv1bb^AHt@&UO0Nah({oTNh{!Z7yjW!NK0%k~*z! zwW}k!Htw(jikyKw6-B>+^i<>v^sPWYM4$NDhcr*=b5OvY;p%kqp|2&h!y$*98P5DT z)W@Api$MF@|77^h1|k2%!FqFH@IHL%yFfVMG$RA-Sw<#wKq>hJ;SP775$*~)W@pa8 z;dGDmysuzA+3iuo@;pI)y`ffY>K+h529%QlHR{OeKw`Pc7#l$`5F8~*ZhpjUqfqRAPV(*>cV zVw@fk^)2Yxkt<)4ipKpFRe1}MwB zb{w3K?OHl-F31D&>HCk_%$<3aH*QH?x$_M?&w|Yw6&XA8cI8%$lrA8isJ&`b-n_FQ z^X`I-H&M^AK183kIDr#y>>k1~Z2PkA{{7zRYY#*cpM;ZA>F_|nPd^nR4{wEW&cp8+ zkE0mU;cAhq;RZfL8K@_EVs-wuXej8_{cqnBd zO=lA2@kHEKwlkYj9HvSO86A|vA?)&WA(N45Xpt*e?hmsf*6-XlPMT}8gZ_ z3?_C|QRY0+lN}xBheag#WG60pT4cavpf}eMPlv-&rA59YCUJjyB0FFnT~MzBA+$xe zXuuluI&~>y?D@;A`ybceq7Bvt$#r(RSwpyg9`JP&vMhVyUgZL69|JxCjW@_Rkbea|P2^{O9CJU%|Of4Xql5N?wIIxoH<8;-eVyugZhSbnh!Aq=sx+Jof_m8ZNzw#KLA1@D_oo zQbqRG7=Ka&)^yY;=Ec)`q>OVUPT;PfUHKi{Zpj4nL7A<;}2K43U4fnLWgurs+3kbkSplcQZkAcwyvIaE_vc`ZXb7py`Vvw~vGstq~ zLe?DlXvw*A5As5wmccwoT$__^$PvszPPVlVIoYi z64}Vgw_&2^t;Mg8Yj=>`LUK7McB?MF4>HvTLNE{zfEK;URsmQpdLaOK9q_XOxB?6Z zxIcp+Jb(~DVFZ{1gxu4=wZdiFi=T*4k3^W1QolXJ2#aC3Z_8Xi&-=eVFJ}AiT8nFzbODukO<5RfF_{)JmT7hp zq6dOsm`)J;=M9gG@7NxE3XhC0z^CmQh0K!$`7=9q_v#4usJsF49y6?e%XrrO<BP)Vlut=02QL<288r9T*A)!(!ZM(Nt~&Bz??6KH(YL8^)}C@b#c7QMm($Jr>Ni7 zFaNZ?uKk6Kq67HZjJ^pB!gvgaJ3|Oxh3AsZoXQ!qIe>c1EL_|6uLX+@#Ay}O4{ajr zTwl1hjBWTd)gzmfee2tP2(QklSd00`N({>^ghV}wvlQP$yWGAa_3nos++PEMohwQm z#_-rB8UX|-TX_SJI!yC83yY~vC+TU7btZ(TBoRvW!#&LSWsn>Q0HmE=so7s%f|bAF z{UKoQ+(*>;CV;s1QpHcejow=!Ja?cwRdAY>rW)t9O)vwMYV9x=(EC55zJAO3MLKLY z`2Tcex91w?G@V-1Nvt{Y9Uwiot}{Ode&8o{4VFJd@#QkJJoC>%W`RyNuFzF^Sqghh zBK$t;HSuo_mT>%=6GafxGDE%{MaN~F)e`%l%iCxSiP}cG(ND#tzb=)JsxT@!QI-udsry!(E{3b F{tL{${-6K= literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7c12ddcfd0fb7808997eb6503c23d34b3c53dd5 GIT binary patch literal 3004 zcmcIm&2J;O6(>2LmL=N>Hv7>yi)oTWR21@N6SSzi-DFX3n-q-^#7T<^8;r+8${Oux zMkYDCura#C3j`=o6utK#or?Sey%p%+!D|mmPA$+|P+2z8I%AY^$@$c3N`3DYGiwlGIp{gGM;e^wi^s!|*8Ph(cN?QXd9`~P-zRLq1LhJE0-h$TW z>%0wZ@Fj6cKZKJujQ{IQKO84t-~P)tAM~F8^Sw{g3;B3Pu&55S>hC}_IVUqZCzRBT z)5k=84_fxvm9J~YopbsH%}7mRrzSI(yV@K1H3J%O?+DB&wIln;;{nhRdBuhu zhZkeJmd@7ZWRHCQ{v$SZr(W%iT2j~UY#r^Hv#DQ`kuz)8ZtZ7u4*o>xwO@O)&Ya9P z=47;qa!&Ly>a@iPSiHV}3})E!b<_Ooy|dT$M4FsN<4Wo1NWfn|6{3jlL`lJ;w~fb< z4dL*(KAi}&UR2|_n(#ymB`HYx_Bj+odxmshGM=g?LYgp+vm(=RY+Aob@=83EvXrJX zj*Db0?kd}uO(~8urG<=-s=)wudAgM8(D+&u3YG_hyiD{vca4)3+HB!g9#0YpEo@Sz zBgILTNulCnnQ#+GG0Bru7{AfZcxhf1FyulfAl3zy6SI|;c`j0YAxd=v<0_F4j5jX1 z$fXY^^*1Ud;>J+4)3|uJp_9Jf=>8QoYjK?BiBfU=75Q#&XIPHK4i{pw^RX;PaDduL zDxQ@w&jy1^Wo6Nfb3We5L1v*>IAWaiCa1CkN8|D8bs&Vc=oSr`Pj6C}GRB_2%r^ex z23yo;ZQ$HwXPXUtdl$}58SYz@5L&E1hpN5-0(c=)29yGJJ-J)MJ%B#@F<~U88aZ$~ zk$(ejX4Ae(&e@c4w`MoV)X~n&os&QQo)YcVbmq^=Ih_K|0eOM#d$kLQeFWG9B;F#U zQ2rI&dRS;<_F#+d*;% z$yI;YUAlZ1c&ZJApbxH(e=@-tYhZ0ZWm+N;uEh$fna)kgCz=pcGW zVAwe#_d2|iZ0gj`h|K`YBW6+BwqF!1G2o-6LqD-WY~p;W+QP5V(@YO-Kn|^LhY>tB z!*a#npD8h@@(6rUon4AKTwH*qzR`mL;O;_4l=&6_xY1H2Pe6_C z4p{b*<})0pLW;q>jpD7e%-_|xU|Z# z-;U#xD#;s;eaq!t*dQ@+$z3D}bkkXUH>4uZ<9J1t@8gm;kvKpwaVvbhLb`Dyyb-R2 mYjz&2!govJeKGz=_HfHskz~z7NhZA4sCOTegf%UsZRfw_;Q1E- literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0633fd9040ad48672c165c50d720b2fa57efffab GIT binary patch literal 23673 zcmd^n4R9RSec$cg-X{(qa0EdR{8~~JMaU#5k`gU3b}X3`V@nh&OO&QTB%cnq3*dmm z-SO>$AIQR1P1$Z-(@~w7HZ!(cu<^8}?I!N`M>0w3Y11}!+_dAS?M|k3Xrr&Y#FQ zC8zvS#wokS#qsOenO#U3a4JZtL@5(U8FYq_GK7@f?!JEA;fx_=47FruY&`FD#_>EJ>wH9H?KOuXQ_}783_nUDc^Acbt`$8xFQR zYPsHOuGal#ryW-P=CbSgZnx$x&$^D&Y|o>)>eF4kuD33#dfS^r0arD!>dD18@3Xi) z3xSh)I&;gIzJ(w(uBr->BzX10@rBN^d)#r|?(rv8XVGo=-tqd1)9loo=G@$hhfe6{ zma}}kRrlStpL)~jy6t)sU3Xhk-D}|vj?-0sG?<#-sm+}@J;q*`2r`BmM)>AkRRhST zdpVfIq|@8*xMee^PEO!e+{_5#4cvWKK6>&_BjRRe3@3wKVa#NmqHF0tC;M6JrvRwQe>`JlW(rvGQn>9ApOb)jc@pb)(pi1FY~y;wsds(XvFB80**<&W z{ZF2;7kt0#ojHDdzUeQl%uY2r%g6P69q0Swv#rkT@#T86eZ0|Xc8|xCD78&bTH%1# ztv9g9nC|7UB+D<+YI|k5w&2zsuhv+ox7%*Z3(K7qzq`VfS#A&!#24cQ+*#aS7lEI- zo>@0;0y=ziJ?mTRxxfqz$Jh=eu`_ei3`{4xJ+nJ=(+D!PykBsbVg=?X-e0$nR@7;M z)t$y0omN5}*@+C^mi>Vsvsk{)(iH;TWN!KuF8wP6$9deizWJV4*q_P4Fz2mwT@@B< zwPw5N*J`2JZil5O>#B}@gcqj25as~mPCFd!cDyEM9TSpl1x45ni zrDTf2$_E~QHd(Q-RF4;LG8bB&+nQ6`QH2_1u!8}C6`Q0~t2J76&#Tom7!vN(ZUl#L zdxYXlp=?;j#O-``)EG6&#_D*|j0TLUcvKC6eHOvR6(JzX*8DUgfw7qRpt)wQW!J30 zST{JAj&U=)Zr;qS8Gd#Vsf#(saIi8ruvYT`S;x9zt{DMld68%we;*}!M>v<=m02&K zMDeDv6Xk=Wjd(mUX zwVbSA^%m+~H!LVuN6Og!esg}Mv*M|J$P#87bw9M0-FjQud^Y_~m~;GVSn~sTR(CTv zh+s0O4l_Q4AT(w}qcKyg0sM*MddN!2<}U!>U8yVQe&p2)VN{H&vBx-sMd07+&W#HZ zm6$Z9CyNBOHoyzOR<6Z(suosiwM#4YR`jM&t2vzpvSfk1W!V*RMmQKjMZ;~iV#pA* zbPszkFj7TqKfA_A(XdLH*G$ti~@B8%H8 z9K0A&g!i^OjiuF#8vC@b*_wdZR~B65+R>1YOxk{Z$@T1*jM&NE@$5SOZO_H3Id)ym zuPg(tO-XePb|}6W?3s+yT=bM+^w*3N{d||5& z=(bqFv@G(iYzhjpD79hfWev`uHyX>Et2_FpVk#Q*^kiPW6S>rb3?5=|h5>h>RFSGS zp(*O!c>5mQ9yce}%`y#1Yv8f(;&eQW+atY6uN63;tYf`ofWsnhkPmmp-FuXz`)T*m z3h73@wd(6N!`#P7$;2(xeY*?1bB86v#B!P&3(c0JFt?I!ZO8){%3<-Du2)~~w%qz7*i0G#+V+Z01XpLvEZA4TE7}f!9}X#dB+{^o(7%XTibP z4oU4Y2rxR)oI`4=NxO-Pu)CYC(^mD;*mRx@|I}(O}e0EWorH=OV>?9pXY- z6NcwJ5_${|$z{z1<1!|(EC_9;W(8SU04swopBZqedVPmq>Y8Pq) zXHe+YH5*=3!~jAa!eYD}VFh#E;zc4B~v}A*$&AAmOy^ZT^T1ks^gZ9&t)m~-N>#P?O|O47%=A)n1Z@;stfibAF~?R z21UJMGL^D|uW&tipZv2wKaLdu3E~flaASjIIxxnxq5a3xyplJ zSvT?|SZzDrjX>;kMC^VMq%9xho!kwQJ5AE^nxx70oXiR-mAFnY)t^T9G%0!(vB}bw zJGDo-1e2Sct|FvjUms?1LDs!IJ?bWv3z*C9k&Qzv`WJ9}2N9%6+^`Ti6CsHm{$4GY zhK<$jJ-~enyOBz7Yi3QR_i(maoZJ8ttY;-zqm8K?kOtEv9P8$3-s z?!3!M7;dnxXqOjEa6f06Q2Z_Vq8#pq(M(>UG^la}F%7@TYcfc%+c=X!@#%xjHF+ zQLVT!x1EfFEO4q`6tL?^5RFE{0QWK45SgIUW2cc38sHYD_{7x{ecU+2zur3Zs)P9ojxx}twFVNd@)YtvY_G?=-wt&O!*)y@__}Z z;h;zjxqvc3J}4560;N(RT3V;t6Ou4um75uVFn|J%{BNYR%iSc>9STa_xuA^LwxFaw z=?_EgY6XSsd^ghfRBfLO%IX*W?R`1YPgU^SrS;L95aC%GcU5|!+us4nCa*$=Oj&L; z(*`rPn7eL5xtpntMZIN8+~3mTPOxqH8-_nFeaE|W@2?v-$QMxGg&+YK-O1e>NKIHtBZ#mN?P@J7k-CIK z{jeF1@>z=z-a>QEuhFWagsm+16fugVIJBN_UU)`JW~FAkCUUqJ=2yCeBjM4*+(HeH zVeTq_1T%-l_dC?)_}3I!=djEkMnwzfG%8ZW)}oq^-)!U7r1Twa&B=ijnS24cRh@y5 zNuin~09r7dWsE#0BuVyb3)=Z>c-l^c@gaj06`?sSZ%9+C zKQMogFIZ5tu#{`oQ_n~B5y8&6b$<8K} ztQGx%#R}P0q(aZ|oViw7D+9ZtUEt2r&J6SsWtcf~z{k2TRp95mq-561NskBFyCAod z0x2-hBt|}IPECs;BgOb0W%`hs$B&w0u)v_nUO3WI*w7;4SIk7XXxSLfMOm_S>Hw3f!7bu)(#~$AI;nj#pvY` zV2s%2mu7FlV0EsLM-!IJj;)KZlmPdzfWw$uZ!FjyXf2h-6@N-K=bO-Qf+425J+}AK z}%)pVNe)^aLM%r_A=`9Sx3!O!nG@-_CFLaLgEbc&qv&T=k_c;xB_SEdD z`yYB>cGf-d(1{bL>kqgmW=}ov(CLPA-#hO+-FV>sQ@wIEe@@$xz2hrw*M~{|t_!D6 z>&7*l6ZFq-(wGu~5KnlYN3e+xqVQ=dTlA&b_Ps#p*)f=dAzVKE7<7Vx@tF*EWA-zd z&soxdh=LM99s;-mJLg#DE!A(-!6n?tHy*O;i3NZPMJ?eY#%O8;)%1 zi!>GPB;V3@uhbHlt#y&O=5{s?!zjoLCrlKG-h%$9Fk=FktUtkw>X8hFZ9%#=Y? zHb;+qLjp`L#uM7W4IGp2Iq`r>HN7+XqPhcX3Z?BH^ zZ`0Ud(`O+sq&u>R2TJXCwvY=!BH00o2jQXBoMZ?DD0%{u4ka{*k-URCngjxT3GD3i zFbDh<F(=%ed9tAfxl(t+)3l$@2E46#itE>H$_7Lzjyg0y`hY_5OmO11doMa%s z)7C?Yw%(R1k_`dVSoB)swI|?SfN~vs)}{TUl8p5Ixnb-LcWe>#l~&e-WE zcmQFDwYe7Zt@SFHlr}XlB33N`mgt{;3@ZYUm~LpH!_G{}@Azz0NN^&cWCn(f zA#2o7A4j~wg!Dyxz{_L1X~mSNrRc*@gw29pRhgNbo7dJ0cr_HnL|O#66jgwts|zw4 zkHiM?D+^6fV5~R*2MdVl?2+>+`-!>_b%STO1nc%qw*3wO57_MEofU+(_)FRJQ-M}O zn&`thM1H(E3~v;{Mj}X}BL;3bDI|mfhAZue6&sv>1|_#b)ytejVRFcX?M?%QCKy|z zWUdX7=ULp|uOlEPd@dXL3eiXjJT{t%i19%&;U%dP#Dl17z4e}SD2R;vO2R-9FsaWl&{YbGbIQ{qZOhAZKuTAWiv0} zuQTT_*RO^-;D@@e6kDaJajT+r1SBV+IqQYFD@_M=M}Fs zueiTi2BF}Hw%GU zz>~1*LxKZ&HLnBx>&>FE1@_EKc&$Fl;A0FnV4Z;!*7+#=Cya@?x^Ih~q_W;By|K6g zE3ywl($G%1>Myh9I|?tvJ_~G{=04C)t6~*L;RM9=2JXI#L+ACQpgVGqt53dlbSJ3^pJbjP(fyTd(xj zwmx{1*pFBZZz8OQv>zF(5u*{aZNN%@qF7D+Ij%feU-d}@kcNZ<5gg_-^-~OfmI1XK z>SY9>b-C$YQGcG#pJtGtER~==^>cjDzuU!8;^zoa2k|n(LZYOgb#KfVd)>-S03T5+ z$@uDMAH4l=m`Ht&9r;UJ^oCgBODUW)-q%_R+o<*a(ml4DJNXkEP{%k>2h%s94ssW~ zJYu3Wr{q+HS1LOL;zCw&2FWWqL(VpOpyR~DurmVh!XC#o8kU}!cKs*F_SHWFMnQ1N zMYv(jt+ZM;-;s@STai~;q6Q(A+^oga_#+~`P*la1uRC~{Mf%jfQqo$g;4D-hu~;54 z_^|*WS^El8`*r6w@s@m?=63tak+X%atOlJ)V+meqc#h1INKKhFJ+hHKw}wb36OA{a z{;VqUw?kcZ^ct2pv7?DSF?Nb@ZJHS29OuK~7A$JQxFi)}(y<8izO}opdc!5RdW}L# z9e$twQHUuXLW)Bx@Cwu;BGYjmN$yo~6+Eil>Y&SjuF%Ix0nff)`_Qr7`>*x}1Ev8K@7PkAy0 z)DK|SVej>rCcFfmc(ZN+n6>Ok#MJ}FDf^t;RW1%VZ19Zb8g4W^>Xoe9=bIOvI7amZ z_p`QX(MG=Pa?DL{xo>FLD1y;wauky{YQ~!PV>I2V<9FN*?9CSH#JEGf$>_~76ygwv z+Oqk&9f(+TS#Mo?vw4ygZMGZA1<5r}`qLxElK=29`3m3&U z4->txACycQyAi|}Q>kLh;`R=L%B4i}5z$PK1>4p}f^8C?3Wn(_kLh6@-;w5aD1jq<-EX zTHMyQ5Dp^+JwaFv7~dYyHa{BdaxyDyZ%42l;TWv#P^7Kx_IIw2`@7&_u=~1+_yl|h zb_KinBpw61A-E5)B<$<3u`S|2m5yT!TGQTpe+cFGpm)2)%D=alyK_I{V`z0>D(x|* zRUoa^_V-fNBI>aD9My6l7(>e+L(0CF4ev$#?&q_UTieryJ#KA}y#GS52XDU+jO*Tc zUqJ5)GV@W2`i|uI4s(0Ifw!;%e#D?_QtROIU2A)B6zxzj4&DA<@h$j{e>b1j#?hO@ z{yqK?s1tZ33TH>1!5gshQO$^+Pl{*8y-?yBYx{!i+J2|9X8Z35Q2V*e%f|C_3vX&Jnsy~gFVj34XCLNtjpr$`jmf9aOhJwh{tCrS?>+@i*ohe zU?1w+&-9ngOQ(Ya{t1rRsf=fqvS|BIuz&4vQp4!WSv8Z4JL|o-Z`_A5w;XfMsop5I zhy0W14OBb9VQ2ddK$x_B7;})c7}(CxLzy+ZzeniDp=frp>Wg~pTg*;YeI4_VRllio z=-I&v$E%X{zS-A1-M3LVHIBV(dF8hnP z)i;VByzCHoIgMT<@FLIx_}C-xv3UQTA^2tM(kF2HnD=r2ox%R=24cj156W(b-LSXY z1rt7mIowBhGdqC0ZS|uXR?muVfoWVe%L_GYz6H{2dPU%#tC7 zb}>GY1!KQH#%*7QSb;Z3ycKA;(SzvZ@FP!7PT4=HE2RmORnEc$E&YzTMo|UH39lLq z8DLstc7g|hr!Jmfed;_;_8d4M^?*m1M?DxI$a8DO(N3Jms1MX|P#A-<8>-`Z$9jxF zQYnT52b$HN+_Fw`CFJ&FL%QEPk}ebYR;m`}0Ipx<3y_JMh zFaarls0JR9I?L3%7<`DqhY{c$TF!C1(2qj52+Pw|IU7AwZg)6Ye7=?2sy5X@VW{&T zV?}HYx=V>m8*C3J%eLr+xMoehD`gufH(dT6rjkHgIMgpCOKss0p%UC(B*lh1 zl66Q6q~abHR?-4lIEeA|#HvsXJrS6M<9C7)J!Mqogay$zpEw;wPMs93b8kAs)Oqum z$;N4HNFU7x1a{~baEffH)^xmZN2)n;p6D4EWO2Jw4kwhxM&+{AGW8GoAa_`LG~N?o zNj41A!*SM-s&MvNR9`r;`3;Y;&|(e;gSE(UOfPE5#hTF+3HKZT{k$u_Enyz#<={Oc z)>b+Cs65v9oYc46bQ~OO62?9p)H;p0`fL}@^XHhCT(Nc`$RZ1R;ow-i%y&nHLvq*} zP@x@ACMUId;1^luml%A7fq>}WVC-))s5AIm4E{C)f#|=(*xzM9*be2W+FbX<>E6<- zf55i{diBf9vTcJ*B0Ej{EJWMCmfHR^BdLd)Y)17(R5JssnP5ShFz7824)uV62yABG z7PY}KD&D!Wc;QHDZDezK>MydBCIB%fiU=56prZ(l8XDpeQ!gx{c9Dt0yw)~_Msp?$ z_03jF@2^v%ci2Rod&}^yu*{;Fhku3n1Iy~gw43GTunLb1{%Ow)?N(WQJzpr^M*MbO zYN;4Q+2PxkRmEK~h6^~$hqt(`ylI&gw5Ib7o2HJzmt)8%;vIb{O#J0hh9>|mczacN z%G|VWqi;~0^2XmEpteDJe$gikV=(er^EPt6ZsstEKZM`v`f!&NmOq<{TXEdMxs+Sx>bny0`z)|MzOW)R7wm)AiPddF zRY>I{e6E!VHZC59{wHWJ%<(AiHTedM`b{LLe}MoWbvVp;zmh)O21F^Ja^W{zfE6at zb?wOmSgjbVyLzMDkaN@LUH-%ded^@qL974B!HqCThOEB%rUO28^59z>@TrpxlC}7X zU2Gpmf_ENk&QPk@}f9Q z;4m%J5b$Rzz&?dz2fC~Z0&5YzO!)pseiuGigBsdm|3?m3ye#?p2XhzIr>VfeS3sh) zHWdM7^%ZfgP2S+EA~>^-uZk4mGXTZL;t*P+Ti!N^&J%FDErZd=DYENXq~_O$gA$?k z(pORQh<*dx81q+=|1dlO=(x+A+izx2cS-7|sF2ih>3cy5Eg4Bmqfjn9i28Q0?x!(t z>ti}5pHnE;X);Q-LC<46fqT)<(nc^#uaF)jM2g`zb_v%mKAQ@~RuVA>aAWEHhN2J;XAK+mei;(POr`^ z!~ZEd?sUz*#(v7z2PhFCPVqpm5-l10s!nfT$#{>UDe+PGpkHB3Y?;5x*snAADuZ8RKn7Ud$Dn7sB&JmVoIwv? ze2tI)kipj({1$?=b9Y!u;6%Nb89s&}%rD_HHS=C*E`TBN$?}L3xLT2ym=JvKYXjC5 z)9AfKJZF$8Q_7q8GpZ5ww?KR7*-fX7JmPs{9Mq8C#Q-OPU!)@jYH#9e0I^{(cCnls z)l>I*)EjAB>G&-(nFMeFM;fKBX?l51Pfx#*i^YFJKJ_gI^eefuXT>kz?N_jQd|WU$##Tu2tV-!5|coRA@t8iZF7&mIG__;o`A1Dzpot&@KWKXKE!; zXqSTu4sZcuLy0}GNVWEmD7Uu-Lx>NHLVE-%SbPO+ZOpH(Z}&&h!Wfj+)%6`v+l&Td zd=ga+-|xghvR%OSRNKPG45yBJYTF6v+ip?Y;@H{RxMQvD^7m-9?epJ-&Rop=kd(nD7sXLe-)|Red8Ehk7})s7sH4MtwTz(LRik zy(u~+QHLEwyQFdZoN}T*-G|;KwL@Js0QD)3(m|n{w0`Lun6G%0y2i?l{Kipi&LOJI zyZY-A1H++UyE6#Io~fRb)e?nzv{u#i!wzUW)UI5iad0FS9qaM-2>e^%IbEA<|0>FF zlk%SphG;f{^4gbkEPF_fvIOjeaQmATs)kscD((j|GkU=l|mk2%|e?Kp%m`Z?jCL0UZE0%Hw;O)~`-cLgxN4 zUp&mfVjzUAzqMrZJzW{E`u$+-JSz>^~OQ4 zGUFn|g4k~Si^O{osd-`&qJ4fVv3yvIM+?{D6qszYkHnszq%m2YT-rdliCxFXc50=K@oIuUh{rziSBfvwe<= z0&V`aB0gp~6s7RAr};Wl@XZwDsJvdx#c@3Saj~$vYa_XQ3lcBE7CU}RiI-BY`VYu} zlaK%v53Tiet|w7`lfBu{ zxo;tD`0aEKQY4@*xeAEoai7KQeH8(up^VQnE07pLX^pikv5{XZ=e-czm`ZE|4)i<8>OGe6VVAKl1*}t5Co73p&Prf zMCvF5^3wc{RsCwyn}#3duQ0L1;42I^$QgpV%J5QFzDFk^jhjy}rvQFWK>01C^_@8S zHZ#2kfqZ}yicU`?3-t=qH-IDg^PL1o!g;1pg#wF$14V~DxGOj$3IPJb#OnB_Jp}U+ zKeq|Dc49>Jc?4R1h$Mxc2tnE5yUg?781%XzC|36Df5(ewiCqsP0!d)-j*`H4!TV1Y z3&nzBnaO*?Y<;$&KjHW(B!oHW-n-Y-pOHB9U{K+5uEL_kr`m0?I?Qa3AE)M4+Tz7m zZ-wPvTs_G``{aYi_`vbh#t$4%;lq0^eb^1t7a1_PF@+jxMLidH=B_(ek%$V~8X#DW z*dxq8p=H1->w7l?hrz=P&M~;i;5h~#WFRE_i3;7U4tB579nukhoakq-Os$H#&j#Lj+hpPwifzngO(dxnK;VOl%4`(L-FJvUt AkN^Mx literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-38.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92da12cfe1189b3565ab8891b5864282ca857154 GIT binary patch literal 9569 zcmcIq+jAUOTJO`hX^k$iCHW%T$#(3-@g%ks+u0Dj#x~gNuv8+#+S$tl(`of-jau&M z9-r=!Ez#P7$idkyoF&Cl>~3~Jf)~6%fCs9mf&UZia`sip~!}Y`3m*?L*u4(^7FWWB@FPCs<`Do;Lq`sczy~aPw{DtobJta`HOspKacM|$=Nln@na8(VzUia+9K?PevaCK zD9Lg^$s<1xy?K$gy&z5Iqt=I~E+6^9KfZYA;)lOD`N7NIU-`zx4?lkJqrZK1?w2kG z7A@1e96W8@mvLt{XkUQ#H+AuBt}U`c=lXUn zgb_yU$%pu8p}nv1vF{pd>@8MU>-7x{w0>;1wBslVebG{s`iG04w#D=1PWVtqQ)vfj zh5-*jGfhzOa1+;u&)wVq_kaB#|9Y{ZmzM9fqeMRZJI&$jmdHJ|W@*yPN<9k}Lf(x- z_nDRYGj2DznWVRq>Bewr<*5i3N(aNUjvs_&Z9Y!@{KXea+ZTS)3QIldlzJyB4S-P^ zo%D8T265CW>tFEWZunUtQc+q-x7}H8IHgm4s98_EJZgto9(KwR1&jv<5_wtgy@d%c z9fg~hmDXIC`&l_uP09d3+*$Cm2bwqrq(A!H%tG1@XKtoZun^63M7kIT`An8}6CUL4 z8SrK+^+MjB@w+@qV^m_WYLa_AnxF4xQJTzj^C+I~ESKYL|7PgX2fYg3efUIn358}j zy2)(ts;-%vrlJ3*zG>DR{m`av>}Pw}2s^^Yb%&Wu#|RqZV8%GpH!;rI!^G2A<9a}9 zrn+jV&@Ea2EmV4}f1&^bAJSG$uqn&+9PE!4a}5jvX0?*@7-xX3uA|@NI!X(IVoqus ztY8}kwX@}*S@Q-F{TwZ*seZX*1AbFpNh8R zsnWdF^?B>t-=gc`MKJ_0S7xp+gd%jK%uQ04^Qi6eC<{cGM=O2~g1YT6Ur4zd_;DO^ z_vlCLs2HY+PhSgT><-%c+v(y{YK_cqcj8c-a~pL*WG`!e<_qC3mu4Jgd1*<$6XWO- zd#K;`Gm?hVh?2aVm`(G~(E@GkZOD84*AeJz)S7z`wV5gTV-+RY^|s^TevZ9=lr;P98wkv)L|yu=)KkV)>*_5ZR> zoz*u@o499;iL)5-$&B(2dJC;!Q#}SDrkT=OjM|{BbkbH++Pc1>9RNeDf;GS(1~&@p zAlRlb4?>sGv#>~~5z|bPo+NgFW-Uyxid9%(3X_-uZFZp(`$$HSZ3Q%HVa*$07iU&- zi5LM;YKx!gZ?O{7N9?cf<>%a$;m^3fcOpM^x`CJDSHT7OLg)qyAa&v{kq&Xwc^9&H z(!IR^*(P>d`Kgl_g^CEXZk%Usl*n20X&k4wqon0#QLByh-AaCUeco+YlDk9$AeVNY z5^qU<5;JH4uMA1RgKj$$XK0K87{v30;87IKVTG615#dtPmfy>E?*8LjyDzyKka->G zO|V`lvY*8=*&!4f^u<$#%?@qa2J{9SBfUbUDPBac7)22nG_ZG(wjhHkHO6bGw6s+t z2Rp2>E82~hR>#~4_4_)6Sg}6n~e8JDRG@KX5v@9^w)Xf^uaCoKlkNJx-y^eq-)5`D)sy{?lU{(YC>+M2b1CT_btGbjP#E- zY{CjVxQ{Y%0q+B%Lp~72X3TiO{1&X=MvOtdZUFu~(w3!w=4*RyXa!hCpA*P9qMJhFvnh zWO=i8$5X9ru4GM1))iPhD4PV z36wkB>4L4nIVFT(|IM5&>V=lWVO?YM@Gc0wYtpdLQ@Z=Q|L1TkC1-q=6tdC=m0-KK zBN+K^4l57G3#u0&hN2yEiUeH194$#_AUg@Y(9_P?IJumF(;2`EXou}oEW0FaVDIn_ zmR6I-W-@KwGQS&8&(ql>y~2g&2c$`ql9NJJfS%dh^Np6N%}|$2hMR6 zM0va{6^`pe<@EH;pmfj2e(RhY&AV6qI15jb2KMLRxVMDQBa%^g^1;ncie12FD6nT6 zCLH-LoT764a=-g(PctJf=yc*`aRnUGuu7BgEp0`g(pI>l72hpQz_iQA)wjM_V2O+j zlU6sulKXXW4Rw;tfy=<#CEOXgNTEh?byE)w4oSNU6V^0QH@GS57Pla2ws-|Y2Lze8 zo4nAMac8qAV9_*B%NfE%a5uR>uyUHGQw#iJ8OX^NBRAKr4ejgW@O8CCy210S<`=b3 zYd@!bUi%V4#%tOGHVXkA5#7YAcu72quau*bGCdj0$`SRd$RzgCh+2|*v>hO(5#)Eo z>(ppN;#>KG_&mLgL0-Uq)gn99TMhC-4E`?0DwU?&Y#d=lo$Y5Uhjwq_bZCQJX`%U{RhKjSXXSLY@aNtfz#Y-F^y`B~(+!$Yl!d=%kNcF-6eNVpy+BevBM6EW5 z2Y}dfOhCvYXiiB~?}1+G)6bXe&Q7Y;&yx#By5hpO=^Lc_C`wzIc?7{K9Ws4BP7$7D z7f`8Ye=^wnv7`ihOo0ts(F0imgTTDD6ot1vKNB8??toCtq3ySE_d&ak86#3Fdffy@ znK$&=*+xx#fhJ2{8V(Z@Z;85MB**XODX~{+QYpTOv9P~i#-lmX<16aOnq#mkq;Zjn z>A+r#H)zPP;#WCh2YjVdQRSN$nH@o)A#$yod)Oqi;g{BN>mS>;NzpA^Is8A+m4wIf zyacYU=aqHOYs0q#9$q=(c?c!rUe5?({yYj)EGB*)Ula>e92+c1QAibV9~G%#9D-*^ zXR`I-o@X@e>VWYhE!20B{{k%;VH>Bt?`?(4`Y z;hXz4@nFF62oVDR|&;1?T)6K^Ee3@^zK{ z7<6mt?fOCp1KswQ;a9>NCM%yNemr0$5os&)iV8kR`GB7Ps1iOu~YS}(c%WaT3SRpk4%S}SF+Ad+=t~5 z)%}+QsdS*t5z4px3f==Ed36ywWJM%OE-HPL#$h4#b%_$$fNeJ0ml&7HEiuKfqP`0o+9HV!o$BOpagYPi0lU8P zH3|Ex1iSCj(FCWEAwfdZkE{G->@SC6Ko127gp)9abH9~0kQNhuwd&~&-==MnWcpQU z-DHw!;0tQ-Ll?jS-Qkx-+U?8>KSx>Qqba#D8>PjqhW!qJb@b4fO!`z(w@AggYZW zQ@PX1*=xd)1~tsmycWIk`T%p&tUhy(;f0PY)+zi&BJi3fuA`};_dHrkZ^78W_cF6E zhQdeZ4ydK^{%5o+SlP-5eX{aQ5rT+_CM{BNlL~3Qt)58~JH{(zN3o)^T389h(oBT*q7e}ucD$;37es7M3n?GcnIM|L8TFlVrTQ2tT@Irg0V(N)qG5U=q#o}lWAd`AkZnnTTF#zD=$bX>!%1$UpeWX!Hq2L*(+qom(m!gDs&Ve zB}B=CN|#XZ_K=|a=wLSieLJb36M-#da?YKF{?L=8v@c_&(_B5z^qyzAIE2{`jMBuM zp4GD{Gq5{Gtq~|*!v7r;vN05zNeQCr7Qc&@fSRj|q@zzc$|^~UK}LvTuw73M;;oLX z29Sj8d)U@|UU#83p)GPD%c zD}Kn1;zel-@?~;JOfRPZhd|a$-q)U%_zjGvI0#N`4k4zM%5m@_Ts}G!m0&4Be!I^q<91YtipiM}%a16thZulN*2Z(T0hOmyz!7ekm-$P!i;D ziVvueAaB7<8uQ2wxTki*-Ai~NF|65dKre`|2{Ew%Xo9rWuC8HPgzSb|yxC_uRr zq5$35Nfh_Adq|He+p-GVBHt)y;vn~-i{pi{hHb~;EV;JkY=9{P7cDq}JEJ@-=@M?e zL#9e*rM{`Z108%{U)0fNOpy^YbK@Ow7H3t;8Oz$UN^y3uggS1!h+0EspJ8A3;!U28 z@vA>FPm>r^`D)j`LFV?+%H0}nBKlTMyijwET6g)$1+o+6RwA{TaH zXJV?vj--^Zdxrrdc3&}uAb`OH1fTdLDhMX=$NidAf`5Y7}um^si&FDvMb6u@~05fNg{QdI7|!dLs8mcl%$k3{9jvFTG&t29~yRI;DG54)q-i=>u|zc7rk**O$+F5NSZ`UufW(>Ci*;XY8GtWP+$vv<<1A9jZD RS99v)*0|-2qBc4{^52hpRkHv9 literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc b/tango_edm/audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..035b771c97a223a8ab77d3ab7ca2b8bf8a28eb75 GIT binary patch literal 9604 zcmcgy?Qb1dTAwp-cYOU)$Ig4w>@;c9-qvyAq-0ArB<-e67qm4M$-dnxOvd+&ugAG} z=H|>?C-#mPRChOJQMy{euGnrDYy^TJQQi><{s2B8BoJRXpIG=%MT>wHLP%A?{GM~} zT>IK7f-j6cXU?3rd7kGy@8^c2qg4afd!L);_f8tdKhn$g%f`!#xU>DtFu1{4%c$!o zt26a()=l-a+IHP$y3eV*Xfs=twpaI5o7Ea>SL;>OXLyBsZ?XC?Kg+9p7_||8j*swB z)JFO9{4qX`+8E!%C*CsZkMT*q7tirm4Ze@>f6L(eBXe=j;>4O&pX5Kn5Ankoy_X;1 zM=@$2f1Dq~b3gwC_whWypWw&wJc!vR_>-9ZWMnTM!ra4Xd5WJz%Sp5xLCaD8G=BzT zALljxES|^sb9@TVPhjLJK8=yny}3R=&rkC+`0f*&T{CK*d`J|Vtug6}s1t=bY6s#Z z%flp(!#oP+McNJ;X)+%-KR9*e=nwwk{Nc+Ve0lHtuYP~!7cYPC(SslU&1-W%d%5Az zGQG>g)5U!Scjkij1!#ZE6wl_yA}dU8P8x-!N`@AbTlZ~lPa3OE^0yf2VC4H~w+hDH z14eIzS+FBUVIMG7-NNNA)s7gsz3A*UxP8P}t>nhKTU0jKdWG8?Y_(G0k5>+1)M^Fe zadYFV<{<+$gL_mvVjRRbv4&ST*w-Omy>9`A*8#&YVE9W7Z2>q+DA9Wsy(8%T=iR-8 z5k`#XLwvL_zH9KY?^$c?O;$MT)eQr*er#5{tvHE7(bSas`^%uV!}H}%^w30;bQ@`g z0S`elLs0Q>6W0gN-P`|nfAt^#a=B(oCk)zgq8{Ph`fzqz6Pl!emA+5q<50(+OTx;R5TW(hv8W#Y(%m$-%7*${PWTcMVK@rX(kk3N09u`eFneH#6F~aIS57aa?dUX*qR#2pMYYX9dKfjq%uS4N#+i5u^R72Yt&~><6}lCxzllnZ)t@QAu!oIR z8*IsPGY7k)#a;tLfH|GyBF0%@qwDCmxrx$&VAy+&4OXxXi#wBG=N#;;7|H_!6GK%> zhT@(Q1WSAc%r;cG-!;Bxt}$j5F4(L(D=Xb3o=-(v@sPA{cSGKM_d9ewyex(Q=E~{o z3z3NYIP;U#=R9uvJkA;-%Hx$V2YLN=lrN;*Z-lK@#QozRvg2ZyCO&m7YGFUn*55uK zpVDh&VY}0c#09@r6+~)T3A0d$a9P@|ILoD@_)JWoON>*$8)n4)(u$K@PR^$JkJAEe z>}SM#{MHfZt5xcI2ukcSpbRu{^$GA|pql%zlx>}MpC>a}S&GD5nnjW|0zrukp4ZGvBHm-uFkiE|k7@r?2|dJCgq zhk6V`OmoO+GHQc%GD$lfGS8mjc@* z+00~Pq*1^b&={MxEj~%#97WNvP?Q&G3^I~RYrKw1(^$1~u*4dB*|_n-Dz>m7!{rn< zX9v)ZI<}WI35R|t+%+nlHJHSMMq#d*8`$RCBe}J17uE)ubu&BxgHrWzT@kSAkstb` z2)BruFNFD)%tx8pgl4>iom12a@?b6${(LuSkUp4FYw2C&(L9Wlo&g$np@1>Z?}(Il z8#!(HGc%{?Yc)TG$l6c6|bKpW~<4%W6wFMVL>u*PC?^;D3_?S(S zTGWVFq*sD2d|Hjg(#&UYpWP{2lWHpin2n&nZw*YzNCz3iHf*qm`#2LX;e9}K$gd<{ zQ;E(+R0_6e{4`s&AT%(wBpw?kX-)`_10T|AGMHBzLc+L3y7E`qPZ{qh-3l?WXO-G} zz!VvJw2)4=5W*(8%@DgKWEI|T9yhyb7nTHCOLbbUD5<%Ug_7m6ayQVeY)-O3T8}a& z(cN5xjXZ!cO9MD8St^zQentoxtP00N%%V16Gzd}hHWadCdZ6N`peNWGTv9>_UVxpm zMYAw+xT$Nb2_Hivcug6Yci?s^H{T5Z4DO`jj#o%AYmHC}w|^&wvG3-v_KgKn!2mK6 z?TAwZ-~;M-NjU}C6m&yROJnoYastj~1KvP8YNuk^Cy4`#M|Z(U;8Ifmgt*^G+np{J z>uLMBuyN~7D7fDVcd$Rfq*=cY9oRnV0X-loae^f?J-9F3S#-M_C5=ca3M@{~o3zx; zfw@ZaU@kObIRUV4Dfr$TpCc$rc;>?+ngfA&3Vo2IkCbxY+=?4<-da{l#}A=)dOBxN z!slCI^MW7G`&YwO7EO^B4(H&oH$})}l2dr_ja!`*yMWEmV9(ZUIPhIKM{?pyzx!%W zJ0l+GbXv>eRq#m7kv8EgT}>b9YFyEZ?@Ajm?J{=NtuGc>A|vBuRBbTjK~sDIb&}43 z%fj1>xHEEl5L<)jnz)F=;hEap zxo?x6y@?fXx?iz$9FD!M%npDE*F!1kYj|ddQ8Y2(#>6U=^SZG?Y8o!eMYjoJ1JScNG>1gluPeo`G3QWkmeBtEcMLv#a44B>NLC{`VilaM0mfDY^@^dG(hpyaFmevE0i@;d0{tCiPfY{rFa8lVWV&0Q6K5?7dEC%3o4AQDiZ&G|1`ARoN=4j9MQI>-(X2XT zGS&KU{At6uI$$12b@binKgUa9WxoYJ-4=MSgWEUQFlCVtW)}#@7aioAti4bMN5F)O z2&@;8F}hDt7NREZeAVR6JQUBmw_zYJ~Pn*SsL@O z)sMCP2;MM4(xf?PQu^^PS+to`FZLzo_7YlsD-4xSC4`29tM)ug59B0CynC zMcUFa1iVe<Z273|a;s6b?u&J`T5{WcJ-tJZ;NLSW8@dO^?E{&*e(WySH zbFx92w2(r9sL$6OqEe>q%}}2}+k}baNKTXt5%h=rHrlh3D2yQ+?lsa|Y<$y1>H?vv zZGYrCN7*seer(gTk;3>IM($FP78?~?Q z$Dnfu)XH@K6WTSbY-NN#S^B65P((tLQYt!B+@#{fU_m@BzK!AsxYa@)Q}4TwwlIJ$ z<)oBI-!tZ6YvI#+b?Y9mtCI_9iC@6DN8(P4WCu~?#+t&NG!zbr5G^m_n5d+J0}1TH zHffN52IHkod1+-G1i4BNm>dd%Mhg*T5NMcUv&<>hxwW~xj&K&4m;Z3FkxRvLX#Pt)3GcQm1XK^F%d+ud2=}Tz~ zXD@Vkprl31hEA=}p~FLh_QQkS1hnp?g3b`O6w3vF7P`bt66s#SN;6zP&-R|@xHyd2 z4=ib8PS5_iltb7Z%~lAMAmRUx2^lfIVS5N=%Uk>=S{l?`UZtccDvqO7QO^K1n5G_V z*%O1f>w~QUL?KHXm@D1Ssxv;kQcx&zIVCEUVbfvfA5F@#~nfgFJFL zTtc!PQ?ZFdM5aeYT-;W=gR&Bi-1C{`IT3R(9aqj{ zBpT&1v=r5eKj@hqN1-(d`8hcxsh5p_v!H4=?`u{^yo=EkA;G=PA>g!voB(ga8Ke{S z@^f-u1Kwa3=6SNc8EWl2m{EVML9M+xIEAlIpkez2v%BZku|6S-CPo-)R>@X5RqP&vE}LWsfd z5?sHBqKA_BeQNjvDip|DaNk55QaS^0AKDFfFD-(U(hj{!cxxU0%Y&^iH}X%`FW zW4lkI2GhF-4|TlC6BSZuFN`Srz*LE5<&kX}F6mv=rzDq75HTnUI>y{hI>1^vnxt_S zsqkcb-b5X$T-t>k(NQ5sLQ2TJ&43KMulYg{q~HO9Py8Vj1e5rqeog7aKgR1xz@aiG zY6O(X6&kZd1<4K5e3x7cXRg75#Q=m-B1w1W0SkZ4f$ ziT>3Q9iS9wyINycYsaOTX5z~AeHi!>O`WBJ@~Yw%3h9-3T$Slik(Y{9lz*?>V&xY~S9_t#Uy?(8 zPeZu{%Bm?tB1knVu2x8E8Y!x)LJ`WlsMYwTVT{erqoCt#I&{Hpj@ZU#^-0_Zs(Y)G bp6l(|>sOC>L-?zB)d^?9@kUV_of!FVoy}qx literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/latent_diffusion/attention.py b/tango_edm/audioldm/latent_diffusion/attention.py new file mode 100644 index 0000000..47f77e8 --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/attention.py @@ -0,0 +1,469 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from tango_edm.audioldm.latent_diffusion.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + """ + ### Cross Attention Layer + This falls-back to self-attention when conditional embeddings are not specified. + """ + + # use_flash_attention: bool = True + use_flash_attention: bool = False + + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + is_inplace: bool = True, + ): + # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + :param is_inplace: specifies whether to perform the attention softmax computation inplace to + save memory + """ + super().__init__() + + self.is_inplace = is_inplace + self.n_heads = heads + self.d_head = dim_head + + # Attention scaling factor + self.scale = dim_head**-0.5 + + # The normal self-attention layer + if context_dim is None: + context_dim = query_dim + + # Query, key and value mappings + d_attn = dim_head * heads + self.to_q = nn.Linear(query_dim, d_attn, bias=False) + self.to_k = nn.Linear(context_dim, d_attn, bias=False) + self.to_v = nn.Linear(context_dim, d_attn, bias=False) + + # Final linear layer + self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + + # Setup [flash attention](https://github.com/HazyResearch/flash-attention). + # Flash attention is only used if it's installed + # and `CrossAttention.use_flash_attention` is set to `True`. + try: + # You can install flash attention by cloning their Github repo, + # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) + # and then running `python setup.py install` + from flash_attn.flash_attention import FlashAttention + + self.flash = FlashAttention() + # Set the scale for scaled dot-product attention. + self.flash.softmax_scale = self.scale + # Set to `None` if it's not installed + except ImportError: + self.flash = None + + def forward(self, x, context=None, mask=None): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + + # If `cond` is `None` we perform self attention + has_cond = context is not None + if not has_cond: + context = x + + # Get query, key and value vectors + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + # Use flash attention if it's available and the head size is less than or equal to `128` + if ( + CrossAttention.use_flash_attention + and self.flash is not None + and not has_cond + and self.d_head <= 128 + ): + return self.flash_attention(q, k, v) + # Otherwise, fallback to normal attention + else: + return self.normal_attention(q, k, v) + + def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Flash Attention + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Get batch size and number of elements along sequence axis (`width * height`) + batch_size, seq_len, _ = q.shape + + # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of + # shape `[batch_size, seq_len, 3, n_heads * d_head]` + qkv = torch.stack((q, k, v), dim=2) + # Split the heads + qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + + # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to + # fit this size. + if self.d_head <= 32: + pad = 32 - self.d_head + elif self.d_head <= 64: + pad = 64 - self.d_head + elif self.d_head <= 128: + pad = 128 - self.d_head + else: + raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + + # Pad the heads + if pad: + qkv = torch.cat( + (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 + ) + + # Compute attention + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` + # TODO here I add the dtype changing + out, _ = self.flash(qkv.type(torch.float16)) + # Truncate the extra head size + out = out[:, :, :, : self.d_head].float() + # Reshape to `[batch_size, seq_len, n_heads * d_head]` + out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Normal Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` + q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] + k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] + v = v.view(*v.shape[:2], self.n_heads, -1) + + # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ + attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + + # Compute softmax + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ + if self.is_inplace: + half = attn.shape[0] // 2 + attn[half:] = attn[half:].softmax(dim=-1) + attn[:half] = attn[:half].softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + # Compute attention output + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # attn: [bs, 20, 64, 1] + # v: [bs, 1, 20, 32] + out = torch.einsum("bhij,bjhd->bihd", attn, v) + # Reshape to `[batch_size, height * width, n_heads * d_head]` + out = out.reshape(*out.shape[:2], -1) + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + +# class CrossAttention(nn.Module): +# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): +# super().__init__() +# inner_dim = dim_head * heads +# context_dim = default(context_dim, query_dim) + +# self.scale = dim_head ** -0.5 +# self.heads = heads + +# self.to_q = nn.Linear(query_dim, inner_dim, bias=False) +# self.to_k = nn.Linear(context_dim, inner_dim, bias=False) +# self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, query_dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x, context=None, mask=None): +# h = self.heads + +# q = self.to_q(x) +# context = default(context, x) +# k = self.to_k(context) +# v = self.to_v(context) + +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + +# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +# if exists(mask): +# mask = rearrange(mask, 'b ... -> b (...)') +# max_neg_value = -torch.finfo(sim.dtype).max +# mask = repeat(mask, 'b j -> (b h) () j', h=h) +# sim.masked_fill_(~mask, max_neg_value) + +# # attention, what we cannot get enough of +# attn = sim.softmax(dim=-1) + +# out = einsum('b i j, b j d -> b i d', attn, v) +# out = rearrange(out, '(b h) n d -> b n (h d)', h=h) +# return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + no_context=False, + ): + super().__init__() + + if no_context: + context_dim = None + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/tango_edm/audioldm/latent_diffusion/ddim.py b/tango_edm/audioldm/latent_diffusion/ddim.py new file mode 100644 index 0000000..a8c279c --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/ddim.py @@ -0,0 +1,377 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from tango_edm.audioldm.latent_diffusion.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps, leave=False) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO deterministic forward pass? + img = ( + img_orig * mask + (1.0 - mask) * img + ) # In the first sampling step, img is pure gaussian noise + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + # When unconditional_guidance_scale == 1: only e_t + # When unconditional_guidance_scale == 0: only unconditional + # When unconditional_guidance_scale > 1: add more unconditional guidance + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO + return x_prev, pred_x0 diff --git a/tango_edm/audioldm/latent_diffusion/ddpm.py b/tango_edm/audioldm/latent_diffusion/ddpm.py new file mode 100644 index 0000000..9089973 --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/ddpm.py @@ -0,0 +1,441 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" +import sys +import os + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm + +from tango_edm.audioldm.utils import exists, default, count_params, instantiate_from_config +from tango_edm.audioldm.latent_diffusion.ema import LitEma +from tango_edm.audioldm.latent_diffusion.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) +import soundfile as sf +import os + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "film", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None + ): + x = x.contiguous() + t = t.contiguous() + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif ( + self.conditioning_key == "film" + ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM + cc = c_film[0].squeeze(1) # only has one token + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + latent_t_size=256, + latent_f_size=16, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.state = None + # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + self.logger_save_dir = None + self.logger_project = None + self.logger_version = None + self.label_indices_total = None + # To avoid the system cannot find metric value for checkpoint + self.metrics_buffer = { + "val/kullback_leibler_divergence_sigmoid": 15.0, + "val/kullback_leibler_divergence_softmax": 10.0, + "val/psnr": 0.0, + "val/ssim": 0.0, + "val/inception_score_mean": 1.0, + "val/inception_score_std": 0.0, + "val/kernel_inception_distance_mean": 0.0, + "val/kernel_inception_distance_std": 0.0, + "val/frechet_inception_distance": 133.0, + "val/frechet_audio_distance": 32.0, + } + self.initial_learning_rate = None + + def get_log_dir(self): + if ( + self.logger_save_dir is None + and self.logger_project is None + and self.logger_version is None + ): + return os.path.join( + self.logger.save_dir, self.logger._project, self.logger.version + ) + else: + return os.path.join( + self.logger_save_dir, self.logger_project, self.logger_version + ) + + def set_log_dir(self, save_dir, project, version): + self.logger_save_dir = save_dir + self.logger_project = project + self.logger_version = version + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + # print(f"{context}: Switched to EMA weights") + pass + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + # print(f"{context}: Restored training weights") + pass + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) + channels = self.channels + return self.p_sample_loop(shape, return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch + fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = log_magnitudes_stft.to( + memory_format=torch.contiguous_format + ).float() + # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + return ret[k] diff --git a/tango_edm/audioldm/latent_diffusion/ema.py b/tango_edm/audioldm/latent_diffusion/ema.py new file mode 100644 index 0000000..880ca3d --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/tango_edm/audioldm/latent_diffusion/openaimodel.py b/tango_edm/audioldm/latent_diffusion/openaimodel.py new file mode 100644 index 0000000..76f7341 --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/openaimodel.py @@ -0,0 +1,1069 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from tango_edm.audioldm.latent_diffusion.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from tango_edm.audioldm.latent_diffusion.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.extra_film_use_concat = extra_film_use_concat + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + assert not ( + self.num_classes is not None and self.extra_film_condition_dim is not None + ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = ( + self.extra_film_condition_dim is not None and self.extra_film_use_concat + ) + self.use_extra_film_by_addition = ( + self.extra_film_condition_dim is not None and not self.extra_film_use_concat + ) + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim) + # if(self.use_extra_film_by_concat): + # print("\t By concatenation with time embedding") + # elif(self.use_extra_film_by_concat): + # print("\t By addition with time embedding") + + if use_spatial_transformer and ( + self.use_extra_film_by_concat or self.use_extra_film_by_addition + ): + # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.") + spatial_transformer_no_context = True + else: + spatial_transformer_no_context = False + + if use_spatial_transformer and not spatial_transformer_no_context: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None and not spatial_transformer_no_context: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ), + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + if self.use_extra_film_by_addition: + emb = emb + self.film_emb(y) + elif self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/tango_edm/audioldm/latent_diffusion/util.py b/tango_edm/audioldm/latent_diffusion/util.py new file mode 100644 index 0000000..3acc259 --- /dev/null +++ b/tango_edm/audioldm/latent_diffusion/util.py @@ -0,0 +1,295 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from tango_edm.audioldm.utils import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/tango_edm/audioldm/ldm.py b/tango_edm/audioldm/ldm.py new file mode 100644 index 0000000..e560ee3 --- /dev/null +++ b/tango_edm/audioldm/ldm.py @@ -0,0 +1,819 @@ + + +import torch +import numpy as np +from tqdm import tqdm +from tango_edm.audioldm.utils import default, instantiate_from_config, save_wave +from tango_edm.audioldm.latent_diffusion.ddpm import DDPM +from tango_edm.audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution +from tango_edm.audioldm.latent_diffusion.util import noise_like +# from tango_edm.audioldm.latent_diffusion.ddim import DDIMSampler +import os + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + device="cuda", + first_stage_config=None, + cond_stage_config=None, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + base_learning_rate=None, + *args, + **kwargs, + ): + self.device = device + self.learning_rate = base_learning_rate + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.cond_stage_key_orig = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + self.cond_stage_model = self.cond_stage_model.to(self.device) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + # Text input is list + if type(c) == list and len(c) == 1: + c = self.cond_stage_model([c[0], c[0]]) + c = c[0:1] + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): + x = super().get_input(batch, k) + + if bs is not None: + x = x[:bs] + + x = x.to(self.device) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox"]: + xc = batch[cond_key] + elif cond_key == "class_label": + xc = batch + else: + # [bs, 1, 527] + xc = super().get_input(batch, cond_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + def mel_spectrogram_to_waveform(self, mel): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + return waveform + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + if self.model.conditioning_key == "concat": + key = "c_concat" + elif self.model.conditioning_key == "crossattn": + key = "c_crossattn" + else: + key = "c_film" + + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + @torch.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + + if mask is not None: + shape = (self.channels, mask.size()[-2], mask.size()[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + raise NotImplementedError + # print("Use ddim sampler") + + ddim_sampler = DDIMSampler(self) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + + else: + # print("Use DDPM sampler") + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + @torch.no_grad() + def generate_sample( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + if(torch.max(torch.abs(samples)) > 1e2): + samples = torch.clip(samples, min=-10, max=10) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform + + @torch.no_grad() + def generate_sample_masked( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + time_mask_ratio_start_and_end=(0.25, 0.75), + freq_mask_ratio_start_and_end=(0.75, 1.0), + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + + _, h, w = z.shape[0], z.shape[2], z.shape[3] + + mask = torch.ones(batch_size, h, w).to(self.device) + + mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0 + mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0 + mask = mask[:, None, ...] + + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text) + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform \ No newline at end of file diff --git a/tango_edm/audioldm/pipeline.py b/tango_edm/audioldm/pipeline.py new file mode 100644 index 0000000..e35dcf7 --- /dev/null +++ b/tango_edm/audioldm/pipeline.py @@ -0,0 +1,301 @@ +import os + +import argparse +import yaml +import torch +from torch import autocast +from tqdm import tqdm, trange + +from tango_edm.audioldm import LatentDiffusion, seed_everything +from tango_edm.audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint +from tango_edm.audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file +from tango_edm.audioldm.latent_diffusion.ddim import DDIMSampler +from einops import repeat +import os + +def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): + text = [text] * batchsize + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + + if(fbank is None): + fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format + else: + fbank = torch.FloatTensor(fbank) + fbank = fbank.expand(batchsize, 1024, 64) + assert fbank.size(0) == batchsize + + stft = torch.zeros((batchsize, 1024, 512)) # Not used + + if(waveform is None): + waveform = torch.zeros((batchsize, 160000)) # Not used + else: + waveform = torch.FloatTensor(waveform) + waveform = waveform.expand(batchsize, -1) + assert waveform.size(0) == batchsize + + fname = [""] * batchsize # Not used + + batch = ( + fbank, + stft, + None, + fname, + waveform, + text, + ) + return batch + +def round_up_duration(duration): + return int(round(duration/2.5) + 1) * 2.5 + +def build_model( + ckpt_path=None, + config=None, + model_name="audioldm-s-full" +): + print("Load AudioLDM: %s", model_name) + + if(ckpt_path is None): + ckpt_path = get_metadata()[model_name]["path"] + + if(not os.path.exists(ckpt_path)): + download_checkpoint(model_name) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config(model_name) + + # Use text as condition instead of using waveform during training + config["model"]["params"]["device"] = device + config["model"]["params"]["cond_stage_key"] = "text" + + # No normalization here + latent_diffusion = LatentDiffusion(**config["model"]["params"]) + + resume_from_checkpoint = ckpt_path + + checkpoint = torch.load(resume_from_checkpoint, map_location=device) + latent_diffusion.load_state_dict(checkpoint["state_dict"]) + + latent_diffusion.eval() + latent_diffusion = latent_diffusion.to(device) + + latent_diffusion.cond_stage_model.embed_mode = "text" + return latent_diffusion + +def duration_to_latent_t_size(duration): + return int(duration * 25.6) + +def set_cond_audio(latent_diffusion): + latent_diffusion.cond_stage_key = "waveform" + latent_diffusion.cond_stage_model.embed_mode="audio" + return latent_diffusion + +def set_cond_text(latent_diffusion): + latent_diffusion.cond_stage_key = "text" + latent_diffusion.cond_stage_model.embed_mode="text" + return latent_diffusion + +def text_to_audio( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=10, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + config=None, +): + seed_everything(int(seed)) + waveform = None + if(original_audio_file_path is not None): + waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) + + batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) + + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + + if(waveform is not None): + print("Generate audio that has similar content as %s" % original_audio_file_path) + latent_diffusion = set_cond_audio(latent_diffusion) + else: + print("Generate audio using text %s" % text) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + ) + return waveform + +def style_transfer( + latent_diffusion, + text, + original_audio_file_path, + transfer_strength, + seed=42, + duration=10, + batchsize=1, + guidance_scale=2.5, + ddim_steps=200, + config=None, +): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + assert original_audio_file_path is not None, "You need to provide the original audio file path" + + audio_file_duration = get_duration(original_audio_file_path) + + assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path + + # if(duration > 20): + # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") + # duration = 20 + + if(duration >= audio_file_duration): + print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) + duration = round_up_duration(audio_file_duration) + print("Set new duration as %s-seconds" % duration) + + # duration = round_up_duration(duration) + + latent_diffusion = set_cond_text(latent_diffusion) + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + + seed_everything(int(seed)) + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion.cond_stage_model.embed_mode = "text" + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + mel = mel.unsqueeze(0).unsqueeze(0).to(device) + mel = repeat(mel, "1 ... -> b ...", b=batchsize) + init_latent = latent_diffusion.get_first_stage_encoding( + latent_diffusion.encode_first_stage(mel) + ) # move to latent space, encode and sample + if(torch.max(torch.abs(init_latent)) > 1e2): + init_latent = torch.clip(init_latent, min=-10, max=10) + sampler = DDIMSampler(latent_diffusion) + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) + + t_enc = int(transfer_strength * ddim_steps) + prompts = text + + with torch.no_grad(): + with autocast("cuda"): + with latent_diffusion.ema_scope(): + uc = None + if guidance_scale != 1.0: + uc = latent_diffusion.cond_stage_model.get_unconditional_condition( + batchsize + ) + + c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) + z_enc = sampler.stochastic_encode( + init_latent, torch.tensor([t_enc] * batchsize).to(device) + ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc, + ) + # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output + # print(torch.sum(torch.isnan(samples))) + x_samples = latent_diffusion.decode_first_stage(samples) + # print(x_samples) + x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) + # print(x_samples) + waveform = latent_diffusion.first_stage_model.decode_to_waveform( + x_samples + ) + + return waveform + +def super_resolution_and_inpainting( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=None, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram + # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting + # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins + freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution + config=None, +): + seed_everything(int(seed)) + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + # waveform = read_wav_file(original_audio_file_path, None) + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + + batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) + + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample_masked( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, + freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end + ) + return waveform \ No newline at end of file diff --git a/tango_edm/audioldm/utils.py b/tango_edm/audioldm/utils.py new file mode 100644 index 0000000..5401b29 --- /dev/null +++ b/tango_edm/audioldm/utils.py @@ -0,0 +1,281 @@ +import contextlib +import importlib + +from inspect import isfunction +import os +import soundfile as sf +import time +import wave + +import urllib.request +import progressbar + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +def get_duration(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + frames = f.getnframes() + rate = f.getframerate() + return frames / float(rate) + +def get_bit_depth(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + bit_depth = f.getsampwidth() * 8 + return bit_depth + +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def save_wave(waveform, savepath, name="outwav"): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + path = os.path.join( + savepath, + "%s_%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ), + ) + print("Save audio to %s" % path) + sf.write(path, waveform[i, 0], samplerate=16000) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def default_audioldm_config(model_name="audioldm-s-full"): + basic_config = { + "wave_file_save_path": "./output", + "id": { + "version": "v1", + "name": "default", + "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", + }, + "preprocessing": { + "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, + "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 8000, + "freqm": 0, + "timem": 0, + "blur": False, + "mean": -4.63, + "std": 2.74, + "target_length": 1024, + }, + }, + "model": { + "device": "cuda", + "target": "audioldm.pipline.LatentDiffusion", + "params": { + "base_learning_rate": 5e-06, + "linear_start": 0.0015, + "linear_end": 0.0195, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "first_stage_key": "fbank", + "cond_stage_key": "waveform", + "latent_t_size": 256, + "latent_f_size": 16, + "channels": 8, + "cond_stage_trainable": True, + "conditioning_key": "film", + "monitor": "val/loss_simple_ema", + "scale_by_std": True, + "unet_config": { + "target": "audioldm.latent_diffusion.openaimodel.UNetModel", + "params": { + "image_size": 64, + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + "use_spatial_transformer": True, + }, + }, + "first_stage_config": { + "base_learning_rate": 4.5e-05, + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 8, + "time_shuffle": 1, + "ddconfig": { + "double_z": True, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + }, + }, + }, + "cond_stage_config": { + "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2", + "params": { + "key": "waveform", + "sampling_rate": 16000, + "embed_mode": "audio", + "unconditional_prob": 0.1, + }, + }, + }, + }, + } + + if("-l-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256 + basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64 + elif("-m-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192 + basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST + + return basic_config + +def get_metadata(): + return { + "audioldm-s-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full.ckpt", + ), + "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1", + }, + "audioldm-l-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-l-full.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1", + }, + "audioldm-s-full-v2": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full-v2.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1", + }, + "audioldm-m-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1", + }, + "audioldm-s-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1", + }, + "audioldm-m-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-full.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1", + }, + } + +class MyProgressBar(): + def __init__(self): + self.pbar = None + + def __call__(self, block_num, block_size, total_size): + if not self.pbar: + self.pbar=progressbar.ProgressBar(maxval=total_size) + self.pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + self.pbar.update(downloaded) + else: + self.pbar.finish() + +def download_checkpoint(checkpoint_name="audioldm-s-full"): + meta = get_metadata() + if(checkpoint_name not in meta.keys()): + print("The model name you provided is not supported. Please use one of the following: ", meta.keys()) + + if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9: + os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True) + print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}") + + urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar()) + print( + "Weights downloaded in: {} Size: {}".format( + meta[checkpoint_name]["path"], + os.path.getsize(meta[checkpoint_name]["path"]), + ) + ) + \ No newline at end of file diff --git a/tango_edm/audioldm/variational_autoencoder/__init__.py b/tango_edm/audioldm/variational_autoencoder/__init__.py new file mode 100644 index 0000000..363a2ca --- /dev/null +++ b/tango_edm/audioldm/variational_autoencoder/__init__.py @@ -0,0 +1 @@ +from tango_edm.audioldm.variational_autoencoder.autoencoder import AutoencoderKL \ No newline at end of file diff --git a/tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-38.pyc b/tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..415a44506480809f62cbfa684d6292550dc3ff79 GIT binary patch literal 267 zcmZ{dzY4-I5XO^sP=tc7(WTj(M08d}PzSq&&_o+;awSa#AIVqhsFSmslUKn%iwF08 zcgKDF)=|_WSeMRbyFkCT%Wnm^jIpaeiX@U2#L$#dB7-G~!y8^nOq*y@dr&q5soZon zkArmJh1~&H(j=mm8VVea!fD|(SW$5C4dBBk>{{TNxDf3nj}^5UPoLFIzyJUM literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-39.pyc b/tango_edm/audioldm/variational_autoencoder/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b14ebed636cb1eb73079742be13481f06ec412e4 GIT binary patch literal 220 zcmYjLu?oU46ntqHQ3`&A(8c_Kh|Y=>+`5F&mtrJoN}33MlE2i|$=OX8Uv;ntcgNju zoXoN*5Z;pGg8V*&F%k+JVmBv2KxV+Qk}<&K2zmNoROhVEwlO%>R1THj?GJfU&g8%G z!^#bOO7@jk!B5UAj!Lz>@b0Xt&~X`2JAt~cW2as7%LXmKNUvp}N*Xbmf-B8Rn2&l1uGMvMeih?4;WwwYzB|x9L}!rjFe>3fee2)J}^KK?#a8yQ0L8 z^$b_B7gP_Gf%KA|dMHo?SRL~NdhhR;TajLLYkMjRwC?kgtM9l#x!}wBBxjy^-e=y~ z54v67!1WjR!@<874dV;SoV;vgZsS%zhv5cisS(iAOxXi^TB(^?ftA^T%_wiDt;`9W zI^9a$%nQ7%9kiM8uE8DdJ~6l}{H?h?E9h``-{^Y>c(1YEXPUj-H?{jt9v56l?cNo2 z0u|TpCeg6Sqx9WqtW*-^cN100WMf<=MUFmmLs5o1Mboc4ON2>25HhBhpS-%WjQTwi z!Uz~Q0+TcR$XfUzJ;XHev$-|seqsz5Z*k{IJFvLRy(dOs^EUVKZ1FkX!PDU{@Oi%Q zgas~NJf4vpYpO_V;TevOU>d!Exv1d%# zzQK&r92gb*9jlDcD($0mt8^-7k4oLj*m61Bw&mT@<0j=+FhulhZ~4E!d^XgT2Td*k$g1 zM6EC3Ua7jU&vAoTu`Y?0lWQL+(Np6M0Kj`3jL1fXN@(4CUsnevnrMF##+y-|i&SYV z9*x)f=K6a7vTmvI2t4vaFd-?!P!c-Yo#85p*V?6565FgKFSR@SkZz465RRg9Q@%*~xn!6ZvTmw^#kfcdnM_2Or_lzG zl-H=u(nh32n2Jc|Nj?l^R0{c3RD8~KI}GDAQYs8TH~#g@o0~-@Zf+MzyqVk_$zn^y zYghJt%zM3W>!pDd4?{Id!KxW!B$U&Tm&wrU`z=WW*6fjHn*mlo z6M3n0n+PLhQbZ_RJ=Q`eqgrZNX2@jk)oiH_=;pFh)IM|J;b1X6HssIMpM5 z35kkLNmZ|$VCmjr6C^l`KkN;rFtonssm^Uls^sTpWk79kpoK_RoBpsOOUHrO)juu1RqTAc6j zAf|;lyr?J1Og+teNE|JURhjZ~WsyYh+PR06NhzAHe515EtL$lO-`F#;jgmY5kI?=-92O7hXT3>KTd;W1Ueg+}r_WC+Q#sJX!D<4E#%QKyz*hz2(EO%Gl< z58w37$)(e<&3c?-1VKeBS3|Xf1fT#TwoSE+6j-rs0Wic1U<#o6$lf(~tz8>%AkVFR zVv4bX=mR}&m)37s1r{NFuTHZmAO*Wxl}WqOmO=Z_mhrr?Z%-X;^DS)l&ORKx?=~C* zKINc1c#7+oZU?lsz?kGz5^Lp;VD$VQrGz9#y(1-D`}hq_o4uP7;egtKIGF z;C_ZTrMQ&|V|>U`t1nAGj?qV$p4{R$aq+!_fx;$FC-yj}t!TI550q0!13kB}rZ)m|OtK{E9c zoi!+C+$f6e5aKHz>$XIMBoX8Y?PG^2li|1+t2$C@B-U=$fJ9VxMpE!3#!C{euhpn(B1Vr=as65-BPINrQAY#wag55<1RAwdH zGv#ulvb$z}y)Lu1Y?NNF7)q8)s3lGhY>)jB3;7v>Bd9G-4VAe}kvq1fs|{7;@g~wM z2EFp}Nc3(@F5T#1w!O!q>?O2IQD`K1-_w4gpoEAOV^IfD@`o7GY>^&_Tp@hExL%Y$ zJ?>QNt+wgKHQ0v%%Y}2_Z}zV{&19vj=^1ZLNQSP9HfGAPlu15|_pZDZ7qHFc0%`)W-V`8rMe$7E<1sAWr% zJTy?ac@Yj}#N~Tb^dXrLMt5iNcr4NsXPx7zeuJ7&kSQrhqb*305_*=&p7Hy~{L7x_ zIbO?O@D@DRvmEBzzF+t6cbj3{IEv|44`ce(v<^ma;28|Y3ftkT$fDr2FKPIU;D5Ej zN7VF#OI*ki}>=FWg1yAyV8vD~9oHp^xm!3GPFz{4n_X;0ORd)${v zRgaU5+6ZAw%ggcy0BJ1k%f9f&f8aON>+<3^;0f_S_)houPM}A9`qaJZtM8mTHDR;q z8+d*_T;qRg8O9fsIQ`g2+{dea3c(G|Vxvo6GiFcdYsF?_b*;qi+Kke6TuYp;Q|h(Y zO}ws`)VpPy?vu=^E&tOt?@bDz}Mk7`8;3PXI+=K z_#$#VzQix#Tj$ICGQPf;n>G#%zQW(wH}+V!IZNN-SNXMlhMakRo!>ysg-;CI=(YgN zmvr-kah{0^kB=X1=vsd>Ot%jiWnHLB9Wu1aUQ&3qn=sO_PbU;$-D&?@G*F=ldF%ptW{)F77%5AqUsAeXuO zlxknbyHYeE-{1zZWJ3~7CpUhkL|ctF0R(SzFeDjeDk2`VzorgNRMGw<=xv2*Dq^Lr z-e|nuF*i0kS9DE{N8ptgfEiI91d>qF?hIc^Jl8IDN$j(tG}rEIAYB_tO4ssC_O`T{ zrrP;1OLyMo+J80<(>&-Qr57-jJUGdaPTA!VG3z4P6g=QjqTM9J6f&tBd6bAiZH@c= zSm>4%aTfBRFS8`*ZIAMsBwp7k_L-oQ~1pTmwE;{qtRwC|ek8{kT9}T*G zoGDf5x*jFrKm^<3rM!;W$eScC(Y)$3KrjmPEqRO5bI~BpWLZ^pTfHpKWHb>$8i$)e zQr@6COPiq*K`cUZYg<_&*0>O(wV%jrD0;bC z3&%Xlu>5|1tgv5II_AmRPAH?0wlIu?@PzTGaLoSI(M!32frta@D-cG@H(RV_`liq3 z&0B26Y%s_4@H&is4Rf+`#(Gwdw!KFrIeLza zkWnTjja!MYQ$?p@Z!eJ?pnJx;6R#N2gDR(1(66^ z-2ajB8&-ft&>xjLn-Ztk)Gkaqk%k7HK||x3aR8Wcr`THDsAEfspYW!5uN8zs?MP(-$Biv z<5epVL^j)OnO9iD@>%PZZF>JUU!5o4MJIFtPy)3C^eGft@aq62aHt%By*&f$JjfI! zftS1n$VwOim<90A@&q<{9|N5LCdq!}2M`xgBb+2nPS?lj%!p^Wb(=E>6qK=#e1KYA z2YEQ-9l|QD~Ku(ZWpk*zm+} zYC9D#a_x^o872ZwM`?$Pov0@|thA={llvGw#dwtoVf+^tNo~*UA5}b_2rcQzcN{ry zR@!>Utj1eKjdO$QrRuDhnZC*q6lADY&A~j+r;C->d*r;19@6P8k@*D_xUhBd+X!)C zY4Uasgpjt7hDCmAPHng}qE&6^3|+*Cu+C3S_;n9D;=^g3$|(XI`qyFC%Vq89938+V zr4DJhG6ckZz(Nd#Pp!1F$q=S07Wcd$I8hs=sGHvR(1p6o+S-dcQbv&H|_6lv~ zDR8ix*jU7{zrw(Z0=6n*p&-&^^RH}ca{KH-^|}r^>`BL5U+*l^#Uh0V2m`$S#RrLg?l!)b5Em z#x>?-sc%sgiXkP%V07Xr;-P1m>@~lC%-``m&+%&hg16wgp5-tfZ^JKp=rpT|+$=+P zgzBqt8DpFT@v9;UyKfyAzb2AjtgVgTt&5E)lTdJ3s;mqutd2inPy^sC< z^mNt2^Uul`oikIG^)H(2e;iC!(e!W72}`iVYL@@3$!}YN3;U5J>}}q(F>{2AnJevW zyXj!=i3;Ww>0<5)Urao*niawBTD9skoM$y6>i0h!lT|c*4xO}`Oju1WnBdrH3tKq& z2N#^+74x6Ab*{oFUyJb{w}*M0r4hr%@k3*3x6m|QG`F^`Z*b;a>nUqIWyQ%xmjAjx zNMv88xfH)rDpOc?v`jirE#+g&_g~c4v%aivWO1|}*9R)=$tbVuY?z8D@7MD%?Pe`0 z`t@)q;w-`7eyHM*gb9;YILtGdMwyVRF80Z(7Y7duPlv?ZBk}0w7TY}gx=SsVbZ>e$ zW#cGCrc-pZxklMdU|KkRNJFDB&oi&(NmPxRTllS38un$YRa9FoEDjSIPqkY2hG9~^ zQKVZnd5{+=PZ1kMS98%ciR3fiHY-Ql{ymT#ZF}JX1Fon51AcL0Ed~M(0Kg~`C(uJ< zYI^_}v60mRfbUvHNwBqhKTL+Q##IG$>I8L0@;&iM0dX4Vt=0@)j&=6-D~uP%WJjBc ztTD2BWA18}xV^?iMxP|gnA^gu-2NXV-k2AeZGIoGQ?K$Wq2MV*2un>Jo|hw*9c72^|RSJCt@(d`+3)&v*^*EgXBu75&Qg#U;E0287rCcAcri77FS z6;I5Fw*ZZb@Bodu;@nyscF6~CheNI7Fx`)m&mv+terQYpy@jU#h^~Z5&U$>rDEbKd zvHc|nbZj&yXHQ(1(IKIB%*u$~c~}Q?dNXhB5YT*Dp}Cn&VC1P0)~XoUJFGW}f2%ho zDvy~1V6;Hd@(G+ZzwwN4`;4L&Kx&J)-^CEm%HMYfmskH?uKeQ@P$J7SMwD%MLyO6InUCx-j027}F zgbwZ)VK-=KACz>K_PFI}{~1{EFh6`!3CLZvK}8D?RgrCXk@%la)Q>QGMsB1|p{rF& zTo@gDzmuuXlu>=yp}?X8F2Qhtyo~jpXMb(qtAw(G>U5pJV=W-z7{{ zUvSQ!bKChZcbp%%>wmxb&apPwNQ|6DbVWxW2m+LxDrm+Uzr_GLI%JL9e&Rr{cF!dT z-+~=r6>miFtDLrA9iIY6ZvvWjjD@Oa0m2a%ZfTR6*H8rAUAO22;M?pQ_9qM3iHcUjh27z6);9Z+ZxM1_BnlhY zU<-sBhu$v$Q&<6BfS=0Aa^0T9Dupug>6#y2Kt`wrcu;_$iz{s*u0+?w z(o%G-bNQo7?d!{zcel4*Zcj3;CDpCx?<_4F8E$3$!7!Hryf^s#P9Wm@Qgu=7%AhUt z2U4a1Fc@@q^RHeso9@d1w+*xQmG&GhgiX( zVNgb=1IQky8iIxX-z;GhvmPSNLO>!6*`-OaVf;ir!b_V4;1skqkf7vHMi+f zd9c*1l=Z<<_eu$`(@{pui9KmP$C~~bI;-L!q)agfVP)2-a_@Q7p}K@6^KoE`#KHqdd~6vXHRL7xgXKoy-z>jMR;XCo Ob#LFE@n6u+S7AT;I3^Xa=LoYq`VCAO2qWwGeRG^2#x1NL4?(fZ#vLsuImzdd^w{PaXdGGh$ zJo>Y*Jv3Ss!R{FK>J#i|Gz04QKOB=)G<6PLXtbCxT3j%}vC`a!l9;T?7IEJ~ArVd3PVF}0g$iVm6?+r~H8^Nz90n!Btx*-W!P z42E$y2$L)nzn3zVAX_SoySs*TvE+v@8tdsGY=|%%Ha?YUKMXR}@N*HRZ4q_5xq{Sk z9*aTao-ZSx6!GJ>pJ!>91gQvRL+sO0Ee;#2NV2P( zJ}8R}CrON?quFSRRC1YXYRb`me;aH^`&~Hjzbz{8zgxV%7Qy=moUaXu9q6GkwLRyL z*vM!Dymt(3B>36B=f`W+8)Yha-T ziqs+3!}&ye!TAK)RWwE2p7ZB)a$yL2omLR`uZfCqAF`I+GsM*Qtjk1IOk?hd88HiB zRD=Uy%opd@BELrwc+1a~iu`0hMm`Ei+~lD#!Sg1X`V+bmBpK`T5u>~#%tz*z9KbQr ztc*Rj;Y5p6+R>B|eRDvEaQd=f>=4XcqR?E=rZ95k2(&6j<__ym<8Ab3MCB2)0E;#_ z+Acw}?lzwgYM)Tv0z^#_4SE>jS^4L9aC!AVCFNfqg2NHvy(kFlvqg32OGIlHN>$id zDueaH4%YKz17RQfN#Vq45BUki7`dmI6~sCKzEu_lg&BT5EKHS&!X&2ZW++qDj-!pR z;MnX9h^(7sq@2ZAzD3=^aM7qZ1G*wm8O&idb^`qgHiv$OInOrVInFmP4ctTpZwzVal1TX3&v?;Y~^U)Lr=tyySW6=xQl!=_ev7Y_;wW zL-{Vz^{MnR%Y;y{uMeM(0RIk}qH`PU6i!q%H_sn?rZQNY0UcxV(3qk^!|rbw&@to* z00fs^kL^Hpz|Y7Hd0TLR(Zm-EJh%dP0Jy$=8kcR2Ryj2T(O`eig!70m06+)%M_60K z1&viO(OFtym!thVfa07!ybpn~(i|fV=BYtVNq!xp@V_7Kn(}88;uDG_c?w;mz3b;o5VM^vYlq@p*&>if^WpX29qMs5RhiLi~byQp1L-{s``v` z{*;^6|F~uS$Zh9`%{PyA!$vG=!voPzcTxpSyWukoV4g#6$jrwUOltORitH^o0Pb)` zguBXV18#9CWAvwBWJ|lIdKTy#abcJ4s5?zbm>nCveCt@LbmP?7M|IXj*dBT9Tp~GFs5`RuIEk-> zT3!0NtUBX5)syp6lIQ{Xe)3RuA-xI1>M!U7`wm!uXZgtNT5HBvwJj?%M#%qtzRl(Q z$odY2<=qQ+=f4GS&eZkW2)u3f4g0f!Yl-Sq!iCx8YsNPFl5Y`mY}^ngu)%7C8y1V)%GUGFVvr6R=`c+E zsL{(K5sLSo*QK2}?KGCJ1k0EFt6jfyE$CdnvMjp45g&Af_#n6{mX?C6-OKM^>RemC zJYK%9tc_EpLMm6!f3mc!b-0-hhItlxh*|IBPdpLb3uO;QZs>Kw?0y&~9x&+j#uu+z z)Qj#=l2x8BLyv0W0Se^6kK+ekN2b0Ae1+QkB97;E{47`pXW%E;&%?=ZrG>)F($WA} zqB0-W(3eaZ&rVjR zK+0rwUhu_BvX0>{lh^PnLq*r-)f+KxVmkB6AMr(3%ns+zHh-R^a)AF68}Z_eIE4+= z3I-Eiolqt@`5!t%BRgEBy(tB2CFCzLs@pA_>VTzIr7Q%N%K7qSzFe4Tr(bIHTUx5s zn%a~8F=*;n=!}Ym6f(msq>*{6%AKcGn=|~zNqGiI`P7g6nWVIXATDp*Bp(bP$R&{U qe*j$z77irhAC-SXUJ8jq=aO3ROkpaCXf(cYb>Q{r>}-UMQmg literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-38.pyc b/tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48b132c00e1ad9b7cf12ad6af9601a4653d3debc GIT binary patch literal 22361 zcmch93zQqzdEUGL3?7RGu}@MYHKZs?D@!Cr%dV{|iX~B$)JkjBqEmY#3js0MU0|`m z^3E(Jq1ZSzWk+$#R!-8VyzZfZqwB6sp50mb{D~I~f1++TCokDcB|wV)Bp{P9WFH_Ue;*+GVn`W~ zvOfjLR1BE{WZItrWG04812XH+0Ww$5qum+Y&-?pvzdx=q>)+$wnLLfZXQa4#@2>tXGy*XT4JEw3izzD4Kk{(MHV&!my6=PP=BY zh4@EwNaK14SMV!HLZfFi%}b_lu314jwBBd-Oy4?SY+0eb?fCW;vuF9{ym6DUWrwNl zbeL*pdKOaGcP?8!J8w)ISIjNDXE(EGDd(py^Um1LUorf&pMJmPrdH9`M zZ3NzWHLR`H120^ydyVytR(&1AuW|ysM%z16Z+99&qrKvtX)G^q1_VE;I_InB8?8nV zHfr7z9lzcR?v)15VnV&M)rJb{7fuM|!Ocd?_t3y*&}sP9mRDatU-x}980ZS$MIW#p zK}Vg||L*k8Z-$b(gMTevE5e{YL3F!vJce{P6h=qEFhdUu@Lsi&xFa_8aR9 zj$&PvVj?*v3+2de)!R`jSgmfB)@jsM8z(nZr-|7*8FV(=el1)- z30hg{ROg^Q1eWxgs(>P~EoK1Kms;PZ?Lu z#RW^10FN?3V`Uxqg%vWNlX2$&#kdSyj*HZko5=4-3fb{lkowFM|$ zU0y~H7HsLR+K*PF>?5qc*iq{XR+MkFlSkC;DBd}UEGt%9jc%R24mgvBnKms${;sPN zxUUg7{Ilbm;21J>gK}%h`e9=kOv|@kaF=Y~_8rI}=Z7)RODPO16}j)+2&!ZQwf@s= zg!xSTKRAs{Xk5ZlTro_eXZ4HMC+8l}5V}Rw_~U*+#fZc71VZ7Sj*m{9aFQYntZ;LV%IH1<6R#XUT)fPlj zwFa^Wj{r$sO%DdD4&VWzNmy4Koff2*j8EMQ@KsCQhkxo;CSH;V@@JMi>e;ID-vn?# zZep0z7--JymXZO+Rj#Hv+3o6j9h04_R5U#!U#L_rY*t&kM5azXf`+1rcxG$$R;xd| z9tza=FyZUfNhEjT3P{JUS#SkX!_CN#c){OmLDZ>!|IdNpJtgRK&+CA*poq%@oJ}oW z6W%Tfxq`Zz31?INc_uPck2x7H@rt0m@#^51ZomF0wqUKG7xd)Z3Umz~l^* zvrN1sk-l?HAI3F`CO3`iZ;UbO(WC{BxxumFh=TpKVAL*40{P$VICn&7D5rxtlBFw-GZ z>9dpebX+41rjw#=ahXhqOeV8 z)r@8`l(8bTgbS&gcuU3;<*FA~DjS_n>$HE>9%4dL!(%l`67>4Z041`^F=x$!HEnjw zLvwUP%z8VTc%03|93`g+d0jATX3&rSw>(8HqRUsUYaD{cHQ_He#y=Ve>8C0gw{XiM5^A81S2DF2Rn(5{CMwmP-7#j_1)W5$9V6*`+? zqRvGVR81p${jeiNZ?^ARox!vT+4pL^7mhf$pJYS-!F6s)IeVjrk2wqUCxi-&u3iGRwE~y;RS-!slU3_AGgFLQ0A~ z2ga*WnT*@vTdC6Y+bL1k!)348DiRgvGkd&oUweM2ge|}@mcEolBTOx7c00g+;>^O7 zdNclMMT1J3dKyWzZ!@Ti_LoehD7{gt4bAxOAx-J6texVd-yNm3O%JPxlTPbK*FPO) z&Zy1?)+x&PaYi0Ft=_}O3vDRm{UuvSYu1+ZMe3+!cI_D^W0Zhr0S`zEMwZHDiRz`r z+(5;2Z5QwyOrTL{fpv%K<(%QW?(`7K4QQJv)Bh2r6Or;!-1M|(0lox!973~aU(yix zA;bqkw92YQBGNX?feu&bkb!#{aDP zQ9PH^dT=|3e7>&}_v|^C2kI7dbfKVr0RO}?QzTwRdR0CXnQm$|TDvDss0tGZyU3R5 zSaegCX<63SOUD%99LC76zg~d{@PO?sX@&z*gLyV!G)`0iBExUd7EaTh4?Eg#<0UX} zGpMfA$J!`-S1^l&?4&8i8NEc8#d!&3!Cn-b4Av=o9;H_6RXMPKPZ%edZ!-;i*M_gQ|HyLPJ>c0r z;JH2E`PhISax_s{0J$V9qp5;IEON*ZcQQotLS(K*=EcY*7p#EusdZEy$GbiScyK=w zgIdNet~Jb!+z}0A^sZY@rZrjpk@?Nz4GEue!TQ0cAXv8HQUF(C4xC8%ld%mUgq#DI zEjjR;rt`G3mEOv1xm($u4W&5U%k*c~+F9YAMGewVUE_9puW}!LJ%Rz*t;AMp<9n(&#m)r3^Jc*V8k6Xql2G3n_h zF&;6-OJ2B;UewwXXGhkYd=J0h{56%uzlBM)fq&`(6U8K8Qa}=A;7<^@V-6B9Y^-c{ zHUqd->S}!xJ_Bm#dUUYfOZFdrgl(=$J8`1VI26UB$` zz;k{(WT0^gh7=7e+}|5oz&rRIpE5$=AWMWI(~Dr5??ro2T6-8td+mBRR|TuC&}Zy* zU2n6!)`rb071duw?S+(@z(47*b_TkS`3yRQ&kIun$0dV}vPm;T9Ffd0%C1+}>J=6< z%n>Wi^N?cv%US9LKzFf5@-qH}w)?CxR&`-q(tvz9Ev%A?E*0Kl9G-G?aB3+fj>0rB zbTUeb;~!?GXm5h4OKAKWolIZmYxp~NRz|&$jxx2?%KBz26geK1fTi#?m;Q3FcgzjFN8UUVQ^nSZXzr;bK8?B5v})%mD~d*n)VfI1*{3V zLZ!r#=03L96nmYMs`cGoc5A{GpL=qWS%kJ^UGV48opJE|L?;X zioHn*{FITm{+?(j7*3#pP|bC|g~ znYlU6oS(y-S!z4HC7!dn$}!Fvgb3&C$o8#%t~YnZGL4ta3ttG2_sW8ex5ao|@RC!esJ`>W6^jiqy5kUb<=u^??XQArqZgERto^MIotX;LKWB7^n-8g}75DHboUB zaWu)e1?aNV+GKPmh>9_wa-QaG5akAi`h9W~mbOV93)1-8kijlBJ134PxsD?>iy~4A$QLQAm(-J-~yM5A1+QLy)a z)Q%|T9>Nv;AyrCSOxptt&yZ6rTNHRu;mENqbzVFw+2GF5exC{HKn|a}WKq$iLhsuM zz&kLOp(5)4G2HCK+m3IDS_*#)!tmx+3aX>g%nSc{aVsr%1-W~2iw5yTFTLMr7JC`A z04^)-SX*vLuB*#RP&{q;X+A%INNpGWaG`iYz6js_AQWD_iJJ^yWcQ>bZezQw+rEN6 zfzP9iZ^F|?b834U@(HSLd^f|^;{0sCB%JGyVssonWS2nRC1aT%j78>`z1x4qkj)5~ zW6W58WV&WNGDwy0Lr-GZU~J%t-v`+RzY=Pk>DSmkpnHf?z_xGFcUTWSAIA2cS{w@${ipftiAzdOg-`P+fOWyZMSF(=C;_-O@H==Eg@)x?WzaL}ss!pOTaHb7Et&^FpVNRA&4chZ6`8Zk zpuHj@=+<=y8u&c8!8B-$u8;17rma}_PV_Cl0`HPbD2qBC)(14zcidPqswn+zof|S1 zOg%ae^>ojVAQu(6Eda4tCePQ^Pq6>c+|Y1$TEBU;!h*H9xRBZ5FBk+M)vvJm$C-FZ zGH@8YAFw~c6_k;Q2Ej^#RTd}-`Yvjj~@fXmr9;YL(V0X^TnGr2H>*q#- z==mt~u-+4)K7dx#i%9y4a7IfH%4bC}5xIV6^E?8-T{VyphX>M4RN{(`j^CptM3AgY z5e-oMQV%9p2J}A<>A?ljgD1fEpa;Xyxse83M2##BoMsk0$N~@A5I%rAz*1{Quo32c zJ1p?-k~s@r0IoUFEH)=&otH)g<4|;dl%VsL{7i`LN22rYgMMjL%3*PPYM18w`K=hf zHKog;xtXY67R`5toIB3PI-YQ$d$;C#Q=<9K^k&f~YQAo)`Ob<~OU?HyJ?N^?L1(Cm z#hUL|@D^~Yt$FBv^U&O;ky6j4=9`K&-{<5$yOstw^2398o?7sHc#Yzc*D;eKJV+9Il zy|RY=87qPMN30@a5zas8?Juz4<4i;v$RNjPdLJb>-9$8v@0Fl}-qPGyU*miqq zA>G^Q8ld?uYSjTLiQAx;kZs3m70P&rr?Gz%@r*RQczCDW71&ywhg#P!fu~R4Q1|qE zjqOrglU@~PQGEhq*UrBbatr%jCuQVk+3ay9qJ3lr+J{ztewjsrc;8Xw$UwlzCLP0a zfC77|7_mAr**Z&j|ct}<|+SsuExKvXt-zBTF0P+{S1<4LsO(WWYtfo)P zy^*v&Qw?;0X>r$qRI5@U3kCne7SGQSLWbGB^=z1Qqf8H{k0M3)L$`o_Xn5yLwq@Id&=K*k8)TgVqFnT_c& z!k0pLL;%zM0V6Cm?AUNq-<|73LMF6g_Hde%Of8&uIZjG3n;S2HsiqOP#J>9)GCi|N z4>+~~h#*FoM2F&gfEglpebLe{z9u%>i z7qMN?Vp~ME9nw>(;l-G-Ju`sln%ur{;hCO8vQ&@vOW~uJo?IZLq+a212dEIRKiKoI zOvWeM7vP1D$1<(H8qPp#*Y`Xg5eLIX1TV}Z!|!)5rJN zGC2a(_oH>ulCB-M#%-2}j9s=?pu>SQh~hf_pAB=^s;M_>X5nBoA?s)Y>$p7}_yewO zqQ~M9&|G^7x98xgrwXGDDRT6Y!2W&~2^)PYbHZuKD5KOmkKPQiN}`U|?^2py-bKLE ztn1Lmh|)2;8^ghVfeZT}sxs(L-HsfF3OpRMY39b|v3=?Htn29mJ6yE0Bda%_-*848 zT*7ajM5ceB3IhMw)7aYvN`D2mKRloChPo*uZ*bJ95L2G}|L z+K6Gts_)F`6x@)uUO^i_&(=66Xz7R|WEU;5l)r#IPE=m)jzZ+a;1 z$*=@ZKV!Bya0fPlU7k8(D@vHVb0daO?YSXJ9p0??Axzdzr$$1(5Ua4%43^sRnDh{wSeY>+#a ztqIFJMw9u0{V&P^b}WKbSQq+sybh0uYp=&}4txkG#XXc2?kTcDoXaE(5Fh~|Jxkb! zFb_ED|24X^GaT@F!u}nTQ5oLfD@q2U*@j~7J4)>=C{6F`wzwyr`(ujmo5Y-CZs+{P zyaby>{w|w`8K?<6x!j z@EtgEiStWJcuI|{+5a(5&yKK@_aQT2C))1ip|PIzqCQ=R7YX^LaEeyHny+-RMzTeM z`%*!)ZkT-$Oai%FQV$m3u6lX4{USnmJ}j)uP**F3K3P*7Xz4(Ko4C}GrxV@abWPz3 z4kB5B8u^~-Ei*J;#2eZ%PVCd6~%KE&j!t>AuS`fdE54?_J1 z_US(|`A*0ast$NdXQiLqKswKUvP;loz`DSubU<4+ znPInYCjE8idOA-lEzTEWR`$QTb9+qoh^LZxoOmii$29$920Nb0VkY4gVgu zb9{qVmWsD3LTb21ev~Kt1;>#j%2H0ef!HD?49)ekBMdDcO+QgrD|J6+X#GGQx$HS{ z5Pq3W8pa=guW*lALElO`fs_|fBv?c;#Okab41-wD0dny41$Lhmpfyng!X7lR75g$B zeB7m$gT_;Uc9oA$BjW245J`w`e*&$Z#r{m8haba@mi;2}#o_TO!bA($5t#Tn6kJW| zWgVp^^+gmxaJ`w2l7__H8fEHn5R;S3Ei{_LgVgV19chF3uiwVz#|A{$wWt_B-9O(t z-M@zkgMC94A?Kx%9B*B(`|@>$+G=~;j+Q66@OPq7qe#R`c;#sBPu+>3l5choja$`( zn~o4T**{nCAzZ;cl3@bJmgI!c^pOQi{SNA=|IXwr6E8^;a@b^_Nt2o!x(CPhrjPz9 z9WmDK=-*;%#N@qY{FhMrcD{k~!_7j*>%Tt%VX&-*YvlKO5JqfD4EfAe($!r=K_Wp1 zjo@)4LquUhr5_N51(Z)g5zqRN*dPgdZmr8u>0^>Gg(wh!)*6D4L>w_$$SFZUA%Pu; zKOhyM5}^&j+uufKM$rD#0N^}O91S4a%d++{UK*FzNET(&z<>pLLx*WAK)K;?KHUJq1;StwY1ja}ahmABdBJyE&en*l0+fQ@PH znc&CXsQkOA{a5&gU8sCav8;w`z_=W(GIM(~qJhVT`Pdo2~r z@@p?4*fd<9fE++P5ROXNBU`{%LS}hRDigw-guKRKJM1rqZwM*sBL%zvu<;R?pQaIJ ziRV5pY*5BmS{N~bP7C)Z+5sF9EWkx#ahHh;V^}u*;(htc+_pR+R8WXLOq}6vR=&By zPquU?J#MhZcUmE;5B*ElyD(-@->0Ag<=y5=W z1A)Z1LE#z5vkAUgG#snF5jptu%IPSzp?C!1eil8(B%yb4V`_tsq=0u);9=*bp2o=S z&!3ioPxE2Ea!I}dBD*t#rL3OR%I1a^V!zK?A`mEHe};oUfeZyg0R|rh*#6d2QC`yppAA!*cBF$Vf*}E@QCROXvCL_pFb_`-2e!JCx=RURi=t4&Q0GXEL9lYI#WT)Ym>*uVL9fQ)@$yvSBYt1}4l`0egiF6GuTAlZucUu907&sj;DiGVqPKI=-`m^Xd9) zKJ@rbECCDUt8$q-4xQm6VzMpqlDUL0%=rZvIeDR+0w%2(&7H+Jy27VI92tGG6NLDp zSFIl4BU|YHi+B@!OJrzK2|FD^RRv}F{2WZZHR33yR_EfvFpS`LkqnVnMh>dq0U1W3 zIT~6%WF&}42(EnqdLQ99R~5tQy3dG@9b8zOUiNVF`Y=&bs6q$4NeMW+23pwaB`H|+-9;UNhgGmz;0 z@38%sm{96OSvdqM=S0O1@KypH0(Efy`&sktOvG#RSr+F9y9N~kR6;n!65UnyL~XKO zZs=21?EBKbL7(EU?M3NWN~WaEr4$Z_;h2NDK913RiDW{5XR6$>1^&w^=4C+-hT0-yQkM?jHie~r3UV|HtDI-H= zH{KO8g?#y!G`dwA=Pfkv~sS$$$~fKYSE*s-YenD0II@ePcO zoid0lSC*N$o0$HL4h)}0)>{g2!Wc+EhYb`5T|mj58)6b_;i*d{?=v3!*gn?FHmnq?QkVb19^p3vS3 z5*-y(9SQ2*yccfuWngOY?4n4KMQAIJp_QSt{tlhhTJ4a%L)3*8aacq&zK7Nk0?@;y z@ZEcR|JP9Zn?TfnBHc|TUXNGt*#*n_R_UahLF6i}SP!#%oBsD+{<}rdywtCBJ1C&SUdZXRh2xLy9`Q%H8C;DGPJOT7=1wz+Lca14-Nn`|HT#P@Q zc><@!^ruA?`H6h+P*zfSvfiY8k_AsN5$^jg=7brqG1p=80+aVM`6QE1GkJx{=a`5U z^CjjekwxDY)k5ZBXth0-l7&R>eZT}b;qE8y1Z3;_j27eBI8kD|GuWDxE|>5(gTMPq hnewSzzWiR}RQbI~C(0+w_mvR6LvzX9!SiFyD4 literal 0 HcmV?d00001 diff --git a/tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-39.pyc b/tango_edm/audioldm/variational_autoencoder/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26fbb358274b085b3ab72074d1789f87d397ab4b GIT binary patch literal 22086 zcmch93v?XUdEU(I?C$IXi^XCAkQ7C&D3TUrnv`hCNfcR;C5p0L8CE04jbU4x<<0v)}jM*_mBnQOb?l#hHsc_ujd4AOGX~|9kJWCMJpoe$Ri;3HKczFpMwqX8330 z<^f#(&zXkd8m`$j7WIE~(Uf~@Ewh+0b(+1HMVi&kt>qWg4X5BfY%y}i$JA?au?ta|wk4wzDce^*a2hoZ- z-Lr8h8dmJgp!e-@PC5;m+>w)f7wC}XvHkZ^}MBi0dK3UPHtCGhjWE_ z84OIJbr2L?Y~XixrQCV#9 zhR@{3RR3J-T&LUdgHF5gw!Z6i{d=XtGZ@gunN~;n-h~sA^590N>o!oqhTreFt!~3x zJLkDBD)e=MSI`E`hu>GH^uObca~na!MPoeWH~emAc_rvRE$`xVN;h23Z>!Gwpbw9G zZHd1A#66AvQsZzk!;QE2jW>0>jd%7p8VjE19@b5|hVe13arQga*BAy!BkKlFuakUo zu6ME1_7<+1QO518&D)A)RhEIo7|c&anXcE1vi?eI-HU9D@Cs7;%TYGyt2W}^Q|nQ# zB{Ui3m%9B{@cL7cg-i#JM_EaUGVN9nWiWJz_5H}|^`hd^Mo;?E>PDIMmOF1pg`l(M z`GL3YM-$1|Hqo}oS!s2bB$h$ORYTo^%3pr_$(8<^chdE|^^=&$Rg8^)vbEuM`Y0OD zu)goQYbP(Z6gUNw-s(168$sXewK3i5{<9eIVsB%uxvcsd81gdX?Uh!q2Ohw@A}QBeT0#rvGtykO z54A>xM_7EJuh!*{rgwRt=I%fvVS z5kx$L+G)}Hps@te0ax?jk&SgvMNYHX>2-o;Gb%jU305eSFAj}g?jhWtawD7RsMYqp?vgsr z_k0@yaATp_YsosD*4w+A4G|VCfk$=VOn^S}SpBTJ&7hRC2AL7HeBhYlQ zO6jWUBC`b=S+&&>NNfxRBsqpdy8k5<=kI|iPjv#`ylO>dzuoFW;k4Rda!3R)(bZhq zLA4(bkT3#It@pdoP0}y*9Z0@vsqe%;#o66S?)jy@da|Y5dywZ-b{J-z9W*zqNeAO1 zS92WXUTe+6AQziW&B2J5o6QRwt**|IAya2iQ8W>cXxr;{2ct>}qV8kRU~mEfC&4G@ zI%e6Cls0ZgexwKfz7+zU9<=`~#M|RSHg~@YvvXtkmjas=3Fv;lss-4 z*WQ>?)FY$Sr1SkKGI*EHH&b4f`7WOBKiPZLMAFV_RwvmdnPDoLS1`?QV>gGyMn=TO zblS@CvrAIvYd9VLX#@}whC1$=fQ3*jtB{TY<_T_oLwBt>Cx00U`N?b^?01P#&lE&>f{ z+dD-=g{bu1p5OC=2fO|D`Gqqb1F&#D6P5cLK?2mG2||yiztUGhdn1S@0wCQaH-WO8 zdGddPqP~EeAru9E2a2*>W5ZAfL$inY7A~a1I74U$027y3Di`Eeoz;R!3qlTpQEIMY zn04**wpwETb}7t;))hWq#*>yjIUzZf(1rE`*GUqcY5-9WH9&eoL z?HkHrOYFEUH9>sY1ub4vswd6NPb;Dl4JeK?_AUg`)Q0Z~x=V&qKy8%m0rTE9gowVK zrL!FLyP}-7%wg+r(AQzL=<=tc{Atx+$2>)OH;%|7r_{Uoc)16tJ(#ljoEB+GTcnKg z*tGXB7^Cc;M7mE#HVOofRl-M$vEp?nPPTgQRCTO5%JOh+lHArzjW0WVuc zYjb*dsBb7x7E#|t)=c-GleJ9Cvc6V5qL{u*tPJMMZy+}nD~np7fF)sGrGmtc${`6%@wpsBbyR$r@%8WMCSR(ycju@ za!p7!wT8muWYaGpJ!TI`q(mxdX_SVWwPe_m*&P|b66#6hkOR>TIRwS94U->IhcQSR zkw3;Zln7!r(lQc*JXy8hZExka@>|YUAS*9lcPrk-qA-0?m@?i^Rm}KbM?^drT8Kh$n#D{7SiOd;ioaThaoe<^(drbZ zQz^aNP?+{n`#7_oi#>$PKZAh2l#7PH@2yxk3C)ih zZsudqc;ZL5;k~|{9`;Kxe`pTj{@%a@{lOpjxB-rq9O1Xz0z~3_QD2nPoULCUa+R$QF zJa{>f=46x=zdP(gfn7qP??k=V=mq)`F8Q53BORX4Mfvthb8Vv=h=Pu)AWyjenu)J1 zDifCtytvUM@APO&FV!=sE&;iuaJn_AheT_A9&wf33Lq9i7GW@uI$6y7=R50IxlljJr^X=duOeNRi;ZfL zN|rVYRIsc=SlXDhEvVsdTGt&1dDk7=mJ$_FC1u1bcnVeu{_y`B_I2B_kyppXTB<1F zn3YU%s7x_e9lmplQ1DbttgR*?Ys`(*SOgrhQsjFoJ7&a2J;3!6N`@ti3JzF}653Md#PwnN$GKh(El!DUU$TfOSAosl4A#4CSiN?T z*Z<>mbCBM%T|@AvL%Vd@+{y;j8AXvd&u-=9t|WJlZPBtVhq?U*Fk>DyK&nbT)|L}g zlDtVE%nTeVpG{Mww1ajyK$z9g5jp!IAXdDIYx=O3L&=Hj-=5NSUqPE7*OA9J;b~*F zwml7f1WX&>)md8{pBdx?vx8oY_QQp)l9cb2zRab4k@02g_FXY#?H^-|5gYVOmyCM` zE%FYuBu-2F1{wGc=qfmXP~!BU#P)vOLgXTC`+&Z~eCY8owr`4i-~fie^&{|ubRbyy z+I9Tq4{EtH!&KqETNxu0i-b~ti?O2&{x*Z>8T@qyozdQYQhmC3 zmtuY!*T|1O40 zz){lQIEN+JO_i#GL#Y}Q?Az4q(|2oC;3xA^K?VD95%7{GeR_z80B>0U?*s%M056Q0 z8$n$NpJoAI7+@E|$Aa)#hp8)f2(#x6e;r0qCMfalk~s?@0C5S^ceN6OTUrZTA)@#K-}qoAWTEt&u-Op`qp&(zm8T9wZEY2M}H4V>j>uK z)**JNNw6EYNA4SAO9AqRGXe!~i${4TETJ{O54c#Y8}QN>#XcL1_>mD{NREB^+*xd& z5(LqJ*AIq1czdmp(pTf|MjxqqUsVr#jaclQ*YL5?1HcX2mRdd8yA<>T?2<|`O#>+K z910xucYsgz*YIBT!wjBdkV7zkG=-R2-u?))M2@BcR#1`Pqko260Jzpu*k=@|9nL5% zoPUfJ{0xJ?$$+p&3v3a|^qLAtR6ojC0-j`*Mhol(W?W&gLu%6tHYBF^r%+G*eP$&% zPyIA+6IM>4T~ypmmSEH`vrxj@O1d`DSq7-|n&+`eV%b;!fJLM)BJ$JL{wy>8Ap=nZ z(#iXHn>te{r#?YJ%3(LA2WSo$NpKjkTm|CZ!Zsc2x}8h097}vFXs!a~a$`hC^cJJM zuWwdkF_@sV5yuMwS_0B2r~#oMqu}7v5Q<9)0B14uu;mVt1P0t=TaYhUu?CpIexEGB zQ^q9}vj8-50pXC&qmzzOte2)pX-JPPEHrF8VNSU9cAEV(%>_aokes*X?>A_^E-)jK6>5}xA*bD0C(Rc09%OW;E^(vtc zK^jL{x4?~j3O6*``Nzzoan& zNq0s3-@?-5=20iqBvKCGN_6tEQJp*?I(hOH>tyI!xSz@hi*WfEO>D6R?ZK1~$_@kM z@D3sMb13?mPK882yEcr9fd-}?jdfh?JSw=PSkZdZX!zE&3x_nc;`9>UgH_t<>eOlt zDK)H5Q@1;!+oyx6?Yd}jYTue@O6s@5WwvgnIVa8?I6eUXdB1;bsRO9xCaL9MaC2~o z?YM*`5pejw-m2@ot(o}$tYCn;w0cg;=G<&-hTIQ5KC^v`ykicv#4RCjhjVasP1Dr{ zz4=5qi&AuUNz0fIbK(izm#zKK$orcOAc_6WTf;iWmj=imA_hc&T1!$#1f8rMU;CT4 z0SO#Q@4{&}2bT`rMhD^2vBLdVv`gpCA(xIGPr7u5$8*%BL(Q#zg#m^7kSFydZ@G4p z>epfK%_~u@|AJ5dO9n!_|D3U3W$hv=>nc~hzN}4 zsm2m*ysjk(%i?+nm;a{-LUWalY%J4@&c#B9UP@cdaCu&|nz*sCAf*@i!Bt$G*EvBx zki!8nxAx!|;XRxXF-JZq8oDGJx~w&{Xy;5o*`(7eu?%}6P0=N}z2L$Vp-m1{j}CI- zYL}K=U`kbenui8}7N&G+P@{~K?FDeA$9<`*KM7{wkkRWxL!7}=4WB!(@|E#en30kia>z3IXNJU8|vd1QSV{Cp^ivw zE)VpvY5h@kS+PVL*T^pu-+e2aXeS<*108YG+Jh=s6N!{#$#L?x5v1lU&#&O(Mn$a9 zUCD4u5?AQzFiFKXj&=0-9$OVhQuY04S=6LU$F*^LBNn(WTgw1vU=5PEUY{=nMQoAO z8zc+x8=6pSG;sC0I{ff#bm~tRdtVFacOv5~9PZ!70w)0K;q(~O96v3MIJu|CC#y%$Bf+ng|tvV~Y3p4%K_?s3>!u1NP{1{{jqz@3gd7;QS zxR)_5=2I5gaKnnr6)8aTNWrv&gDozrrUTAF#kqlk2XmBWtfzhrDcXEH%9vo^c)*Sz zAu>BrR@_f}deuJ5!K$LH0kIADG*+KubUQOT1UKZV&!CQuV7V!e| zCql=~UXlgNu)u}ft-^(~K)P@TqxHnCIazYWaW#ks?5@-cuHqqac5a{$Zaq81p~IU5 z$GIvJu9elV!AjWVU;!yk7{Wh-m;~^oSVtIIxL$~~vV*HNn)?W*N{MIWLyUP0Mxo0h z4iftD>Q9PY#t4BgBtoEA#3B#b+6IMl?*$ig2vc@~jvRpt%M6xX_U$0L@nT>EZoC&w z=H8va7#Q?I2#-P>WWeT_u2mS-F?sAuO?v?Y*g*&pVOih!gc#nCCDC1OK_2AuA{&Jh$~n(Fa?!?b_PU{q%h_GVIWXmOz!`xJw{v~&R+ zrVp4|+({ z9Ub0;c~0DirSw+UzaNo79sl>8Q2!opR{w#)=NSA)21z^Bf5PqY@eRbf&0WZypGe)c zQBEHY)4Rb$S_woJOOI?cBmwWzTbBrfhf+^z9UO0S{?EwNp3;ahfiKttupX@Fb!usB zgIRAtu#fG$0}$wuGuQ*f(i zATPx6C@aErK~!KE0gA*Br5eZ~86GbTumv9GlOrxiD5E-P_fne5v(ZT^)|oSUm09mz zmVGspc;#JhP-p>`Y7(FW$DpWlk~Yqi+B~T^U(vq7-$dGwiN2q=rx`qr02?Oy8`vmO zW`YbiCq!Lj@C62-yxXe8#Fko9F>$dQb&Hodi?pv5G3HJDAEKwnnaK}Zec`_ zm7@CFJhkk(u^<~P$C0hiiqG#$tkN+4_`5`Ev^iFT-vlDO&@{teKrkfYEMN`ze}Oj2 z0f>XA9I#oe43tE~gT-iwlgwrM?6|3{2!vAxnkvPo`|vFWs3I&*|2%3vgME~O;6=lK z7%AHG_yS|U%s{mMPctT3|A!dk3dGf{o>qA};Ue{4kOhTx4~vm^#Df~;y?7~-!_6Hp zTC#)1zmsL84tcq}jn(Vb?j7$X_=qa;(}VM(^}#(%6>Rfq3L7t0<+$jY=gPMX+ABTH z*ePkTagTB0`EsK|!b)W1XzZ)b#8A%HH*X#{tBDjHVQ`9l&fr71e8RP12FKRngwgcT z0ZaW3%BcU&;9&+S#Wr|LcGzH>bBJhm*gQD4H3Ra`>UCf3Vg4fPASLgu+kYFmZ|57R zH=F`=yZ-wM7=w9jTq8eL$KRiaF=EqU$PKQNrtV^jGG2xT8vdgQhM2+xgio2m0?B6q z?h7t7HCTeqSnDz%e9RK2FhyyA41^`(MUeTTUSWFy^^aX#6IKx%EYvD=``=J_gzR5L z0?ysU4t|ooOzVK+UtogJxzr(53bCgw=&^{f8^whBT?C8y0Q*w+Am)(MCR%E5V*XR7 z{C^MWAyR7xok5n7uc?SH!^v3faJb_mwWgaT?>;a_mXs2yYhT3nl!($>wB(`=RmG6Zu-f zy^YOE<8`mQzrk&X1O<)Fa--FA8ycrIMC)P~`52(J3sQ1)ra3}Y1cHm5E%RAJW8onY z_&H<)2@B_?a-2Z&S5U?O#D8P+STY+46_10p^Ue_Sy?_UY)D884wtD%dVHB!?jn!L^ zY<#ceRFqv;JfZMLKJaD+37dKoy+-$(tQWHm89+pn0?#x?Rwc=vWBC-|(8SbH=y zi}bAGLr!+`sgF{W12zNoE}V%RezS+)q-YQ7;tLpO2hak;cM*(mFn93N0wix*>iYtl zPAu-#1ZN@#Vdi77F~OVpOZpo~IHTylS$H1DS66YC5ZIpEmAETD_2xPYo4JL4;DO8T z!J#;Of2fOw^{_7{*)v0)B)>h0qa+Qzt|U8TmO4+N*k-mt$F1I$&Gx5e>4@J~T*41B z|9>zbB`;1i^_b}IVMNnt$Xf3-cX#4$PilxhNmESB~6|LtdZZloAk; zdg=J3>$eW-#nupdg|2T->Q*FdERdkhc^Lc{XWt&eC8v1k)WumCoUOnq7;FoKzZ@=I z%m{t#w1N0Z0lIIY>A#4F&cKPw&PNmYp3E}7JJZJ)D_^|{Y4h2!rVWeZDe37E#HH+z zjA}6L41&+0b$%HU1I}9;hhK`1ewtO1Ah7{v zQZc2CYvh-pF^Z#v4e-@858tH0S#SO28+vgk7D00IowdA&17P^9mTVxrWG>>HYi=2a zOi9?JjNz(8b7%0?s^IYe$1@-6`vE?U)b@OQb_&g>K8D|j(krTBgF>KMU?Z2GQ1QAW z&R>oOH6prU_`i!_h?nwm==%=nDRRhRY#@Lh#AzpTm(@D}Y{&F_lXRZ3%Zs~~TF!=x zn1{ySO+F<5VdpbEaF}y^bBwkx%D}$^ezUmy82pAI(EJ8Rv4aaI;BaD&7WW(AICi$n zaP1aD93Y1-C2wN!4$Js*k8%a%SFr637GJ!T?;}A0OM}?Gj5k#!UXb{dAj{&y=9?ad z_QX!L!#juQaHBt(iXFK4j`7rIEQlM= zsDFiUzBXisy^MrtcK8EtX-jxyft&ti(fxN=lfY0qiKBv?wUpzSSA} z&OPY$qhI4Qr9@153i3v{g#@Bo$(E=MmMac@>xs=?oJ!jie|j#;#o90{buMOcR11d= z%(Zb+<-d|c9zaD#5uQ#fcipzEZ`v6rGl;=`LsrBtWpU!*I065_ppZVr>o`7;#o+-g zOl>yyAoxTajp$4=C8_!$Tw+gR!HG*rpulZ>Ac{0h%G}OeF+gF(EqI#4Ow+z_P^iON zI0I?J=Of$0*GZI2RdF4kKwI;$mV|GYDeuBY2NQ%ZYw^2g8cpOhrN~+nun3Fqltf=g zrB$>p{=N;nUNJXM?A5vP)G^XkHZ$q!Vcei&JOoQi79aP+NBlCEEOiJSx3RMo_k+FJ zS>3Ii_5|Q~3BKLNud}qh=-}y2YnfkBd!)7D`<+(rG*9Pr@F_bS^$`g@B67Fz)*T&< zW_7jt%w9@Fb4lMZsdRqt-S_bojP#w#h7wk$c}e;nyX!E$BL{(HtsiSyvKv75N^!4X z?vjP+KZqp<8w(8B;Xt>qe3lUV8@O-a!rds9SLXobGcNZLq6Wy{3_jE-%k!J{bB%Um zrSW6~$opL5cEsizZ7e60;4fBx#&P*3gQEzd%w}g@uR~sE+BXF+LgA7Y5MXrIy`$-~HE+`&%GX{{SLRHgROUgfALc_BX00 zojhPhPO%(D_n7|oUjDmH(eb8V>Bu(;PxN~FC|MQ@yz5V^KSF+#^E$o$x-VlIRpcXq z_$c5B5N+4j+Md4F+P3{5E4WcLg^KYLGq;&1_;zMb3Sgb@ 1: + print("Use subband decomposition %s" % self.subband) + + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.vocoder = get_vocoder(None, "cpu") + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.scale_factor = scale_factor + + def encode(self, x): + # x = self.time_shuffle_operation(x) + x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + return wav_reconstruction + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if self.flag_first_run: + print("Latent size: ", z.size()) + self.flag_first_run = False + + dec = self.decode(z) + + return dec, posterior + + def freq_split_subband(self, fbank): + if self.subband == 1 or self.image_key != "stft": + return fbank + + bs, ch, tstep, fbins = fbank.size() + + assert fbank.size(-1) % self.subband == 0 + assert ch == 1 + + return ( + fbank.squeeze(1) + .reshape(bs, tstep, self.subband, fbins // self.subband) + .permute(0, 2, 1, 3) + ) + + def freq_merge_subband(self, subband_fbank): + if self.subband == 1 or self.image_key != "stft": + return subband_fbank + assert subband_fbank.size(1) == self.subband # Channel dimension + bs, sub_ch, tstep, fbins = subband_fbank.size() + return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) + + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def encode_first_stage(self, x): + return self.encode(x) + + # @torch.no_grad() # TODO: Maybe we don't use this one. + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.decode(z) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z \ No newline at end of file diff --git a/tango_edm/audioldm/variational_autoencoder/distributions.py b/tango_edm/audioldm/variational_autoencoder/distributions.py new file mode 100644 index 0000000..58eb535 --- /dev/null +++ b/tango_edm/audioldm/variational_autoencoder/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/tango_edm/audioldm/variational_autoencoder/modules.py b/tango_edm/audioldm/variational_autoencoder/modules.py new file mode 100644 index 0000000..a234d08 --- /dev/null +++ b/tango_edm/audioldm/variational_autoencoder/modules.py @@ -0,0 +1,1066 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from tango_edm.audioldm.utils import instantiate_from_config +from tango_edm.audioldm.latent_diffusion.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z