diff --git a/src/asm/aarch64/transform/inverse.rs b/src/asm/aarch64/transform/inverse.rs index 268968124c..67bb7664f1 100644 --- a/src/asm/aarch64/transform/inverse.rs +++ b/src/asm/aarch64/transform/inverse.rs @@ -16,30 +16,25 @@ use crate::{Pixel, PixelType}; use crate::asm::shared::transform::inverse::*; use crate::asm::shared::transform::*; -#[inline] -pub fn inverse_transform_add_lossless( +pub fn inverse_transform_add( input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, eob: usize, - bd: usize, cpu: CpuFeatureLevel, + tx_size: TxSize, tx_type: TxType, bd: usize, cpu: CpuFeatureLevel, ) { - match T::type_enum() { - PixelType::U8 => { - if let Some(func) = INV_TXFM_WHT_FN[cpu.as_index()] { - return call_inverse_func(func, input, output, eob, 4, 4, bd); + if tx_type == TxType::WHT_WHT { + debug_assert!(tx_size == TxSize::TX_4X4); + match T::type_enum() { + PixelType::U8 => { + if let Some(func) = INV_TXFM_WHT_FN[cpu.as_index()] { + return call_inverse_func(func, input, output, eob, 4, 4, bd); + } } - } - PixelType::U16 => { - if let Some(func) = INV_TXFM_WHT_HBD_FN[cpu.as_index()] { - return call_inverse_hbd_func(func, input, output, eob, 4, 4, bd); + PixelType::U16 => { + if let Some(func) = INV_TXFM_WHT_HBD_FN[cpu.as_index()] { + return call_inverse_hbd_func(func, input, output, eob, 4, 4, bd); + } } } } - rust::inverse_transform_add_lossless(input, output, eob, bd, cpu); -} - -pub fn inverse_transform_add( - input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, eob: usize, - tx_size: TxSize, tx_type: TxType, bd: usize, cpu: CpuFeatureLevel, -) { match T::type_enum() { PixelType::U8 => { if let Some(func) = INV_TXFM_FNS[cpu.as_index()] diff --git a/src/asm/x86/transform/forward.rs b/src/asm/x86/transform/forward.rs index eb4cad0edb..541769692f 100644 --- a/src/asm/x86/transform/forward.rs +++ b/src/asm/x86/transform/forward.rs @@ -21,8 +21,6 @@ use std::arch::x86::*; #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; -pub use crate::transform::forward::rust::forward_transform_lossless; - type TxfmFuncI32X8 = unsafe fn(&mut [I32X8]); #[inline] @@ -41,6 +39,7 @@ fn get_func_i32x8(t: TxfmType) -> TxfmFuncI32X8 { Identity8 => fidentity, Identity16 => fidentity, Identity32 => fidentity, + WHT4 => fwht4, } } @@ -509,6 +508,7 @@ unsafe fn forward_transform_avx2( /// # Panics /// /// - If called with an invalid combination of `tx_size` and `tx_type` +#[inline] pub fn forward_transform( input: &[i16], output: &mut [MaybeUninit], stride: usize, tx_size: TxSize, tx_type: TxType, bd: usize, cpu: CpuFeatureLevel, diff --git a/src/asm/x86/transform/inverse.rs b/src/asm/x86/transform/inverse.rs index 007ba3e651..df99f4b4b3 100644 --- a/src/asm/x86/transform/inverse.rs +++ b/src/asm/x86/transform/inverse.rs @@ -16,30 +16,25 @@ use crate::{Pixel, PixelType}; use crate::asm::shared::transform::inverse::*; use crate::asm::shared::transform::*; -#[inline] -pub fn inverse_transform_add_lossless( +pub fn inverse_transform_add( input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, eob: usize, - bd: usize, cpu: CpuFeatureLevel, + tx_size: TxSize, tx_type: TxType, bd: usize, cpu: CpuFeatureLevel, ) { - match T::type_enum() { - PixelType::U8 => { - if let Some(func) = INV_TXFM_WHT_FN[cpu.as_index()] { - return call_inverse_func(func, input, output, eob, 4, 4, bd); + if tx_type == TxType::WHT_WHT { + debug_assert!(tx_size == TxSize::TX_4X4); + match T::type_enum() { + PixelType::U8 => { + if let Some(func) = INV_TXFM_WHT_FN[cpu.as_index()] { + return call_inverse_func(func, input, output, eob, 4, 4, bd); + } } - } - PixelType::U16 => { - if let Some(func) = INV_TXFM_WHT_HBD_FN[cpu.as_index()] { - return call_inverse_hbd_func(func, input, output, eob, 4, 4, bd); + PixelType::U16 => { + if let Some(func) = INV_TXFM_WHT_HBD_FN[cpu.as_index()] { + return call_inverse_hbd_func(func, input, output, eob, 4, 4, bd); + } } } } - rust::inverse_transform_add_lossless(input, output, eob, bd, cpu); -} - -pub fn inverse_transform_add( - input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, eob: usize, - tx_size: TxSize, tx_type: TxType, bd: usize, cpu: CpuFeatureLevel, -) { match T::type_enum() { PixelType::U8 => { if let Some(func) = INV_TXFM_FNS[cpu.as_index()] diff --git a/src/transform/forward.rs b/src/transform/forward.rs index 8e1b97fc0a..8c56a9de6b 100644 --- a/src/transform/forward.rs +++ b/src/transform/forward.rs @@ -92,36 +92,7 @@ pub mod rust { Identity8 => fidentity, Identity16 => fidentity, Identity32 => fidentity, - } - } - - pub fn forward_transform_lossless( - input: &[i16], output: &mut [T], stride: usize, _cpu: CpuFeatureLevel, - ) { - let mut tmp = [0i32; 4 * 4]; - let buf = &mut tmp[..]; - let mut col_coeffs_backing = [0i32; 4]; - let col_coeffs = &mut col_coeffs_backing[..]; - - // Columns - for c in 0..4 { - for r in 0..4 { - col_coeffs[r] = (input[r * stride + c]).into(); - } - fwht4(col_coeffs); - for r in 0..4 { - buf[r * 4 + c] = col_coeffs[r]; - } - } - - // Rows - for r in 0..4 { - let row_coeffs = &mut buf[r * 4..]; - fwht4(row_coeffs); - av1_round_shift_array(row_coeffs, 4, -2); - for c in 0..4 { - output[c * 4 + r] = T::cast_from(row_coeffs[c]); - } + WHT4 => fwht4, } } diff --git a/src/transform/forward_shared.rs b/src/transform/forward_shared.rs index 232af22866..507bf355ba 100644 --- a/src/transform/forward_shared.rs +++ b/src/transform/forward_shared.rs @@ -39,6 +39,8 @@ const FWD_SHIFT_32X8: TxfmShifts = [[4, -1, 0], [2, 0, 1], [0, 0, 3]]; const FWD_SHIFT_16X64: TxfmShifts = [[4, -2, 0], [2, 0, 0], [0, 0, 2]]; const FWD_SHIFT_64X16: TxfmShifts = [[4, -2, 0], [2, 0, 0], [0, 0, 2]]; +const FWD_SHIFT_4X4_WHT: TxfmShift = [0, 0, 2]; + pub const FWD_TXFM_SHIFT_LS: [TxfmShifts; TxSize::TX_SIZES_ALL] = [ FWD_SHIFT_4X4, FWD_SHIFT_8X8, @@ -75,31 +77,35 @@ pub enum TxfmType { Identity8, Identity16, Identity32, + WHT4, } impl TxfmType { - const TX_TYPES_1D: usize = 4; + const TX_TYPES_1D: usize = 5; const AV1_TXFM_TYPE_LS: [[Option; Self::TX_TYPES_1D]; 5] = [ [ Some(TxfmType::DCT4), Some(TxfmType::ADST4), Some(TxfmType::ADST4), Some(TxfmType::Identity4), + Some(TxfmType::WHT4), ], [ Some(TxfmType::DCT8), Some(TxfmType::ADST8), Some(TxfmType::ADST8), Some(TxfmType::Identity8), + None, ], [ Some(TxfmType::DCT16), Some(TxfmType::ADST16), Some(TxfmType::ADST16), Some(TxfmType::Identity16), + None, ], - [Some(TxfmType::DCT32), None, None, Some(TxfmType::Identity32)], - [Some(TxfmType::DCT64), None, None, None], + [Some(TxfmType::DCT32), None, None, Some(TxfmType::Identity32), None], + [Some(TxfmType::DCT64), None, None, None, None], ]; } @@ -129,12 +135,17 @@ impl Txfm2DFlipCfg { let txfm_type_row = TxfmType::AV1_TXFM_TYPE_LS[txw_idx][tx_type_1d_row as usize].unwrap(); let (ud_flip, lr_flip) = Self::get_flip_cfg(tx_type); + let shift = if tx_type == TxType::WHT_WHT { + FWD_SHIFT_4X4_WHT + } else { + FWD_TXFM_SHIFT_LS[tx_size as usize][(bd - 8) / 2] + }; Txfm2DFlipCfg { tx_size, ud_flip, lr_flip, - shift: FWD_TXFM_SHIFT_LS[tx_size as usize][(bd - 8) / 2], + shift, txfm_type_col, txfm_type_row, } @@ -145,7 +156,7 @@ impl Txfm2DFlipCfg { use self::TxType::*; match tx_type { DCT_DCT | ADST_DCT | DCT_ADST | ADST_ADST | IDTX | V_DCT | H_DCT - | V_ADST | H_ADST => (false, false), + | V_ADST | H_ADST | WHT_WHT => (false, false), FLIPADST_DCT | FLIPADST_ADST | V_FLIPADST => (true, false), DCT_FLIPADST | ADST_FLIPADST | H_FLIPADST => (false, true), FLIPADST_FLIPADST => (true, true), @@ -1728,7 +1739,6 @@ $($s)* fn daala_fdct64(coeffs: &mut [T]) { #[$m] $($s)* fn fidentity(_coeffs: &mut [T]) {} -#[allow(unused)] #[$m] $($s)* fn fwht4(coeffs: &mut [T]) { assert!(coeffs.len() >= 4); diff --git a/src/transform/inverse.rs b/src/transform/inverse.rs index e4ca4bbcde..870e517f37 100644 --- a/src/transform/inverse.rs +++ b/src/transform/inverse.rs @@ -33,7 +33,7 @@ use super::TxType; /// # Panics /// /// - If `input` or `output` have fewer than 4 items. -pub fn av1_iwht4(input: &[i32], output: &mut [i32]) { +pub fn av1_iwht4(input: &[i32], output: &mut [i32], _range: usize) { assert!(input.len() >= 4); assert!(output.len() >= 4); @@ -1591,7 +1591,7 @@ fn av1_idct64(input: &[i32], output: &mut [i32], range: usize) { type InvTxfmFn = fn(input: &[i32], output: &mut [i32], range: usize); -static INV_TXFM_FNS: [[InvTxfmFn; 5]; 4] = [ +static INV_TXFM_FNS: [[InvTxfmFn; 5]; 5] = [ [av1_idct4, av1_idct8, av1_idct16, av1_idct32, av1_idct64], [ av1_iadst4, @@ -1614,6 +1614,13 @@ static INV_TXFM_FNS: [[InvTxfmFn; 5]; 4] = [ av1_iidentity32, |_, _, _| unimplemented!(), ], + [ + av1_iwht4, + |_, _, _| unimplemented!(), + |_, _, _| unimplemented!(), + |_, _, _| unimplemented!(), + |_, _, _| unimplemented!(), + ], ]; pub(crate) mod rust { @@ -1624,52 +1631,6 @@ pub(crate) mod rust { use simd_helpers::cold_for_target_arch; use std::cmp; - #[cold_for_target_arch("x86_64", "aarch64")] - pub fn inverse_transform_add_lossless( - input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, _eob: usize, - _bd: usize, _cpu: CpuFeatureLevel, - ) { - // - let input: &[T::Coeff] = &input[..4 * 4]; - let mut buffer = [0i32; 4 * 4]; - - // perform inv txfm on every row - for (r, buffer_slice) in buffer.chunks_exact_mut(4).enumerate() { - let mut temp_in: [i32; 4] = [0; 4]; - for (val, transposed) in input[r..] - .iter() - .map(|a| i32::cast_from(*a)) - .step_by(4) - .zip(temp_in.iter_mut()) - { - *transposed = val >> 2; - } - av1_iwht4(&temp_in, buffer_slice); - } - - // perform inv txfm on every col - for c in 0..4 { - let mut temp_in: [i32; 4] = [0; 4]; - let mut temp_out: [i32; 4] = [0; 4]; - for (val, transposed) in buffer[c..] - .iter() - .map(|a| i32::cast_from(*a)) - .step_by(4) - .zip(temp_in.iter_mut()) - { - *transposed = val; - } - av1_iwht4(&temp_in, &mut temp_out); - for (temp, out) in temp_out - .iter() - .zip(output.rows_iter_mut().map(|row| &mut row[c]).take(4)) - { - let v = i32::cast_from(*out) + *temp; - *out = T::cast_from(v); - } - } - } - #[cold_for_target_arch("x86_64", "aarch64")] pub fn inverse_transform_add( input: &[T::Coeff], output: &mut PlaneRegionMut<'_, T>, _eob: usize, @@ -1686,6 +1647,7 @@ pub(crate) mod rust { let mut buffer = vec![0i32; width * height].into_boxed_slice(); let rect_type = get_rect_tx_log_ratio(width, height); let tx_types_1d = get_1d_tx_types(tx_type); + let lossless = tx_type == TxType::WHT_WHT; // perform inv txfm on every row let range = bd + 8; @@ -1705,6 +1667,8 @@ pub(crate) mod rust { { let val = if rect_type.abs() == 1 { round_shift(raw * INV_SQRT2, SQRT2_BITS) + } else if lossless { + raw >> 2 } else { raw }; @@ -1733,7 +1697,8 @@ pub(crate) mod rust { .zip(output.rows_iter_mut().map(|row| &mut row[c]).take(height)) { let v: i32 = (*out).as_(); - let v = clamp(v + round_shift(*temp, 4), 0, (1 << bd) - 1); + let r = if lossless { *temp } else { round_shift(*temp, 4) }; + let v = clamp(v + r, 0, (1 << bd) - 1); *out = T::cast_from(v); } } diff --git a/src/transform/mod.rs b/src/transform/mod.rs index fda9c1d3d3..23d4619280 100644 --- a/src/transform/mod.rs +++ b/src/transform/mod.rs @@ -14,9 +14,7 @@ pub mod forward_shared; pub use self::forward::forward_transform; -pub use self::forward::forward_transform_lossless; pub use self::inverse::inverse_transform_add; -pub use self::inverse::inverse_transform_add_lossless; use crate::context::MI_SIZE_LOG2; use crate::partition::{BlockSize, BlockSize::*}; @@ -52,6 +50,7 @@ pub mod consts { } pub const TX_TYPES: usize = 16; +pub const TX_TYPES_PLUS_LL: usize = 17; #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord)] pub enum TxType { @@ -71,6 +70,7 @@ pub enum TxType { H_ADST = 13, V_FLIPADST = 14, H_FLIPADST = 15, + WHT_WHT = 16, } impl TxType { @@ -337,6 +337,7 @@ enum TxType1D { ADST, FLIPADST, IDTX, + WHT, } const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) { @@ -357,10 +358,11 @@ const fn get_1d_tx_types(tx_type: TxType) -> (TxType1D, TxType1D) { TxType::H_ADST => (TxType1D::IDTX, TxType1D::ADST), TxType::V_FLIPADST => (TxType1D::FLIPADST, TxType1D::IDTX), TxType::H_FLIPADST => (TxType1D::IDTX, TxType1D::FLIPADST), + TxType::WHT_WHT => (TxType1D::WHT, TxType1D::WHT), } } -const VTX_TAB: [TxType1D; TX_TYPES] = [ +const VTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [ TxType1D::DCT, TxType1D::ADST, TxType1D::DCT, @@ -377,9 +379,10 @@ const VTX_TAB: [TxType1D; TX_TYPES] = [ TxType1D::IDTX, TxType1D::FLIPADST, TxType1D::IDTX, + TxType1D::WHT, ]; -const HTX_TAB: [TxType1D; TX_TYPES] = [ +const HTX_TAB: [TxType1D; TX_TYPES_PLUS_LL] = [ TxType1D::DCT, TxType1D::DCT, TxType1D::ADST, @@ -396,6 +399,7 @@ const HTX_TAB: [TxType1D; TX_TYPES] = [ TxType1D::ADST, TxType1D::IDTX, TxType1D::FLIPADST, + TxType1D::WHT, ]; #[inline] @@ -421,6 +425,26 @@ pub fn get_valid_txfm_types(tx_size: TxSize) -> &'static [TxType] { &[DCT_DCT] } else if size_sq == TxSize::TX_32X32 { &[DCT_DCT, IDTX] + } else if size_sq == TxSize::TX_4X4 { + &[ + DCT_DCT, + ADST_DCT, + DCT_ADST, + ADST_ADST, + FLIPADST_DCT, + DCT_FLIPADST, + FLIPADST_FLIPADST, + ADST_FLIPADST, + FLIPADST_ADST, + IDTX, + V_DCT, + H_DCT, + V_ADST, + H_ADST, + V_FLIPADST, + H_FLIPADST, + WHT_WHT, + ] } else { &[ DCT_DCT, @@ -495,30 +519,6 @@ mod test { } } - fn test_lossless_roundtrip() { - let cpu = CpuFeatureLevel::default(); - - let mut src_storage = [T::cast_from(0); 4 * 4]; - let src = &mut src_storage[..]; - // dynamic allocation: test - let mut dst = Plane::from_slice(&vec![T::cast_from(0); 4 * 4], 4); - let mut res_storage = [0i16; 4 * 4]; - let res = &mut res_storage[..]; - let mut freq_storage = [T::Coeff::cast_from(0); 4 * 4]; - let freq = &mut freq_storage[..4 * 4]; - for ((r, s), d) in - res.iter_mut().zip(src.iter_mut()).zip(dst.data.iter_mut()) - { - *s = T::cast_from(random::()); - *d = T::cast_from(random::()); - *r = i16::cast_from(*s) - i16::cast_from(*d); - } - forward_transform_lossless(res, freq, 4, cpu); - inverse_transform_add_lossless(freq, &mut dst.as_region_mut(), 15, 8, cpu); - - assert_eq!(&src[..], &dst.data[..]); - } - #[test] fn log_tx_ratios() { let combinations = [ @@ -557,6 +557,7 @@ mod test { fn roundtrips() { let combinations = [ + (TX_4X4, WHT_WHT, 0), (TX_4X4, DCT_DCT, 0), (TX_4X4, ADST_DCT, 0), (TX_4X4, DCT_ADST, 0), @@ -604,8 +605,6 @@ mod test { (TX_16X32, DCT_DCT, 2), (TX_32X16, DCT_DCT, 2), ]; - println!("Testing combination TX_4X4, WHT_WHT"); - test_lossless_roundtrip::(); for &(tx_size, tx_type, tolerance) in combinations.iter() { println!("Testing combination {:?}, {:?}", tx_size, tx_type); test_roundtrip::(tx_size, tx_type, tolerance);