diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index a1395982af44c..57166f05cd9bf 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -82,23 +82,25 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: if backend_by_env_var is not None: selected_backend = backend_name_to_enum(backend_by_env_var) if selected_backend is None: - # For Volta and Turing GPUs, use xformers instead. - device_available = current_platform.has_device_capability(80) - if device_available and support_fa: - from transformers.utils import is_flash_attn_2_available - if is_flash_attn_2_available(): - selected_backend = _Backend.FLASH_ATTN + if current_platform.is_cuda(): + device_available = current_platform.has_device_capability(80) + if device_available and support_fa: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + selected_backend = _Backend.FLASH_ATTN + else: + logger.warning_once( + "Current `vllm-flash-attn` has a bug inside vision " + "module, so we use xformers backend instead. You can " + "run `pip install flash-attn` to use flash-attention " + "backend.") + selected_backend = _Backend.XFORMERS else: - logger.warning_once( - "Current `vllm-flash-attn` has a bug inside vision module, " - "so we use xformers backend instead. You can run " - "`pip install flash-attn` to use flash-attention backend.") + # For Volta and Turing GPUs, use xformers instead. selected_backend = _Backend.XFORMERS - elif current_platform.is_cpu() or current_platform.is_rocm(): - # ROCM doesn't support xformers - selected_backend = _Backend.TORCH_SDPA else: - selected_backend = _Backend.XFORMERS + # Default to torch SDPA for other non-GPU platforms. + selected_backend = _Backend.TORCH_SDPA return selected_backend