diff --git a/onediff_diffusers_extensions/examples/lightning/README.md b/onediff_diffusers_extensions/examples/lightning/README.md
new file mode 100644
index 000000000..f89420613
--- /dev/null
+++ b/onediff_diffusers_extensions/examples/lightning/README.md
@@ -0,0 +1,122 @@
+# Run SDXL-Lightning with OneDiff
+
+1. [Environment Setup](#environment-setup)
+ - [Set Up OneDiff](#set-up-onediff)
+ - [Set Up Compiler Backend](#set-up-compiler-backend)
+ - [Set Up SDXL-Lightning](#set-up-sdxl-lightning)
+2. [Compile](#compile)
+ - [Without Compile (Original PyTorch HF Diffusers Baseline)](#without-compile)
+ - [With OneFlow Backend](#with-oneflow-backend)
+ - [With NexFort Backend](#with-nexfort-backend)
+3. [Quantization (Int8)](#quantization)
+ - [With Quantization - OneFlow Backend](#with-quantization---oneflow-backend)
+ - [With Quantization - NexFort Backend](#with-quantization---nexfort-backend)
+4. [Performance Comparison](#performance-comparison)
+5. [Quality](#quality)
+
+## Environment Setup
+
+### Set Up OneDiff
+Follow the instructions to set up OneDiff from the https://github.com/siliconflow/onediff?tab=readme-ov-file#installation.
+
+### Set Up Compiler Backend
+OneDiff supports two compiler backends: OneFlow and NexFort. Follow the setup instructions for these backends from the https://github.com/siliconflow/onediff?tab=readme-ov-file#install-a-compiler-backend.
+
+
+### Set Up SDXL-Lightning
+- HF model: [SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
+- HF pipeline: [diffusers usage](https://huggingface.co/ByteDance/SDXL-Lightning#2-step-4-step-8-step-unet)
+
+## Compile
+
+> [!NOTE]
+Current test is based on an 8 steps distillation model.
+
+### Run 1024x1024 Without Compile (Original PyTorch HF Diffusers Baseline)
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler none \
+ --saved_image sdxl_light.png
+```
+
+### Run 1024x1024 With Compile [OneFlow Backend]
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler oneflow \
+ --saved_image sdxl_light_oneflow_compile.png
+```
+
+### Run 1024x1024 With Compile [NexFort Backend]
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler nexfort \
+ --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+ --saved_image sdxl_light_nexfort_compile.png
+```
+
+
+## Quantization (Int8)
+
+> [!NOTE]
+Quantization is a feature for onediff enterprise.
+
+### Run 1024x1024 With Quantization [OneFlow Backend]
+
+Execute the following command to quantize the model, where `--quantized_model` is the path to the quantized model. For an introduction to the quantization parameters, refer to: https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#diffusers-with-onediff-enterprise
+
+```bash
+python3 onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py \
+ --quantized_model /path/to/sdxl_lightning_oneflow_quant \
+ --conv_ssim_threshold 0.1 \
+ --linear_ssim_threshold 0.1 \
+ --conv_compute_density_threshold 300 \
+ --linear_compute_density_threshold 300 \
+ --save_as_float true \
+ --use_lightning 1
+```
+
+Test the quantized model:
+
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler oneflow \
+ --use_quantization \
+ --base /path/to/sdxl_lightning_oneflow_quant \
+ --saved_image sdxl_light_oneflow_quant.png
+```
+
+
+### Run 1024x1024 With Quantization [NexFort Backend]
+
+```bash
+python3 onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py \
+ --prompt "product photography, world of warcraft orc warrior, white background" \
+ --compiler nexfort \
+ --compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "options": {"triton.fuse_attention_allow_fp16_reduction": false}}' \
+ --use_quantization \
+ --quantize-config '{"quant_type": "int8_dynamic"}' \
+ --saved_image sdxl_light_nexfort_quant.png
+```
+
+
+## Performance Comparison
+
+**Testing on an NVIDIA RTX 4090 GPU, using a resolution of 1024x1024 and 8 steps:**
+
+Data update date: 2024-07-29
+| Configuration | Iteration Speed (it/s) | E2E Time (seconds) | Warmup time (seconds) 1 | Warmup with Cache time (seconds) |
+|---------------------------|------------------------|--------------------|-----------------------|----------------------------------|
+| PyTorch | 14.68 | 0.840 | 1.31 | - |
+| OneFlow Compile | 29.06 (+97.83%) | 0.530 (-36.90%) | 52.26 | 0.64 |
+| OneFlow Quantization | 43.45 (+195.95%) | 0.424 (-49.52%) | 59.87 | 0.51 |
+| NexFort Compile | 28.07 (+91.18%) | 0.526 (-37.38%) | 539.67 | 68.79 |
+| NexFort Quantization | 30.85 (+110.15%) | 0.476 (-43.33%) | 610.25 | 93.28 |
+
+ 1 OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor.
+
+## Quality
+https://github.com/siliconflow/odeval/tree/main/models/lightning
diff --git a/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
new file mode 100644
index 000000000..3f1f7813b
--- /dev/null
+++ b/onediff_diffusers_extensions/examples/lightning/text_to_image_sdxl_light.py
@@ -0,0 +1,198 @@
+import argparse
+import json
+import os
+import time
+
+import torch
+from diffusers import StableDiffusionXLPipeline
+from huggingface_hub import hf_hub_download
+from onediffx import compile_pipe, load_pipe, quantize_pipe, save_pipe
+from onediffx.utils.performance_monitor import track_inference_time
+from safetensors.torch import load_file
+
+try:
+ USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND
+except Exception as e:
+ USE_PEFT_BACKEND = False
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
+)
+parser.add_argument("--repo", type=str, default="ByteDance/SDXL-Lightning")
+parser.add_argument("--cpkt", type=str, default="sdxl_lightning_8step_unet.safetensors")
+parser.add_argument("--variant", type=str, default="fp16")
+parser.add_argument(
+ "--prompt",
+ type=str,
+ # default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
+ default="A girl smiling",
+)
+parser.add_argument("--save_graph", action="store_true")
+parser.add_argument("--load_graph", action="store_true")
+parser.add_argument("--save_graph_dir", type=str, default="cached_pipe")
+parser.add_argument("--load_graph_dir", type=str, default="cached_pipe")
+parser.add_argument("--height", type=int, default=1024)
+parser.add_argument("--width", type=int, default=1024)
+parser.add_argument(
+ "--saved_image", type=str, required=False, default="sdxl-light-out.png"
+)
+parser.add_argument("--seed", type=int, default=1)
+parser.add_argument(
+ "--compiler",
+ type=str,
+ default="oneflow",
+ help="Compiler backend to use. Options: 'none', 'nexfort', 'oneflow'",
+)
+parser.add_argument(
+ "--compiler-config", type=str, help="JSON string for nexfort compiler config."
+)
+parser.add_argument(
+ "--quantize-config", type=str, help="JSON string for nexfort quantization config."
+)
+parser.add_argument("--bits", type=int, default=8)
+parser.add_argument("--use_quantization", action="store_true")
+
+
+args = parser.parse_args()
+
+OUTPUT_TYPE = "pil"
+
+n_steps = int(args.cpkt[len("sdxl_lightning_") : len("sdxl_lightning_") + 1])
+
+is_lora_cpkt = "lora" in args.cpkt
+
+if args.compiler == "oneflow":
+ from onediff.schedulers import EulerDiscreteScheduler
+else:
+ from diffusers import EulerDiscreteScheduler
+
+if is_lora_cpkt:
+ if not USE_PEFT_BACKEND:
+ print("PEFT backend is required for load_lora_weights")
+ exit(0)
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.base, torch_dtype=torch.float16, variant="fp16"
+ ).to("cuda")
+ if os.path.isfile(os.path.join(args.repo, args.cpkt)):
+ pipe.load_lora_weights(os.path.join(args.repo, args.cpkt))
+ else:
+ pipe.load_lora_weights(hf_hub_download(args.repo, args.cpkt))
+ pipe.fuse_lora()
+else:
+ if args.use_quantization and args.compiler == "oneflow":
+ print("oneflow backend quant...")
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.base, torch_dtype=torch.float16, variant="fp16"
+ ).to("cuda")
+ import onediff_quant
+ from onediff_quant.utils import replace_sub_module_with_quantizable_module
+
+ quantized_layers_count = 0
+ onediff_quant.enable_load_quantized_model()
+
+ calibrate_info = {}
+ with open(os.path.join(args.base, "calibrate_info.txt"), "r") as f:
+ for line in f.readlines():
+ line = line.strip()
+ items = line.split(" ")
+ calibrate_info[items[0]] = [
+ float(items[1]),
+ int(items[2]),
+ [float(x) for x in items[3].split(",")],
+ ]
+
+ for sub_module_name, sub_calibrate_info in calibrate_info.items():
+ replace_sub_module_with_quantizable_module(
+ pipe.unet,
+ sub_module_name,
+ sub_calibrate_info,
+ False,
+ False,
+ args.bits,
+ )
+ quantized_layers_count += 1
+
+ print(f"Total quantized layers: {quantized_layers_count}")
+
+ else:
+ from diffusers import UNet2DConditionModel
+
+ unet = UNet2DConditionModel.from_config(args.base, subfolder="unet").to(
+ "cuda", torch.float16
+ )
+ if os.path.isfile(os.path.join(args.repo, args.cpkt)):
+ unet.load_state_dict(
+ load_file(os.path.join(args.repo, args.cpkt), device="cuda")
+ )
+ else:
+ unet.load_state_dict(
+ load_file(hf_hub_download(args.repo, args.cpkt), device="cuda")
+ )
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.base, unet=unet, torch_dtype=torch.float16, variant="fp16"
+ ).to("cuda")
+
+pipe.scheduler = EulerDiscreteScheduler.from_config(
+ pipe.scheduler.config, timestep_spacing="trailing"
+)
+
+if pipe.vae.dtype == torch.float16 and pipe.vae.config.force_upcast:
+ pipe.upcast_vae()
+
+# Compile the pipeline
+if args.compiler == "oneflow":
+ print("oneflow backend compile...")
+ pipe = compile_pipe(
+ pipe,
+ )
+ if args.load_graph:
+ print("Loading graphs...")
+ load_pipe(pipe, args.load_graph_dir)
+elif args.compiler == "nexfort":
+ print("nexfort backend compile...")
+ nexfort_compiler_config = (
+ json.loads(args.compiler_config) if args.compiler_config else None
+ )
+
+ options = nexfort_compiler_config
+ pipe = compile_pipe(
+ pipe, backend="nexfort", options=options, fuse_qkv_projections=True
+ )
+ if args.use_quantization and args.compiler == "nexfort":
+ print("nexfort backend quant...")
+ nexfort_quantize_config = (
+ json.loads(args.quantize_config) if args.quantize_config else None
+ )
+ pipe = quantize_pipe(pipe, ignores=[], **nexfort_quantize_config)
+
+
+with track_inference_time(warmup=True):
+ image = pipe(
+ prompt=args.prompt,
+ height=args.height,
+ width=args.width,
+ num_inference_steps=n_steps,
+ guidance_scale=0,
+ output_type=OUTPUT_TYPE,
+ ).images
+
+
+# Normal run
+torch.manual_seed(args.seed)
+with track_inference_time(warmup=False):
+ image = pipe(
+ prompt=args.prompt,
+ height=args.height,
+ width=args.width,
+ num_inference_steps=n_steps,
+ guidance_scale=0,
+ output_type=OUTPUT_TYPE,
+ ).images
+
+
+image[0].save(args.saved_image)
+
+if args.save_graph:
+ print("Saving graphs...")
+ save_pipe(pipe, args.save_graph_dir)
diff --git a/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh b/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
index d2a45a720..ea779a753 100644
--- a/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
+++ b/onediff_diffusers_extensions/examples/save_and_load_pipeline.sh
@@ -1,10 +1,10 @@
#!/bin/bash
-python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --save_graph --save_graph_dir cached_unet_pipe
+python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --save_graph --save_graph_dir cached_unet_pipe
-python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --load_graph --load_graph_dir cached_unet_pipe
+python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_unet.safetensors --load_graph --load_graph_dir cached_unet_pipe
-HF_HUB_OFFLINE=0 python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --save_graph --save_graph_dir cached_lora_pipe
+HF_HUB_OFFLINE=0 python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --save_graph --save_graph_dir cached_lora_pipe
-HF_HUB_OFFLINE=0 python3 examples/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --load_graph --load_graph_dir cached_lora_pipe
+HF_HUB_OFFLINE=0 python3 examples/lightning/text_to_image_sdxl_light.py --base /share_nfs/hf_models/stable-diffusion-xl-base-1.0 --repo /share_nfs/hf_models/SDXL-Lightning --cpkt sdxl_lightning_4step_lora.safetensors --load_graph --load_graph_dir cached_lora_pipe
diff --git a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py b/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py
deleted file mode 100644
index d88b1f074..000000000
--- a/onediff_diffusers_extensions/examples/text_to_image_sdxl_light.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import argparse
-import os
-import time
-
-import torch
-from diffusers import StableDiffusionXLPipeline
-from huggingface_hub import hf_hub_download
-from onediffx import compile_pipe, load_pipe, save_pipe
-from safetensors.torch import load_file
-
-try:
- USE_PEFT_BACKEND = diffusers.utils.USE_PEFT_BACKEND
-except Exception as e:
- USE_PEFT_BACKEND = False
-
-parser = argparse.ArgumentParser()
-parser.add_argument(
- "--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
-)
-parser.add_argument("--repo", type=str, default="ByteDance/SDXL-Lightning")
-parser.add_argument("--cpkt", type=str, default="sdxl_lightning_4step_unet.safetensors")
-parser.add_argument("--variant", type=str, default="fp16")
-parser.add_argument(
- "--prompt",
- type=str,
- # default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
- default="A girl smiling",
-)
-parser.add_argument("--save_graph", action="store_true")
-parser.add_argument("--load_graph", action="store_true")
-parser.add_argument("--save_graph_dir", type=str, default="cached_pipe")
-parser.add_argument("--load_graph_dir", type=str, default="cached_pipe")
-parser.add_argument("--height", type=int, default=1024)
-parser.add_argument("--width", type=int, default=1024)
-parser.add_argument(
- "--saved_image", type=str, required=False, default="sdxl-light-out.png"
-)
-parser.add_argument("--seed", type=int, default=1)
-parser.add_argument(
- "--compile",
- type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
- default=True,
-)
-
-
-args = parser.parse_args()
-
-OUTPUT_TYPE = "pil"
-
-n_steps = int(args.cpkt[len("sdxl_lightning_") : len("sdxl_lightning_") + 1])
-
-is_lora_cpkt = "lora" in args.cpkt
-
-if args.compile:
- from onediff.schedulers import EulerDiscreteScheduler
-else:
- from diffusers import EulerDiscreteScheduler
-
-if is_lora_cpkt:
- if not USE_PEFT_BACKEND:
- print("PEFT backend is required for load_lora_weights")
- exit(0)
- pipe = StableDiffusionXLPipeline.from_pretrained(
- args.base, torch_dtype=torch.float16, variant="fp16"
- ).to("cuda")
- if os.path.isfile(os.path.join(args.repo, args.cpkt)):
- pipe.load_lora_weights(os.path.join(args.repo, args.cpkt))
- else:
- pipe.load_lora_weights(hf_hub_download(args.repo, args.cpkt))
- pipe.fuse_lora()
-else:
- from diffusers import UNet2DConditionModel
-
- unet = UNet2DConditionModel.from_config(args.base, subfolder="unet").to(
- "cuda", torch.float16
- )
- if os.path.isfile(os.path.join(args.repo, args.cpkt)):
- unet.load_state_dict(
- load_file(os.path.join(args.repo, args.cpkt), device="cuda")
- )
- else:
- unet.load_state_dict(
- load_file(hf_hub_download(args.repo, args.cpkt), device="cuda")
- )
- pipe = StableDiffusionXLPipeline.from_pretrained(
- args.base, unet=unet, torch_dtype=torch.float16, variant="fp16"
- ).to("cuda")
-
-pipe.scheduler = EulerDiscreteScheduler.from_config(
- pipe.scheduler.config, timestep_spacing="trailing"
-)
-
-if pipe.vae.dtype == torch.float16 and pipe.vae.config.force_upcast:
- pipe.upcast_vae()
-
-# Compile the pipeline
-if args.compile:
- pipe = compile_pipe(
- pipe,
- )
- if args.load_graph:
- print("Loading graphs...")
- load_pipe(pipe, args.load_graph_dir)
-
-print("Warmup with running graphs...")
-torch.manual_seed(args.seed)
-image = pipe(
- prompt=args.prompt,
- height=args.height,
- width=args.width,
- num_inference_steps=n_steps,
- guidance_scale=0,
- output_type=OUTPUT_TYPE,
-).images
-
-
-# Normal run
-print("Normal run...")
-torch.manual_seed(args.seed)
-start_t = time.time()
-image = pipe(
- prompt=args.prompt,
- height=args.height,
- width=args.width,
- num_inference_steps=n_steps,
- guidance_scale=0,
- output_type=OUTPUT_TYPE,
-).images
-
-end_t = time.time()
-print(f"e2e ({n_steps} steps) elapsed: {end_t - start_t} s")
-
-image[0].save(args.saved_image)
-
-if args.save_graph:
- print("Saving graphs...")
- save_pipe(pipe, args.save_graph_dir)
diff --git a/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
new file mode 100644
index 000000000..99e46a2c3
--- /dev/null
+++ b/onediff_diffusers_extensions/onediffx/utils/performance_monitor.py
@@ -0,0 +1,35 @@
+import time
+from contextlib import contextmanager
+
+import torch
+
+
+@contextmanager
+def track_inference_time(warmup=False, use_cuda=True):
+ """
+ A context manager to measure the execution time of models.
+ Parameters:
+ warmup (bool): If True, prints the time for warmup runs; otherwise, prints the time for normal runs.
+ use_cuda (bool): If CUDA is available, uses torch.cuda.Event for timing; otherwise, uses time.time().
+ """
+ if use_cuda and torch.cuda.is_available():
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ start.record()
+ else:
+ start_time = time.time()
+
+ try:
+ yield
+ finally:
+ if use_cuda and torch.cuda.is_available():
+ end.record()
+ torch.cuda.synchronize()
+ elapsed_time = start.elapsed_time(end) / 1000.0
+ else:
+ elapsed_time = time.time() - start_time
+
+ if warmup:
+ print(f"Warmup run - Execution time: {elapsed_time:.2f} seconds")
+ else:
+ print(f"Normal run - Execution time: {elapsed_time:.2f} seconds")
diff --git a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
index f4d332a55..9ed6ecadd 100644
--- a/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
+++ b/onediff_diffusers_extensions/tools/quantization/quantize-sd-fast.py
@@ -11,10 +11,13 @@
StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLPipeline,
+ UNet2DConditionModel,
)
+from huggingface_hub import hf_hub_download
from onediff.quantization import QuantPipeline
from PIL import Image
+from safetensors.torch import load_file
parser = argparse.ArgumentParser()
@@ -63,6 +66,18 @@
)
parser.add_argument("--seed", type=int, default=111)
parser.add_argument("--cache_dir", type=str, default=None)
+parser.add_argument(
+ "--use_lightning",
+ type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
+ default=False,
+ help="Use the SDXL Lightning model if true",
+)
+parser.add_argument(
+ "--lightning_ckpt",
+ type=str,
+ default="sdxl_lightning_8step_unet.safetensors",
+ help="Checkpoint file name for the ByteDance SDXL-Lightning model",
+)
args = parser.parse_args()
pipeline_cls = (
@@ -102,6 +117,21 @@
use_safetensors=True,
)
+if args.use_lightning:
+ repo = "ByteDance/SDXL-Lightning"
+ ckpt = args.lightning_ckpt
+ unet = UNet2DConditionModel.from_config(args.model, subfolder="unet").to(
+ "cuda", torch.float16
+ )
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
+ pipe = QuantPipeline.from_pretrained(
+ pipeline_cls,
+ args.model,
+ unet=unet,
+ torch_dtype=torch.float16,
+ variant=args.variant,
+ use_safetensors=True,
+ )
else:
pipe = QuantPipeline.from_pretrained(
pipeline_cls,
diff --git a/src/onediff/infer_compiler/README.md b/src/onediff/infer_compiler/README.md
index 1fa181a79..32f4a4def 100644
--- a/src/onediff/infer_compiler/README.md
+++ b/src/onediff/infer_compiler/README.md
@@ -112,4 +112,3 @@ python3 ./benchmarks/text_to_image.py \
--compiler-config '{"mode": "max-optimize:max-autotune:low-precision", "memory_format": "channels_last", "dynamic": true}' \
--run_multiple_resolutions 1
```
-