Skip to content

Commit

Permalink
Add sd35 example (#1147)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Introduced a comprehensive README for the SD35 example, detailing
environment setup, execution instructions, performance comparisons, and
image quality.
- Added a command-line interface for generating images using the Stable
Diffusion 3.5 model, allowing customization through various options.
- Updated README for the Flux example with revised performance
comparison commands.

- **Documentation**
- New README.md file provides structured guidance for users on utilizing
the SD35 model effectively.
- Modifications to the Flux example README ensure continued clarity on
execution instructions and performance metrics.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
clackhan authored Dec 24, 2024
1 parent 6b09731 commit 74c91c7
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 1 deletion.
Binary file added imgs/nexfort_sd35_community.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/nexfort_sd35_enterprise.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added imgs/sd35_base.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion onediff_diffusers_extensions/examples/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
### Acceleration with Onediff-Community

```
NEXFORT_ENABLE_FP8_QUANTIZE_ATTENTION=0 python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
NEXFORT_ENABLE_TRITON_AUTOTUNE_CACHE=0 \
NEXFORT_ENABLE_FP8_QUANTIZE_ATTENTION=0 \
python3 onediff_diffusers_extensions/examples/flux/text_to_image_flux.py \
--transform \
--saved-image flux_compile.png
```
Expand Down
147 changes: 147 additions & 0 deletions onediff_diffusers_extensions/examples/sd35/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Run SD35 with nexfort backend (Beta Release)

1. [Environment Setup](#environment-setup)
- [Set Up OneDiff](#set-up-onediff)
- [Set Up NexFort Backend](#set-up-nexfort-backend)
- [Set Up Diffusers Library](#set-up-diffusers)
- [Download SD35 Model for Diffusers](#set-up-sd35)
2. [Execution Instructions](#run)
3. [Performance Comparison](#performance-comparation)
4. [Dynamic Shape for SD35](#dynamic-shape-for-sd25)
5. [Quality](#quality)

## Environment setup
### Set up onediff
https://github.com/siliconflow/onediff?tab=readme-ov-file#installation

### Set up nexfort backend
https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort

### Set up diffusers

```
# Ensure diffusers include the SD35 pipeline.
pip3 install --upgrade diffusers[torch]
```
### Set up SD35
Model version for diffusers: https://huggingface.co/stabilityai/stable-diffusion-3.5-large

HF pipeline: https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/README.md

## Run

### Run 1024*1024 without compile (the original pytorch HF diffusers baseline)
```
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \
--saved-image sd35.png
```

### Run 1024*1024 with compile


## Performance comparation
### Acceleration with Onediff-Community

```
NEXFORT_ENABLE_TRITON_AUTOTUNE_CACHE=0 \
NEXFORT_ENABLE_FP8_QUANTIZE_ATTENTION=0 \
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \
--transform \
--saved-image sd35_compile.png
```

Testing on NVIDIA H20, with image size of 1024*1024, iterating 28 steps:
| Metric | |
| ------------------------------------------------ | ------------------- |
| Data update date(yyyy-mm-dd) | 2024-11-22 |
| PyTorch iteration speed | 1.47 it/s |
| OneDiff iteration speed | 1.82 it/s (+23.8%) |
| PyTorch E2E time | 19.41 s |
| OneDiff E2E time | 15.99 s (-17.6%) |
| PyTorch Max Mem Used | 28.525 GiB |
| OneDiff Max Mem Used | 28.524 GiB |
| PyTorch Warmup with Run time | 20.42 s |
| OneDiff Warmup with Compilation time<sup>1</sup> | 96.81 s |
| OneDiff Warmup with Cache time | 17.29 s |

<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V. Note this is just for reference, and it varies a lot on different CPU.

### Acceleration with Onediff-Enterprise(with quantization)
```
NEXFORT_FORCE_QUANTE_ON_CUDA=1 python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \
--quantize \
--transform \
--saved-image sd35_compile.png
```

Testing on NVIDIA H20, with image size of 1024*1024, iterating 28 steps:
| Metric | |
| ------------------------------------------------ | ------------------- |
| Data update date(yyyy-mm-dd) | 2024-11-22 |
| PyTorch iteration speed | 1.47 it/s |
| OneDiff iteration speed | 2.72 it/s (+85.0%) |
| PyTorch E2E time | 19.41 s |
| OneDiff E2E time | 10.76 s (-44.6%) |
| PyTorch Max Mem Used | 28.525 GiB |
| OneDiff Max Mem Used | 20.713 GiB |
| PyTorch Warmup with Run time | 20.42 s |
| OneDiff Warmup with Compilation time<sup>1</sup> | 157.37 s |
| OneDiff Warmup with Cache time | 12.04 s |

<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V. Note this is just for reference, and it varies a lot on different CPU.

```
NEXFORT_FORCE_QUANTE_ON_CUDA=1 python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \
--quantize \
--transform \
--speedup-t5 \ # Must quantize t5, because 4090 has only 24GB of memory
--saved-image sd35_compile.png
```


Testing on RTX 4090, with image size of 1024*1024, iterating 28 steps::
| Metric | |
| ------------------------------------------------ | ------------------- |
| Data update date(yyyy-mm-dd) | 2024-11-22 |
| PyTorch iteration speed | OOM |
| OneDiff iteration speed | 3.01 it/s |
| PyTorch E2E time | OOM |
| OneDiff E2E time | 9.79 s |
| PyTorch Max Mem Used | OOM |
| OneDiff Max Mem Used | 20.109 GiB |
| PyTorch Warmup with Run time | OOM |
| OneDiff Warmup with Compilation time<sup>2</sup> | 136.77 s |
| OneDiff Warmup with Cache time | 10.74 s |

<sup>2</sup> OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor


## Dynamic shape for SD35

Run:

```
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \
--quantize \
--transform \
--run_multiple_resolutions \
--saved-image sd35_compile.png
```

## Quality
When using nexfort as the backend for onediff compilation acceleration, the generated images are nearly lossless.(The following images are generated on an NVIDIA H20)

### Generated image with pytorch
<p align="center">
<img src="../../../imgs/sd35_base.png">
</p>

### Generated image with nexfort acceleration(Community)
<p align="center">
<img src="../../../imgs/nexfort_sd35_community.png">
</p>

### Generated image with nexfort acceleration(Enterprise)
<p align="center">
<img src="../../../imgs/nexfort_sd35_enterprise.png">
</p>
223 changes: 223 additions & 0 deletions onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import argparse
import json
import time

import nexfort

import torch
from diffusers import StableDiffusion3Pipeline


def parse_args():
parser = argparse.ArgumentParser(
description="Use nexfort to accelerate image generation with SD35."
)
parser.add_argument(
"--model",
type=str,
default="stabilityai/stable-diffusion-3.5-large",
help="Model path or identifier.",
)
parser.add_argument(
"--speedup-t5",
action="store_true",
help="Enable optimize t5.",
)
parser.add_argument(
"--quantize",
action="store_true",
help="Enable fp8 quantization.",
)
parser.add_argument(
"--transform",
action="store_true",
help="Enable speedup with nexfort.",
)
parser.add_argument(
"--prompt",
type=str,
default="evening sunset scenery blue sky nature, glass bottle with a galaxy in it.",
)
parser.add_argument(
"--height", type=int, default=1024, help="Height of the generated image."
)
parser.add_argument(
"--width", type=int, default=1024, help="Width of the generated image."
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="The scale factor for the guidance.",
)
parser.add_argument(
"--num-inference-steps", type=int, default=28, help="Number of inference steps."
)
parser.add_argument(
"--saved-image",
type=str,
default="./sd35.png",
help="Path to save the generated image.",
)
parser.add_argument(
"--seed", type=int, default=20, help="Seed for random number generation."
)
parser.add_argument(
"--run_multiple_resolutions",
action="store_true",
)
parser.add_argument(
"--run_multiple_prompts",
action="store_true",
)
return parser.parse_args()


args = parse_args()

device = torch.device("cuda")


def generate_texts(min_length=50, max_length=302):
base_text = "a female character with long, flowing hair that appears to be made of ethereal, swirling patterns resembling the Northern Lights or Aurora Borealis. The background is dominated by deep blues and purples, creating a mysterious and dramatic atmosphere. The character's face is serene, with pale skin and striking features. She"

additional_words = [
"gracefully",
"beautifully",
"elegant",
"radiant",
"mysteriously",
"vibrant",
"softly",
"gently",
"luminescent",
"sparkling",
"delicately",
"glowing",
"brightly",
"shimmering",
"enchanting",
"gloriously",
"magnificent",
"majestic",
"fantastically",
"dazzlingly",
]

for i in range(min_length, max_length):
idx = i % len(additional_words)
base_text += " " + additional_words[idx]
yield base_text


class SD35Generator:
def __init__(
self,
model,
enable_quantize=False,
enable_fast_transformer=False,
enable_speedup_t5=False,
):
self.pipe = StableDiffusion3Pipeline.from_pretrained(
model,
torch_dtype=torch.bfloat16,
)

# Put the quantize process after `self.pipe.to(device)` if you have more than 32GB ram.
if enable_quantize:
print("quant...")
from nexfort.quantization import quantize

self.pipe.transformer = quantize(
self.pipe.transformer, quant_type="fp8_e4m3_e4m3_dynamic_per_tensor"
)
if enable_speedup_t5:
self.pipe.text_encoder_2 = quantize(
self.pipe.text_encoder_2,
quant_type="fp8_e4m3_e4m3_dynamic_per_tensor",
)

self.pipe.to(device)

if enable_fast_transformer:
print("compile...")
from nexfort.compilers import transform

self.pipe.transformer = transform(self.pipe.transformer)
if enable_speedup_t5:
self.pipe.text_encoder_2 = transform(self.pipe.text_encoder_2)

def warmup(self, gen_args, warmup_iterations=1):
warmup_args = gen_args.copy()

# warmup_args["generator"] = torch.Generator(device=device).manual_seed(0)
torch.manual_seed(args.seed)

print("Starting warmup...")
start_time = time.time()
for _ in range(warmup_iterations):
self.pipe(**warmup_args)
end_time = time.time()
print("Warmup complete.")
print(f"Warmup time: {end_time - start_time:.2f} seconds")

def generate(self, gen_args):
# gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed)
torch.manual_seed(args.seed)

# Run the model
start_time = time.time()
image = self.pipe(**gen_args).images[0]
end_time = time.time()

image.save(args.saved_image)

return image, end_time - start_time


def main():
sd35 = SD35Generator(args.model, args.quantize, args.transform, args.speedup_t5)

if args.run_multiple_prompts:
dynamic_prompts = generate_texts(max_length=101)
prompt_list = list(dynamic_prompts)
else:
prompt_list = [args.prompt]

gen_args = {
"prompt": args.prompt,
"num_inference_steps": args.num_inference_steps,
"height": args.height,
"width": args.width,
"guidance_scale": args.guidance_scale,
}

sd35.warmup(gen_args)

for prompt in prompt_list:
gen_args["prompt"] = prompt
print(f"Processing prompt of length {len(prompt)} characters.")
image, inference_time = sd35.generate(gen_args)
print(
f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
)
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")

if args.run_multiple_resolutions:
gen_args["prompt"] = args.prompt
print("Test run with multiple resolutions...")
sizes = [1536, 1024, 768, 720, 576, 512, 256]
for h in sizes:
for w in sizes:
gen_args["height"] = h
gen_args["width"] = w
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
sd35.generate(gen_args)
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")


if __name__ == "__main__":
main()

0 comments on commit 74c91c7

Please sign in to comment.