From 043e0f97108dd45e0781fc1f2c00128ce3dce0a4 Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Mon, 1 Jul 2024 00:04:30 +0800
Subject: [PATCH 01/11] Add sdxl lightning quant use
---
.../examples/lightning/README.md | 13 ++++++++++++
.../tools/quantization/quantize-sd-fast.py | 20 +++++++++++++++++++
2 files changed, 33 insertions(+)
create mode 100644 onediff_diffusers_extensions/examples/lightning/README.md
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
new file mode 100644
index 000000000..09ffdc916
--- /dev/null
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -0,0 +1,13 @@
+Run:
+
+
+```
+python3 onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py \
+ --quantized_model ./sdxl_lightning_quant \
+ --conv_ssim_threshold 0.1 \
+ --linear_ssim_threshold 0.1 \
+ --conv_compute_density_threshold 900 \
+ --linear_compute_density_threshold 300 \
+ --save_as_float true \
+ --use_lightning 1
+```
\ No newline at end of file
diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
index 61e920bc8..eb070b2e1 100644
--- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
+++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
@@ -11,7 +11,11 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
+ UNet2DConditionModel,
+ EulerDiscreteScheduler
)
+from safetensors.torch import load_file
+from huggingface_hub import hf_hub_download
from onediff.quantization import QuantPipeline
@@ -62,6 +66,9 @@
)
parser.add_argument("--seed", type=int, default=111)
parser.add_argument("--cache_dir", type=str, default=None)
+parser.add_argument("--use_lightning", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=False, help="Use the SDXL Lightning model if true")
+parser.add_argument("--lightning_ckpt", type=str, default="sdxl_lightning_4step_unet.safetensors",
+ help="Checkpoint file name for the ByteDance SDXL-Lightning model")
args = parser.parse_args()
pipeline_cls = AutoPipelineForText2Image if args.input_image is None else AutoPipelineForImage2Image
@@ -87,6 +94,19 @@
use_safetensors=True,
)
+if args.use_lightning:
+ repo = "ByteDance/SDXL-Lightning"
+ ckpt = args.lightning_ckpt
+ unet = UNet2DConditionModel.from_config(args.model, subfolder="unet").to("cuda", torch.float16)
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
+ pipe = QuantPipeline.from_pretrained(
+ pipeline_cls,
+ args.model,
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant=args.variant,
+ use_safetensors=True,
+ )
else:
pipe = QuantPipeline.from_pretrained(
pipeline_cls,
From e5911e04eeac5486f77cab39836b05136a8193af Mon Sep 17 00:00:00 2001
From: Li Xiang <54010254+lixiang007666@users.noreply.github.com>
Date: Mon, 1 Jul 2024 03:00:15 +0800
Subject: [PATCH 02/11] Update quantize-sd-fast.py
---
.../tools/quantization/quantize-sd-fast.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
index eb070b2e1..8d054c550 100644
--- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
+++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
@@ -11,8 +11,7 @@
StableDiffusionXLImg2ImgPipeline,
StableDiffusionPipeline,
StableDiffusionImg2ImgPipeline,
- UNet2DConditionModel,
- EulerDiscreteScheduler
+ UNet2DConditionModel
)
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
From c86cf6b794889884053e79051797cf6609cd5061 Mon Sep 17 00:00:00 2001
From: Li Xiang <54010254+lixiang007666@users.noreply.github.com>
Date: Mon, 29 Jul 2024 10:19:55 +0800
Subject: [PATCH 03/11] Update quantize-sd-fast.py
---
.../tools/quantization/quantize-sd-fast.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
index 15ec57d54..322fb8572 100644
--- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
+++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
@@ -8,7 +8,7 @@
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
StableDiffusionImg2ImgPipeline,
- UNet2DConditionModel
+ UNet2DConditionModel,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
From ac9e77c5016ee35b85031ffe8dd8697297a66869 Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Tue, 30 Jul 2024 09:49:14 +0800
Subject: [PATCH 04/11] Refine
---
.../examples/lightning/README.md | 93 +++++++++++++++-
.../text_to_image_sdxl_light.py | 104 ++++++++++++++----
.../tools/quantization/quantize-sd-fast.py | 23 +++-
3 files changed, 189 insertions(+), 31 deletions(-)
rename onediff_diffusers_extensions/examples/{ => lightning}/text_to_image_sdxl_light.py (50%)
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
index 09ffdc916..70014b7b1 100644
--- a/onediff_diffusers_extensions/examples/lightning/README.md
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -1,13 +1,96 @@
-Run:
-
+# Run SDXL-Lightning with OneDiff
+
+## Environment Setup
+
+### Set Up OneDiff
+Follow the instructions to set up OneDiff from the https://github.com/siliconflow/onediff?tab=readme-ov-file#installation.
+
+### Set Up Compiler Backend
+OneDiff supports two compiler backends: OneFlow and NexFort. Follow the setup instructions for these backends from the https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend.
+
+
+### Set Up SDXL-Lightning
+- HF model: [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
+- HF pipeline: [diffusers usage](https://huggingface.co/ByteDance/SDXL-Lightning#2-step-4-step-8-step-unet)
+
+## Compile
+
+> [!NOTE]
+Current test is based on an 8 steps distillation model.
+
+### Run 1024x1024 Without Compile (Original PyTorch HF Diffusers Baseline)
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+--saved_image sdxl_light.png
+```
+
+### Run 1024x1024 With Compile [OneFlow Backend]
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+--compiler oneflow \
+--saved_image sdxl_light_oneflow_compile.png
+```
+
+### Run 1024x1024 With Compile [NexFort Backend]
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+--compiler nexfort \
+--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+--saved_image sdxl_light_nexfort_compile.png
+```
+
+
+## Quantization (Int8)
+
+> [!NOTE]
+Quantization is a feature for onediff enterprise.
+
+### Run 1024x1024 With Quantization [OneFlow Backend]
+
+Execute the following command to quantize the model, where `--quantized_model` is the path to the quantized model. For an introduction to the quantization parameters, refer to: https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#diffusers-with-onediff-enterprise
```
python3 onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py \
- --quantized_model ./sdxl_lightning_quant \
+ --quantized_model ./sdxl_lightning_oneflow_quant \
--conv_ssim_threshold 0.1 \
--linear_ssim_threshold 0.1 \
- --conv_compute_density_threshold 900 \
+ --conv_compute_density_threshold 300 \
--linear_compute_density_threshold 300 \
--save_as_float true \
--use_lightning 1
-```
\ No newline at end of file
+```
+
+Test the quantized model:
+
+```
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+--compiler oneflow \
+--use_quantization \
+--base ./sdxl_lightning_oneflow_quant \
+--saved_image sdxl_light_oneflow_quant.png
+```
+
+
+### Run 1024x1024 With Quantization [NexFort Backend]
+
+```
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --compiler nexfort \
+ --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+ --use_quantization \
+ --quantize-config '{"quant_type": "int8_dynamic"}' \
+ --saved_image sdxl_light_nexfort_quant.png
+```
+
+
+## Performance Comparison
+
+**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 8 steps:**
+
+| Configuration | Iteration Speed (it/s) | E2E Time (seconds) |
+|---------------------------|---------------------------------|---------------------------------|
+| PyTorch | 14.68 | 0.840 |
+| OneFlow Compile | 29.06 (+97.83%) | 0.530 (-36.90%) |
+| OneFlow Quantization | 43.45 (+195.95%) | 0.424 (-49.52%) |
+| NexFort Compile | 28.07 (+91.18%) | 0.526 (-37.38%) |
+| NexFort Quantization | 30.85 (+110.15%) | 0.476 (-43.33%) |
diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
similarity index 50%
rename from onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py
rename to onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
index d88b1f074..c7223d759 100644
--- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py
+++ b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
@@ -1,11 +1,12 @@
import argparse
+import json
import os
import time
import torch
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
-from onediffx import compile_pipe, load_pipe, save_pipe
+from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe
from safetensors.torch import load_file
try:
@@ -18,7 +19,7 @@
"--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
)
parser.add_argument("--repo", type=str, default="ByteDance/SDXL-Lightning")
-parser.add_argument("--cpkt", type=str, default="sdxl_lightning_4step_unet.safetensors")
+parser.add_argument("--cpkt", type=str, default="sdxl_lightning_8step_unet.safetensors")
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
@@ -37,10 +38,19 @@
)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument(
- "--compile",
- type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
- default=True,
+ "--compiler",
+ type=str,
+ default="none",
+ help="Compiler backend to use. Options: 'none', 'nexfort', 'oneflow'",
+)
+parser.add_argument(
+ "--compiler-config", type=str, help="JSON string for nexfort compiler config."
+)
+parser.add_argument(
+ "--quantize-config", type=str, help="JSON string for nexfort quantization config."
)
+parser.add_argument("--bits", type=int, default=8)
+parser.add_argument("--use_quantization", action="store_true")
args = parser.parse_args()
@@ -51,7 +61,7 @@
is_lora_cpkt = "lora" in args.cpkt
-if args.compile:
+if args.compiler == "oneflow":
from onediff.schedulers import EulerDiscreteScheduler
else:
from diffusers import EulerDiscreteScheduler
@@ -69,22 +79,58 @@
pipe.load_lora_weights(hf_hub_download(args.repo, args.cpkt))
pipe.fuse_lora()
else:
- from diffusers import UNet2DConditionModel
+ if args.use_quantization and args.compiler == "oneflow":
+ print("oneflow backend quant...")
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.base, torch_dtype=torch.float16, variant="fp16"
+ ).to("cuda")
+ import onediff_quant
+ from onediff_quant.utils import replace_sub_module_with_quantizable_module
+
+ quantized_layers_count = 0
+ onediff_quant.enable_load_quantized_model()
+
+ calibrate_info = {}
+ with open(os.path.join(args.base, "calibrate_info.txt"), "r") as f:
+ for line in f.readlines():
+ line = line.strip()
+ items = line.split(" ")
+ calibrate_info[items[0]] = [
+ float(items[1]),
+ int(items[2]),
+ [float(x) for x in items[3].split(",")],
+ ]
+
+ for sub_module_name, sub_calibrate_info in calibrate_info.items():
+ replace_sub_module_with_quantizable_module(
+ pipe.unet,
+ sub_module_name,
+ sub_calibrate_info,
+ False,
+ False,
+ args.bits,
+ )
+ quantized_layers_count += 1
+
+ print(f"Total quantized layers: {quantized_layers_count}")
- unet = UNet2DConditionModel.from_config(args.base, subfolder="unet").to(
- "cuda", torch.float16
- )
- if os.path.isfile(os.path.join(args.repo, args.cpkt)):
- unet.load_state_dict(
- load_file(os.path.join(args.repo, args.cpkt), device="cuda")
- )
else:
- unet.load_state_dict(
- load_file(hf_hub_download(args.repo, args.cpkt), device="cuda")
+ from diffusers import UNet2DConditionModel
+
+ unet = UNet2DConditionModel.from_config(args.base, subfolder="unet").to(
+ "cuda", torch.float16
)
- pipe = StableDiffusionXLPipeline.from_pretrained(
- args.base, unet=unet, torch_dtype=torch.float16, variant="fp16"
- ).to("cuda")
+ if os.path.isfile(os.path.join(args.repo, args.cpkt)):
+ unet.load_state_dict(
+ load_file(os.path.join(args.repo, args.cpkt), device="cuda")
+ )
+ else:
+ unet.load_state_dict(
+ load_file(hf_hub_download(args.repo, args.cpkt), device="cuda")
+ )
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.base, unet=unet, torch_dtype=torch.float16, variant="fp16"
+ ).to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing"
@@ -94,13 +140,31 @@
pipe.upcast_vae()
# Compile the pipeline
-if args.compile:
+if args.compiler == "oneflow":
+ print("oneflow backend compile...")
pipe = compile_pipe(
pipe,
)
if args.load_graph:
print("Loading graphs...")
load_pipe(pipe, args.load_graph_dir)
+elif args.compiler == "nexfort":
+ print("nexfort backend compile...")
+ nexfort_compiler_config = (
+ json.loads(args.compiler_config) if args.compiler_config else None
+ )
+
+ options = nexfort_compiler_config
+ pipe = compile_pipe(
+ pipe, backend="nexfort", options=options, fuse_qkv_projections=True
+ )
+ if args.use_quantization and args.compiler == "nexfort":
+ print("nexfort backend quant...")
+ nexfort_quantize_config = (
+ json.loads(args.quantize_config) if args.quantize_config else None
+ )
+ pipe = quantize_pipe(pipe, ignores=[], **nexfort_quantize_config)
+
print("Warmup with running graphs...")
torch.manual_seed(args.seed)
diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
index 322fb8572..9ed6ecadd 100644
--- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
+++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
@@ -8,16 +8,16 @@
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
StableDiffusionImg2ImgPipeline,
- UNet2DConditionModel,
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
+ UNet2DConditionModel,
)
-from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from onediff.quantization import QuantPipeline
from PIL import Image
+from safetensors.torch import load_file
parser = argparse.ArgumentParser()
@@ -66,9 +66,18 @@
)
parser.add_argument("--seed", type=int, default=111)
parser.add_argument("--cache_dir", type=str, default=None)
-parser.add_argument("--use_lightning", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=False, help="Use the SDXL Lightning model if true")
-parser.add_argument("--lightning_ckpt", type=str, default="sdxl_lightning_4step_unet.safetensors",
- help="Checkpoint file name for the ByteDance SDXL-Lightning model")
+parser.add_argument(
+ "--use_lightning",
+ type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
+ default=False,
+ help="Use the SDXL Lightning model if true",
+)
+parser.add_argument(
+ "--lightning_ckpt",
+ type=str,
+ default="sdxl_lightning_8step_unet.safetensors",
+ help="Checkpoint file name for the ByteDance SDXL-Lightning model",
+)
args = parser.parse_args()
pipeline_cls = (
@@ -111,7 +120,9 @@
if args.use_lightning:
repo = "ByteDance/SDXL-Lightning"
ckpt = args.lightning_ckpt
- unet = UNet2DConditionModel.from_config(args.model, subfolder="unet").to("cuda", torch.float16)
+ unet = UNet2DConditionModel.from_config(args.model, subfolder="unet").to(
+ "cuda", torch.float16
+ )
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
pipe = QuantPipeline.from_pretrained(
pipeline_cls,
From 20fb0c333780ffe5fc21e3c787c1069439badc8d Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Tue, 30 Jul 2024 10:54:58 +0800
Subject: [PATCH 05/11] Fix ci
---
.../examples/lightning/README.md | 43 +++++++++++--------
.../lightning/text_to_image_sdxl_light.py | 2 +-
.../examples/save_and_load_pipeline.sh | 8 ++--
3 files changed, 29 insertions(+), 24 deletions(-)
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
index 70014b7b1..78be606bc 100644
--- a/onediff_diffusers_extensions/examples/lightning/README.md
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -21,22 +21,25 @@ Current test is based on an 8 steps distillation model.
### Run 1024x1024 Without Compile (Original PyTorch HF Diffusers Baseline)
```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
---saved_image sdxl_light.png
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --saved_image sdxl_light.png
```
### Run 1024x1024 With Compile [OneFlow Backend]
```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
---compiler oneflow \
---saved_image sdxl_light_oneflow_compile.png
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler oneflow \
+ --saved_image sdxl_light_oneflow_compile.png
```
### Run 1024x1024 With Compile [NexFort Backend]
```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
---compiler nexfort \
---compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
---saved_image sdxl_light_nexfort_compile.png
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler nexfort \
+ --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+ --saved_image sdxl_light_nexfort_compile.png
```
@@ -49,9 +52,9 @@ Quantization is a feature for onediff enterprise.
Execute the following command to quantize the model, where `--quantized_model` is the path to the quantized model. For an introduction to the quantization parameters, refer to: https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#diffusers-with-onediff-enterprise
-```
+```bash
python3 onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py \
- --quantized_model ./sdxl_lightning_oneflow_quant \
+ --quantized_model /path/to/sdxl_lightning_oneflow_quant \
--conv_ssim_threshold 0.1 \
--linear_ssim_threshold 0.1 \
--conv_compute_density_threshold 300 \
@@ -62,24 +65,26 @@ python3 onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py \
Test the quantized model:
-```
+```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
---compiler oneflow \
---use_quantization \
---base ./sdxl_lightning_oneflow_quant \
---saved_image sdxl_light_oneflow_quant.png
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler oneflow \
+ --use_quantization \
+ --base /path/to/sdxl_lightning_oneflow_quant \
+ --saved_image sdxl_light_oneflow_quant.png
```
### Run 1024x1024 With Quantization [NexFort Backend]
-```
+```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
- --compiler nexfort \
- --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
- --use_quantization \
- --quantize-config '{"quant_type": "int8_dynamic"}' \
- --saved_image sdxl_light_nexfort_quant.png
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler nexfort \
+ --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+ --use_quantization \
+ --quantize-config '{"quant_type": "int8_dynamic"}' \
+ --saved_image sdxl_light_nexfort_quant.png
```
diff --git a/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
index c7223d759..4dc123b70 100644
--- a/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
+++ b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
@@ -40,7 +40,7 @@
parser.add_argument(
"--compiler",
type=str,
- default="none",
+ default="oneflow",
help="Compiler backend to use. Options: 'none', 'nexfort', 'oneflow'",
)
parser.add_argument(
diff --git a/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh b/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
index d2a45a720..ea779a753 100644
--- a/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
+++ b/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
@@ -1,10 +1,10 @@
#!/bin/bash
-python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --save_graph --save_graph_dir cached_unet_pipe
+python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --save_graph --save_graph_dir cached_unet_pipe
-python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --load_graph --load_graph_dir cached_unet_pipe
+python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --load_graph --load_graph_dir cached_unet_pipe
-HF_HUB_OFFLINE=0 python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --save_graph --save_graph_dir cached_lora_pipe
+HF_HUB_OFFLINE=0 python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --save_graph --save_graph_dir cached_lora_pipe
-HF_HUB_OFFLINE=0 python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --load_graph --load_graph_dir cached_lora_pipe
+HF_HUB_OFFLINE=0 python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --load_graph --load_graph_dir cached_lora_pipe
From 21e4822776e8c1c3c446fc81b7aaff0415306b2f Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Tue, 30 Jul 2024 14:44:20 +0800
Subject: [PATCH 06/11] update readme
---
.../examples/lightning/README.md | 17 +++++++++++++++++
1 file changed, 17 insertions(+)
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
index 78be606bc..b3cc3cf6e 100644
--- a/onediff_diffusers_extensions/examples/lightning/README.md
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -1,5 +1,19 @@
# Run SDXL-Lightning with OneDiff
+1. [Environment Setup](#environment-setup)
+ - [Set Up OneDiff](#set-up-onediff)
+ - [Set Up Compiler Backend](#set-up-compiler-backend)
+ - [Set Up SDXL-Lightning](#set-up-sdxl-lightning)
+2. [Compile](#compile)
+ - [Without Compile (Original PyTorch HF Diffusers Baseline)](#without-compile)
+ - [With OneFlow Backend](#with-oneflow-backend)
+ - [With NexFort Backend](#with-nexfort-backend)
+3. [Quantization (Int8)](#quantization)
+ - [With Quantization - OneFlow Backend](#with-quantization---oneflow-backend)
+ - [With Quantization - NexFort Backend](#with-quantization---nexfort-backend)
+4. [Performance Comparison](#performance-comparison)
+5. [Quality](#quality)
+
## Environment Setup
### Set Up OneDiff
@@ -99,3 +113,6 @@ python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light
| OneFlow Quantization | 43.45 (+195.95%) | 0.424 (-49.52%) |
| NexFort Compile | 28.07 (+91.18%) | 0.526 (-37.38%) |
| NexFort Quantization | 30.85 (+110.15%) | 0.476 (-43.33%) |
+
+## Quality
+https://github.com/siliconflow/odeval/tree/main/models/lightning
From ab9064de90189a2eb14b18fb437e4f3566594c81 Mon Sep 17 00:00:00 2001
From: Li Xiang <54010254+lixiang007666@users.noreply.github.com>
Date: Wed, 31 Jul 2024 10:05:39 +0800
Subject: [PATCH 07/11] Update README.md
---
onediff_diffusers_extensions/examples/lightning/README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
index b3cc3cf6e..a326f4411 100644
--- a/onediff_diffusers_extensions/examples/lightning/README.md
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -106,6 +106,7 @@ python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light
**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 8 steps:**
+Data update date: 2024-07-29
| Configuration | Iteration Speed (it/s) | E2E Time (seconds) |
|---------------------------|---------------------------------|---------------------------------|
| PyTorch | 14.68 | 0.840 |
From ba63eb7bee953ad4d7173d76f7266cf9b672d771 Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Wed, 31 Jul 2024 10:59:59 +0800
Subject: [PATCH 08/11] Add inference time track context
---
.../lightning/text_to_image_sdxl_light.py | 43 +++++++++----------
.../onediffx/utils/performance_monitor.py | 20 +++++++++
2 files changed, 40 insertions(+), 23 deletions(-)
create mode 100644 onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
diff --git a/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
index 4dc123b70..3f1f7813b 100644
--- a/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
+++ b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
@@ -7,6 +7,7 @@
from diffusers import StableDiffusionXLPipeline
from huggingface_hub import hf_hub_download
from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe
+from onediffx.utils.performance_monitor import track_inference_time
from safetensors.torch import load_file
try:
@@ -166,33 +167,29 @@
pipe = quantize_pipe(pipe, ignores=[], **nexfort_quantize_config)
-print("Warmup with running graphs...")
-torch.manual_seed(args.seed)
-image = pipe(
- prompt=args.prompt,
- height=args.height,
- width=args.width,
- num_inference_steps=n_steps,
- guidance_scale=0,
- output_type=OUTPUT_TYPE,
-).images
+with track_inference_time(warmup=True):
+ image = pipe(
+ prompt=args.prompt,
+ height=args.height,
+ width=args.width,
+ num_inference_steps=n_steps,
+ guidance_scale=0,
+ output_type=OUTPUT_TYPE,
+ ).images
# Normal run
-print("Normal run...")
torch.manual_seed(args.seed)
-start_t = time.time()
-image = pipe(
- prompt=args.prompt,
- height=args.height,
- width=args.width,
- num_inference_steps=n_steps,
- guidance_scale=0,
- output_type=OUTPUT_TYPE,
-).images
-
-end_t = time.time()
-print(f"e2e ({n_steps} steps) elapsed: {end_t - start_t} s")
+with track_inference_time(warmup=False):
+ image = pipe(
+ prompt=args.prompt,
+ height=args.height,
+ width=args.width,
+ num_inference_steps=n_steps,
+ guidance_scale=0,
+ output_type=OUTPUT_TYPE,
+ ).images
+
image[0].save(args.saved_image)
diff --git a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
new file mode 100644
index 000000000..56c398107
--- /dev/null
+++ b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
@@ -0,0 +1,20 @@
+import time
+from contextlib import contextmanager
+
+
+@contextmanager
+def track_inference_time(warmup=False):
+ """
+ A context manager to measure the execution time of models.
+ Parameters:
+ warmup (bool): If True, prints the time for warmup runs; otherwise, prints the time for normal runs.
+ """
+ try:
+ start_time = time.time()
+ yield
+ finally:
+ end_time = time.time()
+ if warmup:
+ print(f"Warmup run - Execution time: {end_time - start_time:.2f} seconds")
+ else:
+ print(f"Normal run - Execution time: {end_time - start_time:.2f} seconds")
From 8349c4c800736e3ee3f98350ebb13fccd433bdbd Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Wed, 31 Jul 2024 13:49:29 +0800
Subject: [PATCH 09/11] Add warmup time
---
.../examples/lightning/README.md | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
index a326f4411..f89420613 100644
--- a/onediff_diffusers_extensions/examples/lightning/README.md
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -36,6 +36,7 @@ Current test is based on an 8 steps distillation model.
```bash
python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
--prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler none \
--saved_image sdxl_light.png
```
@@ -107,13 +108,15 @@ python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light
**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 8 steps:**
Data update date: 2024-07-29
-| Configuration | Iteration Speed (it/s) | E2E Time (seconds) |
-|---------------------------|---------------------------------|---------------------------------|
-| PyTorch | 14.68 | 0.840 |
-| OneFlow Compile | 29.06 (+97.83%) | 0.530 (-36.90%) |
-| OneFlow Quantization | 43.45 (+195.95%) | 0.424 (-49.52%) |
-| NexFort Compile | 28.07 (+91.18%) | 0.526 (-37.38%) |
-| NexFort Quantization | 30.85 (+110.15%) | 0.476 (-43.33%) |
+| Configuration | Iteration Speed (it/s) | E2E Time (seconds) | Warmup time (seconds) 1 | Warmup with Cache time (seconds) |
+|---------------------------|------------------------|--------------------|-----------------------|----------------------------------|
+| PyTorch | 14.68 | 0.840 | 1.31 | - |
+| OneFlow Compile | 29.06 (+97.83%) | 0.530 (-36.90%) | 52.26 | 0.64 |
+| OneFlow Quantization | 43.45 (+195.95%) | 0.424 (-49.52%) | 59.87 | 0.51 |
+| NexFort Compile | 28.07 (+91.18%) | 0.526 (-37.38%) | 539.67 | 68.79 |
+| NexFort Quantization | 30.85 (+110.15%) | 0.476 (-43.33%) | 610.25 | 93.28 |
+
+ 1 OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor.
## Quality
https://github.com/siliconflow/odeval/tree/main/models/lightning
From a5b6bf2aec0ce22a4fcb4fca759a609625e4c9a3 Mon Sep 17 00:00:00 2001
From: Li Xiang <54010254+lixiang007666@users.noreply.github.com>
Date: Fri, 9 Aug 2024 14:28:27 +0800
Subject: [PATCH 10/11] Update performance_monitor.py
---
.../onediffx/utils/performance_monitor.py | 26 ++++++++++++++-----
1 file changed, 20 insertions(+), 6 deletions(-)
diff --git a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
index 56c398107..7ea5643ab 100644
--- a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
+++ b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
@@ -1,20 +1,34 @@
+import torch
import time
from contextlib import contextmanager
-
@contextmanager
-def track_inference_time(warmup=False):
+def track_inference_time(warmup=False, use_cuda=True):
"""
A context manager to measure the execution time of models.
Parameters:
warmup (bool): If True, prints the time for warmup runs; otherwise, prints the time for normal runs.
+ use_cuda (bool): If CUDA is available, uses torch.cuda.Event for timing; otherwise, uses time.time().
"""
- try:
+ if use_cuda and torch.cuda.is_available():
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ else:
start_time = time.time()
+
+ try:
yield
finally:
- end_time = time.time()
+ if use_cuda and torch.cuda.is_available():
+ end.record()
+ torch.cuda.synchronize()
+ elapsed_time = start.elapsed_time(end) / 1000.0
+ else:
+ elapsed_time = time.time() - start_time
+
if warmup:
- print(f"Warmup run - Execution time: {end_time - start_time:.2f} seconds")
+ print(f"Warmup run - Execution time: {elapsed_time:.2f} seconds")
else:
- print(f"Normal run - Execution time: {end_time - start_time:.2f} seconds")
+ print(f"Normal run - Execution time: {elapsed_time:.2f} seconds")
+
From 42c43007525f30133b72a81198fa86423c02fd12 Mon Sep 17 00:00:00 2001
From: lixiang007666 <88304454@qq.com>
Date: Fri, 9 Aug 2024 14:48:50 +0800
Subject: [PATCH 11/11] Format
---
.../onediffx/utils/performance_monitor.py | 9 +++++----
src/onediff/infer_compiler/README.md | 1 -
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
index 7ea5643ab..99e46a2c3 100644
--- a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
+++ b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
@@ -1,7 +1,9 @@
-import torch
import time
from contextlib import contextmanager
+import torch
+
+
@contextmanager
def track_inference_time(warmup=False, use_cuda=True):
"""
@@ -16,7 +18,7 @@ def track_inference_time(warmup=False, use_cuda=True):
start.record()
else:
start_time = time.time()
-
+
try:
yield
finally:
@@ -26,9 +28,8 @@ def track_inference_time(warmup=False, use_cuda=True):
elapsed_time = start.elapsed_time(end) / 1000.0
else:
elapsed_time = time.time() - start_time
-
+
if warmup:
print(f"Warmup run - Execution time: {elapsed_time:.2f} seconds")
else:
print(f"Normal run - Execution time: {elapsed_time:.2f} seconds")
-
diff --git a/src/onediff/infer_compiler/README.md b/src/onediff/infer_compiler/README.md
index 1fa181a79..32f4a4def 100644
--- a/src/onediff/infer_compiler/README.md
+++ b/src/onediff/infer_compiler/README.md
@@ -112,4 +112,3 @@ python3 ./benchmarks/text_to_image.py \
--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "dynamic": true}' \
--run_multiple_resolutions 1
```
-