-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a comprehensive README for the SD35 example, detailing environment setup, execution instructions, performance comparisons, and image quality. - Added a command-line interface for generating images using the Stable Diffusion 3.5 model, allowing customization through various options. - Updated README for the Flux example with revised performance comparison commands. - **Documentation** - New README.md file provides structured guidance for users on utilizing the SD35 model effectively. - Modifications to the Flux example README ensure continued clarity on execution instructions and performance metrics. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
- Loading branch information
Showing
6 changed files
with
373 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Run SD35 with nexfort backend (Beta Release) | ||
|
||
1. [Environment Setup](#environment-setup) | ||
- [Set Up OneDiff](#set-up-onediff) | ||
- [Set Up NexFort Backend](#set-up-nexfort-backend) | ||
- [Set Up Diffusers Library](#set-up-diffusers) | ||
- [Download SD35 Model for Diffusers](#set-up-sd35) | ||
2. [Execution Instructions](#run) | ||
3. [Performance Comparison](#performance-comparation) | ||
4. [Dynamic Shape for SD35](#dynamic-shape-for-sd25) | ||
5. [Quality](#quality) | ||
|
||
## Environment setup | ||
### Set up onediff | ||
https://github.com/siliconflow/onediff?tab=readme-ov-file#installation | ||
|
||
### Set up nexfort backend | ||
https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort | ||
|
||
### Set up diffusers | ||
|
||
``` | ||
# Ensure diffusers include the SD35 pipeline. | ||
pip3 install --upgrade diffusers[torch] | ||
``` | ||
### Set up SD35 | ||
Model version for diffusers: https://huggingface.co/stabilityai/stable-diffusion-3.5-large | ||
|
||
HF pipeline: https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/README.md | ||
|
||
## Run | ||
|
||
### Run 1024*1024 without compile (the original pytorch HF diffusers baseline) | ||
``` | ||
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \ | ||
--saved-image sd35.png | ||
``` | ||
|
||
### Run 1024*1024 with compile | ||
|
||
|
||
## Performance comparation | ||
### Acceleration with Onediff-Community | ||
|
||
``` | ||
NEXFORT_ENABLE_TRITON_AUTOTUNE_CACHE=0 \ | ||
NEXFORT_ENABLE_FP8_QUANTIZE_ATTENTION=0 \ | ||
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \ | ||
--transform \ | ||
--saved-image sd35_compile.png | ||
``` | ||
|
||
Testing on NVIDIA H20, with image size of 1024*1024, iterating 28 steps: | ||
| Metric | | | ||
| ------------------------------------------------ | ------------------- | | ||
| Data update date(yyyy-mm-dd) | 2024-11-22 | | ||
| PyTorch iteration speed | 1.47 it/s | | ||
| OneDiff iteration speed | 1.82 it/s (+23.8%) | | ||
| PyTorch E2E time | 19.41 s | | ||
| OneDiff E2E time | 15.99 s (-17.6%) | | ||
| PyTorch Max Mem Used | 28.525 GiB | | ||
| OneDiff Max Mem Used | 28.524 GiB | | ||
| PyTorch Warmup with Run time | 20.42 s | | ||
| OneDiff Warmup with Compilation time<sup>1</sup> | 96.81 s | | ||
| OneDiff Warmup with Cache time | 17.29 s | | ||
|
||
<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V. Note this is just for reference, and it varies a lot on different CPU. | ||
|
||
### Acceleration with Onediff-Enterprise(with quantization) | ||
``` | ||
NEXFORT_FORCE_QUANTE_ON_CUDA=1 python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \ | ||
--quantize \ | ||
--transform \ | ||
--saved-image sd35_compile.png | ||
``` | ||
|
||
Testing on NVIDIA H20, with image size of 1024*1024, iterating 28 steps: | ||
| Metric | | | ||
| ------------------------------------------------ | ------------------- | | ||
| Data update date(yyyy-mm-dd) | 2024-11-22 | | ||
| PyTorch iteration speed | 1.47 it/s | | ||
| OneDiff iteration speed | 2.72 it/s (+85.0%) | | ||
| PyTorch E2E time | 19.41 s | | ||
| OneDiff E2E time | 10.76 s (-44.6%) | | ||
| PyTorch Max Mem Used | 28.525 GiB | | ||
| OneDiff Max Mem Used | 20.713 GiB | | ||
| PyTorch Warmup with Run time | 20.42 s | | ||
| OneDiff Warmup with Compilation time<sup>1</sup> | 157.37 s | | ||
| OneDiff Warmup with Cache time | 12.04 s | | ||
|
||
<sup>1</sup> OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468V. Note this is just for reference, and it varies a lot on different CPU. | ||
|
||
``` | ||
NEXFORT_FORCE_QUANTE_ON_CUDA=1 python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \ | ||
--quantize \ | ||
--transform \ | ||
--speedup-t5 \ # Must quantize t5, because 4090 has only 24GB of memory | ||
--saved-image sd35_compile.png | ||
``` | ||
|
||
|
||
Testing on RTX 4090, with image size of 1024*1024, iterating 28 steps:: | ||
| Metric | | | ||
| ------------------------------------------------ | ------------------- | | ||
| Data update date(yyyy-mm-dd) | 2024-11-22 | | ||
| PyTorch iteration speed | OOM | | ||
| OneDiff iteration speed | 3.01 it/s | | ||
| PyTorch E2E time | OOM | | ||
| OneDiff E2E time | 9.79 s | | ||
| PyTorch Max Mem Used | OOM | | ||
| OneDiff Max Mem Used | 20.109 GiB | | ||
| PyTorch Warmup with Run time | OOM | | ||
| OneDiff Warmup with Compilation time<sup>2</sup> | 136.77 s | | ||
| OneDiff Warmup with Cache time | 10.74 s | | ||
|
||
<sup>2</sup> OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor | ||
|
||
|
||
## Dynamic shape for SD35 | ||
|
||
Run: | ||
|
||
``` | ||
python3 onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py \ | ||
--quantize \ | ||
--transform \ | ||
--run_multiple_resolutions \ | ||
--saved-image sd35_compile.png | ||
``` | ||
|
||
## Quality | ||
When using nexfort as the backend for onediff compilation acceleration, the generated images are nearly lossless.(The following images are generated on an NVIDIA H20) | ||
|
||
### Generated image with pytorch | ||
<p align="center"> | ||
<img src="../../../imgs/sd35_base.png"> | ||
</p> | ||
|
||
### Generated image with nexfort acceleration(Community) | ||
<p align="center"> | ||
<img src="../../../imgs/nexfort_sd35_community.png"> | ||
</p> | ||
|
||
### Generated image with nexfort acceleration(Enterprise) | ||
<p align="center"> | ||
<img src="../../../imgs/nexfort_sd35_enterprise.png"> | ||
</p> |
223 changes: 223 additions & 0 deletions
223
onediff_diffusers_extensions/examples/sd35/text_to_image_sd35.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
import argparse | ||
import json | ||
import time | ||
|
||
import nexfort | ||
|
||
import torch | ||
from diffusers import StableDiffusion3Pipeline | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description="Use nexfort to accelerate image generation with SD35." | ||
) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
default="stabilityai/stable-diffusion-3.5-large", | ||
help="Model path or identifier.", | ||
) | ||
parser.add_argument( | ||
"--speedup-t5", | ||
action="store_true", | ||
help="Enable optimize t5.", | ||
) | ||
parser.add_argument( | ||
"--quantize", | ||
action="store_true", | ||
help="Enable fp8 quantization.", | ||
) | ||
parser.add_argument( | ||
"--transform", | ||
action="store_true", | ||
help="Enable speedup with nexfort.", | ||
) | ||
parser.add_argument( | ||
"--prompt", | ||
type=str, | ||
default="evening sunset scenery blue sky nature, glass bottle with a galaxy in it.", | ||
) | ||
parser.add_argument( | ||
"--height", type=int, default=1024, help="Height of the generated image." | ||
) | ||
parser.add_argument( | ||
"--width", type=int, default=1024, help="Width of the generated image." | ||
) | ||
parser.add_argument( | ||
"--guidance_scale", | ||
type=float, | ||
default=3.5, | ||
help="The scale factor for the guidance.", | ||
) | ||
parser.add_argument( | ||
"--num-inference-steps", type=int, default=28, help="Number of inference steps." | ||
) | ||
parser.add_argument( | ||
"--saved-image", | ||
type=str, | ||
default="./sd35.png", | ||
help="Path to save the generated image.", | ||
) | ||
parser.add_argument( | ||
"--seed", type=int, default=20, help="Seed for random number generation." | ||
) | ||
parser.add_argument( | ||
"--run_multiple_resolutions", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--run_multiple_prompts", | ||
action="store_true", | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
args = parse_args() | ||
|
||
device = torch.device("cuda") | ||
|
||
|
||
def generate_texts(min_length=50, max_length=302): | ||
base_text = "a female character with long, flowing hair that appears to be made of ethereal, swirling patterns resembling the Northern Lights or Aurora Borealis. The background is dominated by deep blues and purples, creating a mysterious and dramatic atmosphere. The character's face is serene, with pale skin and striking features. She" | ||
|
||
additional_words = [ | ||
"gracefully", | ||
"beautifully", | ||
"elegant", | ||
"radiant", | ||
"mysteriously", | ||
"vibrant", | ||
"softly", | ||
"gently", | ||
"luminescent", | ||
"sparkling", | ||
"delicately", | ||
"glowing", | ||
"brightly", | ||
"shimmering", | ||
"enchanting", | ||
"gloriously", | ||
"magnificent", | ||
"majestic", | ||
"fantastically", | ||
"dazzlingly", | ||
] | ||
|
||
for i in range(min_length, max_length): | ||
idx = i % len(additional_words) | ||
base_text += " " + additional_words[idx] | ||
yield base_text | ||
|
||
|
||
class SD35Generator: | ||
def __init__( | ||
self, | ||
model, | ||
enable_quantize=False, | ||
enable_fast_transformer=False, | ||
enable_speedup_t5=False, | ||
): | ||
self.pipe = StableDiffusion3Pipeline.from_pretrained( | ||
model, | ||
torch_dtype=torch.bfloat16, | ||
) | ||
|
||
# Put the quantize process after `self.pipe.to(device)` if you have more than 32GB ram. | ||
if enable_quantize: | ||
print("quant...") | ||
from nexfort.quantization import quantize | ||
|
||
self.pipe.transformer = quantize( | ||
self.pipe.transformer, quant_type="fp8_e4m3_e4m3_dynamic_per_tensor" | ||
) | ||
if enable_speedup_t5: | ||
self.pipe.text_encoder_2 = quantize( | ||
self.pipe.text_encoder_2, | ||
quant_type="fp8_e4m3_e4m3_dynamic_per_tensor", | ||
) | ||
|
||
self.pipe.to(device) | ||
|
||
if enable_fast_transformer: | ||
print("compile...") | ||
from nexfort.compilers import transform | ||
|
||
self.pipe.transformer = transform(self.pipe.transformer) | ||
if enable_speedup_t5: | ||
self.pipe.text_encoder_2 = transform(self.pipe.text_encoder_2) | ||
|
||
def warmup(self, gen_args, warmup_iterations=1): | ||
warmup_args = gen_args.copy() | ||
|
||
# warmup_args["generator"] = torch.Generator(device=device).manual_seed(0) | ||
torch.manual_seed(args.seed) | ||
|
||
print("Starting warmup...") | ||
start_time = time.time() | ||
for _ in range(warmup_iterations): | ||
self.pipe(**warmup_args) | ||
end_time = time.time() | ||
print("Warmup complete.") | ||
print(f"Warmup time: {end_time - start_time:.2f} seconds") | ||
|
||
def generate(self, gen_args): | ||
# gen_args["generator"] = torch.Generator(device=device).manual_seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
|
||
# Run the model | ||
start_time = time.time() | ||
image = self.pipe(**gen_args).images[0] | ||
end_time = time.time() | ||
|
||
image.save(args.saved_image) | ||
|
||
return image, end_time - start_time | ||
|
||
|
||
def main(): | ||
sd35 = SD35Generator(args.model, args.quantize, args.transform, args.speedup_t5) | ||
|
||
if args.run_multiple_prompts: | ||
dynamic_prompts = generate_texts(max_length=101) | ||
prompt_list = list(dynamic_prompts) | ||
else: | ||
prompt_list = [args.prompt] | ||
|
||
gen_args = { | ||
"prompt": args.prompt, | ||
"num_inference_steps": args.num_inference_steps, | ||
"height": args.height, | ||
"width": args.width, | ||
"guidance_scale": args.guidance_scale, | ||
} | ||
|
||
sd35.warmup(gen_args) | ||
|
||
for prompt in prompt_list: | ||
gen_args["prompt"] = prompt | ||
print(f"Processing prompt of length {len(prompt)} characters.") | ||
image, inference_time = sd35.generate(gen_args) | ||
print( | ||
f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds." | ||
) | ||
cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) | ||
print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB") | ||
|
||
if args.run_multiple_resolutions: | ||
gen_args["prompt"] = args.prompt | ||
print("Test run with multiple resolutions...") | ||
sizes = [1536, 1024, 768, 720, 576, 512, 256] | ||
for h in sizes: | ||
for w in sizes: | ||
gen_args["height"] = h | ||
gen_args["width"] = w | ||
print(f"Running at resolution: {h}x{w}") | ||
start_time = time.time() | ||
sd35.generate(gen_args) | ||
end_time = time.time() | ||
print(f"Inference time: {end_time - start_time:.2f} seconds") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |