Skip to content

Commit

Permalink
Enabled libxsmm for spmm kernel on arm platform
Browse files Browse the repository at this point in the history
  • Loading branch information
choudhary-devang committed Jan 17, 2025
1 parent ba73133 commit f8295e7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,17 @@ else(MSVC)
endif(NOT APPLE)
endif(MSVC)

if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)")
if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "(x86)|(X86)|(amd64)|(AMD64)|(aarch64)|(AARCH64)")
message(STATUS "Disabling LIBXSMM on ${CMAKE_SYSTEM_PROCESSOR}.")
set(USE_LIBXSMM OFF)
endif()

# Flag for arm specific optimization
if(CMAKE_SYSTEM_PROCESSOR MATCHES "(aarch64)|(AARCH64)")
message(STATUS "setting flag for arm specific optimizations = ${CMAKE_SYSTEM_PROCESSOR} to ON.")
add_definitions(-DAARCH64)
endif()

# Source file lists
file(GLOB DGL_SRC
src/*.cc
Expand Down
32 changes: 28 additions & 4 deletions src/array/cpu/spmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,21 @@ void SpMMSumCsr(
}
#if !defined(_WIN32)
#ifdef USE_LIBXSMM
int cpu_id = libxsmm_cpuid_x86();
int cpu_id, limit;
#ifdef AARCH64
static int arm_cpu_id = -1;
if (arm_cpu_id == -1){
arm_cpu_id = libxsmm_cpuid_arm();
}
cpu_id = arm_cpu_id;
limit = LIBXSMM_AARCH64_A64FX;
#else //x86
cpu_id = libxsmm_cpuid_x86();
limit = LIBXSMM_X86_AVX512;
#endif//AARCH64
const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value ||
(std::is_same<DType, BFloat16>::value && cpu_id < LIBXSMM_X86_AVX512) ||
(std::is_same<DType, BFloat16>::value && cpu_id < limit) ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) {
SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
Expand Down Expand Up @@ -266,10 +277,23 @@ void SpMMCmpCsr(
}
#if !defined(_WIN32)
#ifdef USE_LIBXSMM
int cpu_id = libxsmm_cpuid_x86();
#ifdef AARCH64
static int arm_cpu_id = -1;
if (arm_cpu_id == -1){
arm_cpu_id = libxsmm_cpuid_arm();
}
#endif//AARCH64
int cpu_id, limit;
#ifdef AARCH64
cpu_id = arm_cpu_id;
limit = LIBXSMM_AARCH64_A64FX;
#else //x86
cpu_id = libxsmm_cpuid_x86();
limit = LIBXSMM_AARCH64_A64FX;
#endif//AARCH64
const bool no_libxsmm = bcast.use_bcast ||
std::is_same<DType, double>::value ||
cpu_id < LIBXSMM_X86_AVX512 ||
cpu_id < limit ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) {
SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(
Expand Down
2 changes: 1 addition & 1 deletion src/array/cpu/spmm_blocking_libxsmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct CSRMatrixInternal {
int32_t GetLLCSize() {
#ifdef _SC_LEVEL3_CACHE_SIZE
int32_t cache_size = sysconf(_SC_LEVEL3_CACHE_SIZE);
if (cache_size < 0) cache_size = DGL_CPU_LLC_SIZE;
if (cache_size <= 0) cache_size = DGL_CPU_LLC_SIZE;
#else
int32_t cache_size = DGL_CPU_LLC_SIZE;
#endif
Expand Down
14 changes: 13 additions & 1 deletion src/runtime/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,21 @@ namespace runtime {

Config::Config() {
#if !defined(_WIN32) && defined(USE_LIBXSMM)
int cpu_id = libxsmm_cpuid_x86();
int cpu_id;
#if defined(AARCH64)
static int arm_cpu_id = -1;
if (arm_cpu_id == -1){
arm_cpu_id = libxsmm_cpuid_arm();
}
cpu_id = arm_cpu_id;
// Enable libxsmm on ARM machines by default
libxsmm_ = LIBXSMM_AARCH64_SVE128 <= cpu_id && cpu_id <= LIBXSMM_AARCH64_ALLFEAT;
#else
cpu_id = libxsmm_cpuid_x86();

// Enable libxsmm on AVX machines by default
libxsmm_ = LIBXSMM_X86_AVX2 <= cpu_id && cpu_id <= LIBXSMM_X86_ALLFEAT;
#endif //AARCH64
#else
libxsmm_ = false;
#endif
Expand Down

0 comments on commit f8295e7

Please sign in to comment.