Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR of diffused-heads #2

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions diffused-heads/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# 使用支持 Conda 的基础镜像,假设 Conda 已经预装
FROM continuumio/miniconda3:latest

# 设置工作目录
WORKDIR /app

# 将当前目录下的所有文件复制到镜像的 /app 目录
COPY . /app/

# 创建虚拟环境
RUN conda create -n diffused-heads python=3.9 -y

# 激活虚拟环境并安装 pip 依赖和 Conda 依赖
RUN /bin/bash -c "source activate diffused-heads && \
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple -r /app/requirements.txt && \
conda install -n diffused-heads -c conda-forge ffmpeg openh264 -y"

# 使用 CMD 启动 Python 脚本,直接在 diffused-heads 环境中运行
ENTRYPOINT ["conda", "run", "-n", "diffused-heads", "python", "/app/sample.py"]

# 如果没有传递命令参数,默认显示帮助信息
CMD ["--help"]
58 changes: 58 additions & 0 deletions diffused-heads/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Diffused Heads

### [Project](https://mstypulkowski.github.io/diffusedheads/) | [Paper](https://arxiv.org/abs/2301.03396) | [Demo](https://youtu.be/DSipIDj-5q0)

## Setup
Python 3.x environment with [ffmpeg](https://www.ffmpeg.org/) is needed. The rest of the requirements can be installed using:
```
pip install -r requirements.txt
```

## Sampling
Due to LRW license agreement, we are only able to provide a checkpoint of our model trained on CREMA.

The entire test set generated by our method can be downloaded from [here](https://drive.google.com/file/d/1zWSqtV7O4WGkgh6WB55b8Mdg2lXXUudH/view?usp=drive_link).


1. Download and unpack [checkpoints](https://drive.google.com/file/d/1U90egQvzERHclTYPCjZadrEMyF7TAPa-/view?usp=drive_link) (our model and pretrained audio encoder).

2. Download and unpack preprocessed CREMA [video](https://drive.google.com/file/d/1rM0FZLGiy-bJcxpv4CTlbUf0FuROubdk/view?usp=drive_link) and [audio](https://drive.google.com/file/d/1uS7Vi8EwarJFGQhsYHDMSkQmaNuiJIVW/view?usp=drive_link) files.

3. Specify paths and options in `config_crema.yaml` (check comments in the file).

4. Run the script
```
python sample.py
```


## Using your own data
### Audio
You can use audio recordings of your choosing freely. The only requirements are 16 kHz audio rate and a single audio channel. Please note our model is able to generate videos up to 9 seconds long depending on the audio.

### Identity frame
It is highly recommended to use a frame from the provided CREMA videos. This instance of the model was trained on clips with green background only. If you want to use your identity frame anyway, please follow this [repo](https://github.com/DinoMan/face-processor) for face alignment. Additionally, you may want to try segmenting the person and replacing background to green.

## Training
The training code can be found in the branch [train](https://github.com/MStypulkowski/diffused-heads/tree/train). We aplogize for the delay.

## Citation
```
@inproceedings{stypulkowski2024diffused,
title={Diffused heads: Diffusion models beat gans on talking-face generation},
author={Stypu{\l}kowski, Micha{\l} and Vougioukas, Konstantinos and He, Sen and Zi{\k{e}}ba, Maciej and Petridis, Stavros and Pantic, Maja},
booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
pages={5091--5100},
year={2024}
}
```

## License
This work is licensed under a
[Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License][cc-by-nc-sa].

[![CC BY-NC-SA 4.0][cc-by-nc-sa-image]][cc-by-nc-sa]

[cc-by-nc-sa]: http://creativecommons.org/licenses/by-nc-sa/4.0/
[cc-by-nc-sa-image]: https://licensebuttons.net/l/by-nc-sa/4.0/88x31.png
[cc-by-nc-sa-shield]: https://img.shields.io/badge/License-CC%20BY--NC--SA%204.0-lightgrey.svg
17 changes: 17 additions & 0 deletions diffused-heads/config_crema.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
audio: /app/demo.wav
checkpoint: ./checkpoints/crema_script.pt
diffusion:
image_size: 128
in_channels: 3
n_timesteps: 1000
out_channels: 6
encoder_checkpoint: ./checkpoints/audio_encoder.pt
gpu: true
id_frame: /app/demo.png
id_frame_random: false
inference_steps: 100
output: /app/testgen/demo.mp4
unet:
motion_channels: 3
n_audio_motion_embs: 2
n_motion_frames: 2
Binary file added diffused-heads/demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added diffused-heads/demo.wav
Binary file not shown.
145 changes: 145 additions & 0 deletions diffused-heads/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from torchvision.transforms import Compose


class Diffusion(nn.Module):
def __init__(
self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
super(Diffusion, self).__init__()

self.nn_backbone = nn_backbone
self.n_timesteps = n_timesteps
self.in_channels = in_channels
self.out_channels = out_channels
self.x_shape = (image_size, image_size)
self.device = device

self.motion_transforms = motion_transforms if motion_transforms else Compose([])

self.timesteps = torch.arange(n_timesteps)
self.beta = self.get_beta_schedule()
self.set_params()
self.device = device

def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
with torch.no_grad():
n_frames = audio_emb.shape[1]

xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)

audio_ids = [0] * n_audio_motion_embs
for i in range(n_audio_motion_embs + 1):
audio_ids += [i]

motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
motion_frames = torch.cat(motion_frames, dim=1)

samples = []
for i in trange(n_frames, desc=f'Sampling'):
sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
samples.append(sample_frame.unsqueeze(1))
motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
return torch.cat(samples, dim=1)

def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
xt = xT
for i, t in reversed(list(enumerate(self.timesteps))):
timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
xt = mean + noise * torch.exp(logvar / 2)

return xt

def get_p_params(self, xt, timesteps, nn_out):
if self.in_channels == self.out_channels:
eps_pred = nn_out
p_logvar = self.expand(torch.log(self.beta[timesteps]))
else:
eps_pred, nu = nn_out.chunk(2, 1)
nu = (nu + 1) / 2
p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])

p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
return p_mean, p_logvar

def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
if x0 is None:
# predict x0 from xt and eps_pred
coef1_x0 = self.expand(self.coef1_x0[timesteps])
coef2_x0 = self.expand(self.coef2_x0[timesteps])
x0 = coef1_x0 * xt - coef2_x0 * eps_pred
x0 = x0.clamp(-1, 1)

# q(x_{t-1} | x_t, x_0)
coef1_q = self.expand(self.coef1_q[timesteps])
coef2_q = self.expand(self.coef2_q[timesteps])
q_mean = coef1_q * x0 + coef2_q * xt

q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])

return q_mean, q_logvar

def get_beta_schedule(self, max_beta=0.999):
alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
betas = []
for i in range(self.n_timesteps):
t1 = i / self.n_timesteps
t2 = (i + 1) / self.n_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas).float()

def set_params(self):
self.alpha = 1 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])

self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))

# to caluclate x0 from eps_pred
self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)

# for q(x_{t-1} | x_t, x_0)
self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)

def space(self, n_timesteps_new):
# change parameters for spaced timesteps during sampling
self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
self.n_timesteps = n_timesteps_new

self.beta = self.get_spaced_beta()
self.set_params()

def space_timesteps(self, n_timesteps, target_timesteps):
all_steps = []
frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(target_timesteps):
taken_steps.append(round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
return all_steps

def get_spaced_beta(self):
last_alpha_cumprod = 1.0
new_beta = []
for i, alpha_cumprod in enumerate(self.alpha_bar):
if i in self.timesteps:
new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
return torch.tensor(new_beta)

def expand(self, arr, dim=4):
while arr.dim() < dim:
arr = arr[:, None]
return arr.to(self.device)
Loading