From 9231f556a7b77d36b29b07cffd2a93143de2fb98 Mon Sep 17 00:00:00 2001 From: FengWen <109639975+ccssu@users.noreply.github.com> Date: Fri, 20 Sep 2024 21:06:51 +0800 Subject: [PATCH] Fix is_using_oneflow_backend check (#1112) ## Summary by CodeRabbit - **New Features** - Enhanced backend detection logic for improved compatibility with the OneFlow library. - Added a function to check for OneFlow library availability and CUDA support. - **Bug Fixes** - Improved messaging for cases when the OneFlow backend is not detected. --- .../modules/oneflow/utils/booster_utils.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py b/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py index 3b9f43a56..a7f114cb7 100644 --- a/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py +++ b/onediff_comfy_nodes/modules/oneflow/utils/booster_utils.py @@ -1,5 +1,7 @@ from typing import Union +import oneflow + import torch from comfy import model_management from comfy.model_base import BaseModel, SVD_img2vid @@ -9,6 +11,7 @@ OneflowDeployableModule as DeployableModule, ) from onediff.utils import set_boolean_env_var +from onediff.utils.import_utils import is_oneflow_available from ..patch_management import create_patch_executor, PatchType @@ -63,6 +66,15 @@ def set_environment_for_svd_img2vid(model: ModelPatcher): def is_using_oneflow_backend(module): + # First, check if oneflow is available and CUDA is enabled + if is_oneflow_available() and not oneflow.cuda.is_available(): + print("OneFlow CUDA support is not available") + return False + + # Check if the module + if isinstance(module, oneflow.nn.Module): + return True + dc_patch_executor = create_patch_executor(PatchType.DCUNetExecutorPatch) if isinstance(module, ModelPatcher): deep_cache_module = dc_patch_executor.get_patch(module) @@ -85,7 +97,17 @@ def is_using_oneflow_backend(module): if isinstance(module, DeployableModule): return True - raise RuntimeError("") + if hasattr(module, "parameters"): + for param in module.parameters(): + if isinstance(param, oneflow.Tensor): + return True + + warn_msg = ( + f"OneFlow backend is not detected for the module, the module is {type(module)}" + ) + print(warn_msg) + # If none of the above conditions are met, it's not using OneFlow backend + return False def clear_deployable_module_cache_and_unbind(