-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 20f0756
Showing
94 changed files
with
60,246 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation | ||
|
||
This repository is the official implementation of "SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation" | ||
|
||
- [arxiv](https://arxiv.org/abs/2405.18503) | ||
- [Audio Demo Samples](https://koichi-saito-sony.github.io/soundctm/) | ||
|
||
Contact: | ||
- Koichi SAITO: [email protected] | ||
|
||
## Checkpoints | ||
|
||
- Download and put the [teacher model's checkpoints](https://huggingface.co/koichisaito/soundctm/tree/main/ckpt/teacher) and [AudioLDM-s-full checkpoints for VAE+Vocoder part](https://huggingface.co/koichisaito/soundctm/blob/main/ckpt/audioldm-s-full.ckpt) to `soundctm/ckpt` | ||
- [SoundCTM checkpoint](https://huggingface.co/koichisaito/soundctm/tree/main/ckpt/soundctm_ckpt) on AudioCaps (ema=0.999, 30K training iterations) | ||
|
||
For inference, both [AudioLDM-s-full (for VAE's decoder+Vocoder)](https://huggingface.co/koichisaito/soundctm/blob/main/ckpt/audioldm-s-full.ckpt) and [SoundCTM](https://huggingface.co/koichisaito/soundctm/tree/main/ckpt/soundctm_ckpt) checkpoints will be used. | ||
|
||
## Prerequisites | ||
|
||
Install docker to your own server and biuld docker container: | ||
|
||
```bash | ||
docker build -t soundctm . | ||
``` | ||
|
||
Then run scripts in the container. | ||
|
||
## Training | ||
Please see `ctm_train.sh` and `ctm_train.py` and modify folder path dependeing on your environment. | ||
|
||
Then run `bash ctm_train.sh` | ||
|
||
## Inference | ||
Please see `ctm_inference.sh` and `ctm_inference.py` and modify folder path dependeing on your environment. | ||
|
||
Then run `bash ctm_inference.sh` | ||
|
||
## Numerical evaluation | ||
Please see `numerical_evaluation.sh` and `numerical_evaluation.py` and modify folder path dependeing on your environment. | ||
|
||
Then run `bash numerical_evaluation.sh` | ||
|
||
|
||
## Dataset | ||
Follow the instructions given in the [AudioCaps repository](https://github.com/cdjkim/audiocaps) for downloading the data. | ||
Data locations are needed to be spesificied in `ctm_train.sh`. | ||
You can also see some examples at `data/train.csv`. | ||
|
||
|
||
## WandB for logging | ||
The training code also requires a [Weights & Biases](https://wandb.ai/site) account to log the training outputs and demos. Create an account and log in with: | ||
```bash | ||
$ wandb login | ||
``` | ||
Or you can also pass an API key as an environment variable `WANDB_API_KEY`. | ||
(You can obtain the API key from https://wandb.ai/authorize after logging in to your account.) | ||
```bash | ||
$ WANDB_API_KEY="12345x6789y..." | ||
``` | ||
|
||
|
||
## Citation | ||
``` | ||
@article{saito2024soundctm, | ||
title={SoundCTM: Uniting Score-based and Consistency Models for Text-to-Sound Generation}, | ||
author={Koichi Saito and Dongjun Kim and Takashi Shibuya and Chieh-Hsin Lai and Zhi Zhong and Yuhta Takida and Yuki Mitsufuji}, | ||
journal={arXiv preprint arXiv:2405.18503}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
## Reference | ||
Part of the code is borrowed from the following repos. We would like to thank the authors of these repos for their contribution. | ||
> https://github.com/sony/ctm | ||
> https://github.com/declare-lab/tango | ||
> https://github.com/haoheliu/AudioLDM | ||
> https://github.com/haoheliu/audioldm_eval | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"seed": 5031, "tango": true, "train_file": "/data/audiocaps/train.csv", "validation_file": "/data/audiocaps/val.csv", "num_examples": -1, "text_encoder_name": "google/flan-t5-large", "unet_model_config": "configs/diffusion_model_config.json", "ctm_unet_model_config": "configs/diffusion_model_config.json", "freeze_text_encoder": true, "text_column": "caption", "audio_column": "file_name", "tango_data_augment": true, "augment_num": 2, "uncond_prob": 0.1, "prefix": null, "per_device_train_batch_size": 6, "per_device_eval_batch_size": 2, "num_train_epochs": 40, "gradient_accumulation_steps": 1, "lr_scheduler_type": "linear", "d_lr_scheduler_type": "linear", "num_warmup_steps": 0, "d_num_warmup_steps": 0, "adam_beta1": 0.9, "adam_beta2": 0.999, "adam_epsilon": 1e-08, "output_dir": "/output/", "duration": 10.0, "checkpointing_steps": "best", "model_grad_clip_value": 1000.0, "disc_grad_clip_value": 1000.0, "sigma_data": 0.25, "resume_from_checkpoint": null, "generated_path": null, "valid_data_path": null, "mixed_precision": "bf16", "allow_tf32": false, "gradient_checkpointing": false, "enable_xformers_memory_efficient_attention": false, "with_tracking": true, "report_to": "wandb", "teacher_model_path": "ckpt/teacher/pytorch_model_2_sigma_025.bin", "stage1_path": "ckpt/audioldm-s-full.ckpt", "schedule_sampler": "uniform", "lr": 8e-05, "weight_decay": 0.0, "lr_anneal_steps": 0, "ema_rate": "0.999", "total_training_steps": 600000, "save_interval": 3000, "unet_mode": "full", "distill_steps_per_iter": 50000, "out_res": -1, "clip_denoised": false, "clip_output": false, "beta_min": 0.1, "beta_max": 20.0, "multiplier": 1.0, "load_optimizer": true, "num_channels": 128, "num_res_blocks": 2, "num_heads": 4, "num_heads_upsample": -1, "num_head_channels": -1, "attention_resolutions": "32,16,8", "channel_mult": "", "dropout": 0.0, "class_cond": false, "use_checkpoint": false, "use_scale_shift_norm": true, "resblock_updown": false, "use_new_attention_order": false, "learn_sigma": false, "out_channels": 8, "in_channels": 8, "deterministic": false, "time_continuous": false, "consistency_weight": 1.0, "loss_norm": "feature_space", "loss_distance": "l2", "loss_domain": "latent", "weight_schedule": "uniform", "parametrization": "euler", "inner_parametrization": "edm", "num_heun_step": 39, "num_heun_step_random": true, "teacher_dropout": 0.1, "training_mode": "ctm", "match_point": "zs", "target_ema_mode": "fixed", "scale_mode": "fixed", "start_ema": 0.999, "start_scales": 40, "end_scales": 40, "sigma_min": 0.002, "sigma_max": 80.0, "rho": 7, "latent_channels": 8, "latent_f_size": 16, "latent_t_size": 256, "cfg_distill": false, "target_cfg": 3.0, "unform_sampled_cfg_distill": true, "w_min": 2.0, "w_max": 5.0, "diffusion_training": true, "denoising_weight": 1.0, "diffusion_mult": 0.7, "diffusion_schedule_sampler": "halflognormal", "apply_adaptive_weight": true, "dsm_loss_target": "z_0", "diffusion_weight_schedule": "karras_weight", "cm_ratio": 0.0, "augment": false, "intermediate_samples": false, "compute_ema_fads": true, "sampling_steps": 18, "ref_path": "", "large_log": false, "discriminator_training": true, "discriminator_input": "latent", "gan_target": "z_target", "sample_s_strategy": "uniform", "heun_step_strategy": "weighted", "heun_step_multiplier": 1.0, "auxiliary_type": "stop_grad", "gan_estimate_type": "same", "discriminator_fix": false, "discriminator_free_target": false, "d_apply_adaptive_weight": true, "discriminator_start_itr": 39000, "discriminator_weight": 1.0, "d_lr": 8e-05, "r1_reg_enable": false, "reg_gamma": 2.0, "d_architecture": "CMBDisc", "dac_dis_rates": [], "dac_dis_periods": [2, 3, 5, 7, 11], "dac_dis_fft_sizes": [1024, 512, 256, 128], "dac_dis_sample_rate": 16000, "dac_dis_bands": [[0.0, 0.25], [0.25, 0.5], [0.5, 0.75], [0.75, 1.0]], "d_cond_type": "text_encoder", "c_dim": 1024, "cmap_dim": 128, "vqgan_ndf": 64, "vqgan_n_layers": 1, "vqgan_use_spectral_norm": false, "mbdisc_ndf": 32, "n_bins": 64, "increase_ch": false, "fm_apply_adaptive_weight": true, "fm_weight": 2.0} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
dir_path = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(dir_path) | ||
from .hook import CLAP_Module |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .factory import list_models, create_model, create_model_and_transforms, add_model_config | ||
from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics | ||
from .model import CLAP, CLAPTextCfg, CLAPVisionCfg, CLAPAudioCfp, convert_weights_to_fp16, trace_model | ||
from .openai import load_openai_model, list_openai_models | ||
from .pretrained import list_pretrained, list_pretrained_tag_models, list_pretrained_model_tags,\ | ||
get_pretrained_url, download_pretrained | ||
from .tokenizer import SimpleTokenizer, tokenize | ||
from .transform import image_transform |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from transformers import BertTokenizer, BertModel | ||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | ||
model = BertModel.from_pretrained("bert-base-uncased") | ||
text = "Replace me by any text you'd like." | ||
|
||
def bert_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output | ||
|
||
from transformers import RobertaTokenizer, RobertaModel | ||
|
||
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') | ||
model = RobertaModel.from_pretrained('roberta-base') | ||
text = "Replace me by any text you'd like." | ||
def Roberta_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output | ||
|
||
from transformers import BartTokenizer, BartModel | ||
|
||
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base') | ||
model = BartModel.from_pretrained('facebook/bart-base') | ||
text = "Replace me by any text you'd like." | ||
def bart_embeddings(text): | ||
# text = "Replace me by any text you'd like." | ||
encoded_input = tokenizer(text, return_tensors='pt') | ||
output = model(**encoded_input) | ||
return output |
Binary file not shown.
Oops, something went wrong.