Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible kernel selection #37

Merged
merged 11 commits into from
Dec 7, 2018
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ install:
# the main build
script:
- |
rustc --print cfg -Ctarget-cpu=native &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this is only for reporting? There is no need to cargo build with -Ctarget-cpu=native hereafter, as those are only for compile time optimization?

Copy link
Owner Author

@bluss bluss Dec 5, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is just for reporting, so we can see what kind of machine the travis test is running on!

I'm not that interested in -Ctarget-cpu=native at the moment, since we now use runtime target feature detection to deliver good performance by default instead. I keep an eye on that performance doesn't regress if the user supplies -Ctarget-cpu=native.

cargo build --target=$TARGET &&
([ -n "$BUILD_ONLY" ] || (
cargo test --target=$TARGET &&
Expand Down
240 changes: 168 additions & 72 deletions src/dgemm_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// except according to those terms.

use kernel::GemmKernel;
use kernel::GemmSelect;
use kernel::{U4, U8};
use archparam;

#[cfg(target_arch="x86")]
Expand All @@ -16,34 +18,90 @@ use std::arch::x86_64::*;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
use x86::{FusedMulAdd, AvxMulAdd, DMultiplyAdd};

pub enum Gemm { }
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelAvx;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelFma;
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
struct KernelSse2;
struct KernelFallback;

type T = f64;

pub type T = f64;
/// Detect which implementation to use and select it using the selector's
/// .select(Kernel) method.
///
/// This function is called one or more times during a whole program's
/// execution, it may be called for each gemm kernel invocation or fewer times.
#[inline]
pub(crate) fn detect<G>(selector: G) where G: GemmSelect<T> {
// dispatch to specific compiled versions
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
{
if is_x86_feature_detected_!("fma") {
return selector.select(KernelFma);
} else if is_x86_feature_detected_!("avx") {
return selector.select(KernelAvx);
} else if is_x86_feature_detected_!("sse2") {
return selector.select(KernelSse2);
}
}
return selector.select(KernelFallback);
}

const MR: usize = 8;
const NR: usize = 4;

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
macro_rules! loop_m {
($i:ident, $e:expr) => { loop8!($i, $e) };
}
macro_rules! loop_n {
($j:ident, $e:expr) => { loop4!($j, $e) };
}


impl GemmKernel for Gemm {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelAvx {
type Elem = T;

type MRTy = U8;
type NRTy = U4;

#[inline(always)]
fn align_to() -> usize { 32 }

#[inline(always)]
fn mr() -> usize { MR }
fn always_masked() -> bool { false }

#[inline(always)]
fn nr() -> usize { NR }
fn nc() -> usize { archparam::D_NC }
#[inline(always)]
fn kc() -> usize { archparam::D_KC }
#[inline(always)]
fn mc() -> usize { archparam::D_MC }

#[inline(always)]
fn always_masked() -> bool { false }
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T,
rsc: isize,
csc: isize)
{
kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc)
}
}

#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelFma {
type Elem = T;

type MRTy = <KernelAvx as GemmKernel>::MRTy;
type NRTy = <KernelAvx as GemmKernel>::NRTy;

#[inline(always)]
fn align_to() -> usize { KernelAvx::align_to() }

#[inline(always)]
fn always_masked() -> bool { KernelAvx::always_masked() }

#[inline(always)]
fn nc() -> usize { archparam::D_NC }
Expand All @@ -63,45 +121,84 @@ impl GemmKernel for Gemm {
rsc: isize,
csc: isize)
{
kernel(k, alpha, a, b, beta, c, rsc, csc)
kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc)
}
}

/// matrix multiplication kernel
///
/// This does the matrix multiplication:
///
/// C ← α A B + β C
///
/// + k: length of data in a, b
/// + a, b are packed
/// + c has general strides
/// + rsc: row stride of c
/// + csc: col stride of c
/// + if beta is 0, then c does not need to be initialized
#[inline(never)]
pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
// dispatch to specific compiled versions
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
impl GemmKernel for KernelSse2 {
type Elem = T;

type MRTy = U4;
type NRTy = U4;

#[inline(always)]
fn align_to() -> usize { 16 }

#[inline(always)]
fn always_masked() -> bool { true }

#[inline(always)]
fn nc() -> usize { archparam::D_NC }
#[inline(always)]
fn kc() -> usize { archparam::D_KC }
#[inline(always)]
fn mc() -> usize { archparam::D_MC }

#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T,
rsc: isize,
csc: isize)
{
if is_x86_feature_detected_!("fma") {
return kernel_target_fma(k, alpha, a, b, beta, c, rsc, csc);
} else if is_x86_feature_detected_!("avx") {
return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc);
} else if is_x86_feature_detected_!("sse2") {
return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc);
}
kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc)
}
}

impl GemmKernel for KernelFallback {
type Elem = T;

type MRTy = U4;
type NRTy = U4;

#[inline(always)]
fn align_to() -> usize { 0 }

#[inline(always)]
fn always_masked() -> bool { true }

#[inline(always)]
fn nc() -> usize { archparam::D_NC }
#[inline(always)]
fn kc() -> usize { archparam::D_KC }
#[inline(always)]
fn mc() -> usize { archparam::D_MC }

#[inline(always)]
unsafe fn kernel(
k: usize,
alpha: T,
a: *const T,
b: *const T,
beta: T,
c: *mut T,
rsc: isize,
csc: isize)
{
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc)
}
kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc);
}

#[inline]
#[target_feature(enable="fma")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_x86_avx::<FusedMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
}
Expand All @@ -110,7 +207,7 @@ unsafe fn kernel_target_fma(k: usize, alpha: T, a: *const T, b: *const T,
#[target_feature(enable="avx")]
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
beta: T, c: *mut T, rsc: isize, csc: isize)
{
kernel_x86_avx::<AvxMulAdd>(k, alpha, a, b, beta, c, rsc, csc)
}
Expand All @@ -130,6 +227,9 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
where MA: DMultiplyAdd
{
const MR: usize = KernelAvx::MR;
const NR: usize = KernelAvx::NR;

debug_assert_ne!(k, 0);

let mut ab = [_mm256_setzero_pd(); MR];
Expand Down Expand Up @@ -685,17 +785,21 @@ unsafe fn kernel_x86_avx<MA>(k: usize, alpha: T, a: *const T, b: *const T,
}
}

#[inline(always)]
#[inline]
unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
beta: T, c: *mut T, rsc: isize, csc: isize)
{
const MR: usize = KernelFallback::MR;
const NR: usize = KernelFallback::NR;
let mut ab: [[T; NR]; MR] = [[0.; NR]; MR];
let mut a = a;
let mut b = b;
debug_assert_eq!(alpha, 1., "Alpha must be 1 or is not masked");
debug_assert_eq!(beta, 0., "Beta must be 0 or is not masked");

// Compute matrix multiplication into ab[i][j]
unroll_by!(4 => k, {
loop_m!(i, loop_n!(j, ab[i][j] += at(a, i) * at(b, j)));
loop4!(i, loop4!(j, ab[i][j] += at(a, i) * at(b, j)));

a = a.offset(MR as isize);
b = b.offset(NR as isize);
Expand All @@ -706,11 +810,7 @@ unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T,
}

// set C = α A B + β C
if beta == 0. {
loop_n!(j, loop_m!(i, *c![i, j] = alpha * ab[i][j]));
} else {
loop_n!(j, loop_m!(i, *c![i, j] = *c![i, j] * beta + alpha * ab[i][j]));
}
loop4!(j, loop4!(i, *c![i, j] = alpha * ab[i][j]));
}

#[inline(always)]
Expand All @@ -726,49 +826,42 @@ mod tests {
fn aligned_alloc<T>(elt: T, n: usize) -> Alloc<T> where T: Copy
{
unsafe {
Alloc::new(n, Gemm::align_to()).init_with(elt)
Alloc::new(n, KernelAvx::align_to()).init_with(elt)
}
}

use super::T;
type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize);

fn test_a_kernel(_name: &str, kernel_fn: KernelFn) {
fn test_a_kernel<K: GemmKernel<Elem=T>>(_name: &str) {
const K: usize = 4;
let mut a = aligned_alloc(1., MR * K);
let mut b = aligned_alloc(0., NR * K);
let mr = K::MR;
let nr = K::NR;
let mut a = aligned_alloc(1., mr * K);
let mut b = aligned_alloc(0., nr * K);
for (i, x) in a.iter_mut().enumerate() {
*x = i as _;
}

for i in 0..K {
b[i + i * NR] = 1.;
b[i + i * nr] = 1.;
}
let mut c = [0.; MR * NR];

let mut c = vec![0.; mr * nr];
unsafe {
// Column major matrix:
// row stride of c matrix, rsc = 1
// column stride of c matrix, csc = MR = 8
kernel_fn(K, 1., &a[0], &b[0], 0., &mut c[0], 1, MR as isize);
K::kernel(K, 1., &a[0], &b[0], 0., &mut c[0], 1, mr as isize);
// col major C
}
assert_eq!(&a[..], &c[..a.len()]);
}

#[test]
fn test_native_kernel() {
test_a_kernel("kernel", kernel);
}

#[test]
fn test_kernel_fallback_impl() {
test_a_kernel("kernel", kernel_fallback_impl);
test_a_kernel::<KernelFallback>("kernel");
}

#[test]
fn test_loop_m_n() {
let mut m = [[0; NR]; MR];
loop_m!(i, loop_n!(j, m[i][j] += 1));
let mut m = [[0; 4]; KernelAvx::MR];
loop_m!(i, loop4!(j, m[i][j] += 1));
for arr in &m[..] {
for elt in &arr[..] {
assert_eq!(*elt, 1);
Expand All @@ -779,13 +872,14 @@ mod tests {
#[cfg(any(target_arch="x86", target_arch="x86_64"))]
mod test_arch_kernels {
use super::test_a_kernel;
use super::super::*;
macro_rules! test_arch_kernels_x86 {
($($feature_name:tt, $function_name:ident),*) => {
($($feature_name:tt, $name:ident, $kernel_ty:ty),*) => {
$(
#[test]
fn $function_name() {
fn $name() {
if is_x86_feature_detected_!($feature_name) {
test_a_kernel(stringify!($function_name), super::super::$function_name);
test_a_kernel::<$kernel_ty>(stringify!($name));
} else {
println!("Skipping, host does not have feature: {:?}", $feature_name);
}
Expand All @@ -795,7 +889,9 @@ mod tests {
}

test_arch_kernels_x86! {
"sse2", kernel_target_sse2
"fma", fma, KernelFma,
"avx", avx, KernelAvx,
"sse2", sse2, KernelSse2
}
}
}
Loading