Skip to content

Commit

Permalink
arm64: Implement 16bpc cdef_dist_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
barrbrain committed Nov 29, 2023
1 parent e778d4b commit a8ad43f
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 1 deletion.
113 changes: 113 additions & 0 deletions src/arm/64/cdef_dist.S
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,116 @@ L(cdk_8x8):
CDEF_DIST_REFINE
ret
endfunc

// v0: tmp register
// v1: src input
// v2: dst input
// v3 = sum(src_{i,j})
// v4 = sum(src_{i,j}^2)
// v5 = sum(dst_{i,j})
// v6 = sum(dst_{i,j}^2)
// v7 = sum(src_{i,j} * dst_{i,j})
// v16: zero register
.macro CDEF_DIST_HBD_W8
uabal v3.4s, v1.4h, v16.4h // sum pixel values
uabal2 v3.4s, v1.8h, v16.8h
umlal v4.4s, v1.4h, v1.4h // square and accumulate
umlal2 v4.4s, v1.8h, v1.8h
uabal v5.4s, v2.4h, v16.4h // same as above, but for dst
uabal2 v5.4s, v2.8h, v16.8h
umlal v6.4s, v2.4h, v2.4h
umlal2 v6.4s, v2.8h, v2.8h
umlal v7.4s, v1.4h, v2.4h // src_{i,j} * dst_{i,j}
umlal2 v7.4s, v1.8h, v2.8h
.endm

.macro CDEF_DIST_HBD_REFINE shift=0
addv s3, v3.4s
umull v3.2d, v3.2s, v3.2s
urshr d3, d3, #(6-\shift) // d3: sum(src_{i,j})^2 / N
uaddlv d4, v4.4s // d4: sum(src_{i,j}^2)
addv s5, v5.4s
umull v5.2d, v5.2s, v5.2s
urshr d5, d5, #(6-\shift) // d5: sum(dst_{i,j})^2 / N
uaddlv d6, v6.4s // d6: sum(dst_{i,j}^2)
uaddlv d7, v7.4s
add d0, d4, d6
sub d0, d0, d7
sub d0, d0, d7 // d0: sse
uqsub d4, d4, d3 // d4: svar
uqsub d6, d6, d5 // d6: dvar
.if \shift != 0
shl d4, d4, #\shift
shl d6, d6, #\shift
.endif
str s4, [x4]
str s6, [x4, #4]
str s0, [x4, #8]
.endm

.macro LOAD_ROW_HBD
ldr q1, [x0]
ldr q2, [x2]
add x0, x0, x1
add x2, x2, x3
.endm

.macro LOAD_ROWS_HBD
ldr d1, [x0]
ldr d2, [x2]
ldr d0, [x0, x1]
ldr d17, [x2, x3]
add x0, x0, x1, lsl 1
add x2, x2, x3, lsl 1
mov v1.d[1], v0.d[0]
mov v2.d[1], v17.d[0]
.endm

// x0: src: *const u16,
// x1: src_stride: isize,
// x2: dst: *const u16,
// x3: dst_stride: isize,
// x4: ret_ptr: *mut u32,
function cdef_dist_kernel_4x4_hbd_neon, export=1
CDEF_DIST_INIT 4, 4
L(cdk_hbd_4x4):
LOAD_ROWS_HBD
CDEF_DIST_HBD_W8
subs w5, w5, #1
bne L(cdk_hbd_4x4)
CDEF_DIST_HBD_REFINE 2
ret
endfunc

function cdef_dist_kernel_4x8_hbd_neon, export=1
CDEF_DIST_INIT 4, 8
L(cdk_hbd_4x8):
LOAD_ROWS_HBD
CDEF_DIST_HBD_W8
subs w5, w5, #1
bne L(cdk_hbd_4x8)
CDEF_DIST_HBD_REFINE 1
ret
endfunc

function cdef_dist_kernel_8x4_hbd_neon, export=1
CDEF_DIST_INIT 8, 4
L(cdk_hbd_8x4):
LOAD_ROW_HBD
CDEF_DIST_HBD_W8
subs w5, w5, #1
bne L(cdk_hbd_8x4)
CDEF_DIST_HBD_REFINE 1
ret
endfunc

function cdef_dist_kernel_8x8_hbd_neon, export=1
CDEF_DIST_INIT 8, 8
L(cdk_hbd_8x8):
LOAD_ROW_HBD
CDEF_DIST_HBD_W8
subs w5, w5, #1
bne L(cdk_hbd_8x8)
CDEF_DIST_HBD_REFINE
ret
endfunc
64 changes: 63 additions & 1 deletion src/asm/aarch64/dist/cdef_dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ type CdefDistKernelFn = unsafe extern fn(
ret_ptr: *mut u32,
);

type CdefDistKernelHBDFn = unsafe extern fn(
src: *const u16,
src_stride: isize,
dst: *const u16,
dst_stride: isize,
ret_ptr: *mut u32,
);

extern {
fn rav1e_cdef_dist_kernel_4x4_neon(
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
Expand All @@ -39,6 +47,22 @@ extern {
src: *const u8, src_stride: isize, dst: *const u8, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_4x4_hbd_neon(
src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_4x8_hbd_neon(
src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_8x4_hbd_neon(
src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize,
ret_ptr: *mut u32,
);
fn rav1e_cdef_dist_kernel_8x8_hbd_neon(
src: *const u16, src_stride: isize, dst: *const u16, dst_stride: isize,
ret_ptr: *mut u32,
);
}

/// # Panics
Expand Down Expand Up @@ -86,7 +110,25 @@ pub fn cdef_dist_kernel<T: Pixel>(
}
}
PixelType::U16 => {
return call_rust();
if let Some(func) =
CDEF_DIST_KERNEL_HBD_FNS[cpu.as_index()][kernel_fn_index(w, h)]
{
let mut ret_buf = [0u32; 3];
// SAFETY: Calls Assembly code.
unsafe {
func(
src.data_ptr() as *const _,
T::to_asm_stride(src.plane_cfg.stride),
dst.data_ptr() as *const _,
T::to_asm_stride(dst.plane_cfg.stride),
ret_buf.as_mut_ptr(),
)
}

(ret_buf[0], ret_buf[1], ret_buf[2])
} else {
return call_rust();
}
}
};

Expand Down Expand Up @@ -127,3 +169,23 @@ cpu_function_lookup_table!(
default: [None; CDEF_DIST_KERNEL_FNS_LENGTH],
[NEON]
);

static CDEF_DIST_KERNEL_HBD_FNS_NEON: [Option<CdefDistKernelHBDFn>;
CDEF_DIST_KERNEL_FNS_LENGTH] = {
let mut out: [Option<CdefDistKernelHBDFn>; CDEF_DIST_KERNEL_FNS_LENGTH] =
[None; CDEF_DIST_KERNEL_FNS_LENGTH];

out[kernel_fn_index(4, 4)] = Some(rav1e_cdef_dist_kernel_4x4_hbd_neon);
out[kernel_fn_index(4, 8)] = Some(rav1e_cdef_dist_kernel_4x8_hbd_neon);
out[kernel_fn_index(8, 4)] = Some(rav1e_cdef_dist_kernel_8x4_hbd_neon);
out[kernel_fn_index(8, 8)] = Some(rav1e_cdef_dist_kernel_8x8_hbd_neon);

out
};

cpu_function_lookup_table!(
CDEF_DIST_KERNEL_HBD_FNS:
[[Option<CdefDistKernelHBDFn>; CDEF_DIST_KERNEL_FNS_LENGTH]],
default: [None; CDEF_DIST_KERNEL_FNS_LENGTH],
[NEON]
);

0 comments on commit a8ad43f

Please sign in to comment.