Skip to content

Commit

Permalink
setup.py: cleanup get_rocm_version
Browse files Browse the repository at this point in the history
Signed-off-by: Daniele Trifirò <[email protected]>
  • Loading branch information
dtrifiro committed Jan 22, 2025
1 parent 87620b8 commit f9ea699
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,16 +389,17 @@ def _build_custom_ops() -> bool:
return _is_cuda() or _is_hip() or _is_cpu()


def get_rocm_version():
def get_rocm_version() -> str:
from torch.utils.cpp_extension import ROCM_HOME

# Get the Rocm version from the ROCM_HOME/bin/librocm-core.so
# see https://github.com/ROCm/rocm-core/blob/d11f5c20d500f729c393680a01fa902ebf92094b/rocm_version.cpp#L21
try:
librocm_core_file = Path(ROCM_HOME) / "lib" / "librocm-core.so"
if not librocm_core_file.is_file():
return None
librocm_core = ctypes.CDLL(librocm_core_file)
raise Exception("librocm-core.so is not availble")

librocm_core = ctypes.CDLL(str(librocm_core_file))
VerErrors = ctypes.c_uint32
get_rocm_core_version = librocm_core.getROCmVersion
get_rocm_core_version.restype = VerErrors
Expand All @@ -411,12 +412,27 @@ def get_rocm_version():
minor = ctypes.c_uint32()
patch = ctypes.c_uint32()

if (get_rocm_core_version(ctypes.byref(major), ctypes.byref(minor),
ctypes.byref(patch)) == 0):
return "%d.%d.%d" % (major.value, minor.value, patch.value)
return None
except Exception:
return None
ret = get_rocm_core_version(
ctypes.byref(major),
ctypes.byref(minor),
ctypes.byref(patch),
)
if ret == 0:
version = f"{major.value}.{minor.value}.{patch.value}"
else:
raise Exception(f"get_rocm_core_version returned: {ret}")
except Exception as exc:
print(f"failed to get version using librocm-core ({exc}), "
"falling back to torch.version.hip")

import torch

if torch.version.hip is None:
raise SetupError("Couldn't get rocm version") from exc

version = torch.version.hip

return version


def get_neuronxcc_version():
Expand Down Expand Up @@ -498,13 +514,7 @@ def get_vllm_version() -> str:
cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"{sep}cu{cuda_version_str}"
elif _is_hip():
import torch

if torch.version.hip is None:
raise SetupError("Couldn't get rocm version")

# Get the Rocm Version
rocm_version = get_rocm_version() or torch.version.hip
rocm_version = get_rocm_version()
if rocm_version and rocm_version != MAIN_CUDA_VERSION:
version += f"{sep}rocm{rocm_version.replace('.', '')[:3]}"
elif _is_neuron():
Expand Down

0 comments on commit f9ea699

Please sign in to comment.