From 165935a1b6c3dfe2af686545188c3abebc4941d8 Mon Sep 17 00:00:00 2001 From: Konstantinos Margaritis Date: Thu, 6 Oct 2022 14:53:56 +0000 Subject: [PATCH] [NEON] Add highbd FDCT 4x4 function ~80% faster than C version for both best/rt profiles. Change-Id: Ibb3c8e1862131d2a020922420d53c66b31d5c2c3 --- test/dct_test.cc | 21 ++++++++-- vpx_dsp/arm/fdct4x4_neon.c | 38 ++++++++++++++++++ vpx_dsp/arm/fdct_neon.h | 74 ++++++++++++++++++++++++++++++++++++ vpx_dsp/vpx_dsp_rtcd_defs.pl | 2 +- 4 files changed, 131 insertions(+), 4 deletions(-) diff --git a/test/dct_test.cc b/test/dct_test.cc index 2182f87e5..e34122ac9 100644 --- a/test/dct_test.cc +++ b/test/dct_test.cc @@ -539,6 +539,18 @@ INSTANTIATE_TEST_SUITE_P(AVX2, TransDCT, #endif // HAVE_AVX2 && !CONFIG_VP9_HIGHBITDEPTH #if HAVE_NEON +#if CONFIG_VP9_HIGHBITDEPTH +static const FuncInfo dct_neon_func_info[] = { + { &fdct_wrapper, + &highbd_idct_wrapper, 4, 2 }, + /* { &fdct_wrapper, + &highbd_idct_wrapper, 8, 2 }, + { &fdct_wrapper, + &highbd_idct_wrapper, 16, 2 }, + { &fdct_wrapper, + &highbd_idct_wrapper, 32, 2 },*/ +}; +#else static const FuncInfo dct_neon_func_info[4] = { { &fdct_wrapper, &idct_wrapper, 4, 1 }, @@ -549,12 +561,15 @@ static const FuncInfo dct_neon_func_info[4] = { { &fdct_wrapper, &idct_wrapper, 32, 1 } }; +#endif // CONFIG_VP9_HIGHBITDEPTH INSTANTIATE_TEST_SUITE_P( NEON, TransDCT, - ::testing::Combine(::testing::Range(0, 4), - ::testing::Values(dct_neon_func_info), - ::testing::Values(0), ::testing::Values(VPX_BITS_8))); + ::testing::Combine( + ::testing::Range(0, static_cast(sizeof(dct_neon_func_info) / + sizeof(dct_neon_func_info[0]))), + ::testing::Values(dct_neon_func_info), ::testing::Values(0), + ::testing::Values(VPX_BITS_8, VPX_BITS_10, VPX_BITS_12))); #endif // HAVE_NEON #if HAVE_MSA && !CONFIG_VP9_HIGHBITDEPTH diff --git a/vpx_dsp/arm/fdct4x4_neon.c b/vpx_dsp/arm/fdct4x4_neon.c index 2827791f1..11df7292d 100644 --- a/vpx_dsp/arm/fdct4x4_neon.c +++ b/vpx_dsp/arm/fdct4x4_neon.c @@ -48,3 +48,41 @@ void vpx_fdct4x4_neon(const int16_t *input, tran_low_t *final_output, store_s16q_to_tran_low(final_output + 1 * 8, out_23); } } + +#if CONFIG_VP9_HIGHBITDEPTH + +void vpx_highbd_fdct4x4_neon(const int16_t *input, tran_low_t *final_output, + int stride) { + int i; + static const int32x4_t const_1000 = { 1, 0, 0, 0 }; + const int32x4_t const_one = vdupq_n_s32(1); + + // input[M * stride] * 16 + int32x4_t in[4]; + in[0] = vshll_n_s16(vld1_s16(input + 0 * stride), 4); + in[1] = vshll_n_s16(vld1_s16(input + 1 * stride), 4); + in[2] = vshll_n_s16(vld1_s16(input + 2 * stride), 4); + in[3] = vshll_n_s16(vld1_s16(input + 3 * stride), 4); + + // If the very first value != 0, then add 1. + if (input[0] != 0) { + in[0] = vaddq_s32(in[0], const_1000); + } + + for (i = 0; i < 2; ++i) { + vpx_highbd_fdct4x4_pass1_neon(in); + } + { + // Not quite a rounding shift. Only add 1 despite shifting by 2. + in[0] = vshrq_n_s32(vaddq_s32(in[0], const_one), 2); + in[1] = vshrq_n_s32(vaddq_s32(in[1], const_one), 2); + in[2] = vshrq_n_s32(vaddq_s32(in[2], const_one), 2); + in[3] = vshrq_n_s32(vaddq_s32(in[3], const_one), 2); + + vst1q_s32(final_output, in[0]); + vst1q_s32(final_output + 4, in[1]); + vst1q_s32(final_output + 8, in[2]); + vst1q_s32(final_output + 12, in[3]); + } +} +#endif // CONFIG_VP9_HIGHBITDEPTH diff --git a/vpx_dsp/arm/fdct_neon.h b/vpx_dsp/arm/fdct_neon.h index 056cae408..68aeab3aa 100644 --- a/vpx_dsp/arm/fdct_neon.h +++ b/vpx_dsp/arm/fdct_neon.h @@ -340,4 +340,78 @@ static INLINE void vpx_fdct8x8_pass1_neon(int16x8_t *in) { // 07 17 27 37 47 57 67 77 } } + +#if CONFIG_VP9_HIGHBITDEPTH +static INLINE void highbd_butterfly_one_coeff_s32(const int32x4_t a, + const int32x4_t b, + const tran_high_t c, + int32x4_t *add, + int32x4_t *sub) { + const int32x2_t a_lo = vget_low_s32(a); + const int32x2_t a_hi = vget_high_s32(a); + const int32x2_t b_lo = vget_low_s32(b); + const int32x2_t b_hi = vget_high_s32(b); + + const int64x2_t a64_lo = vmull_n_s32(a_lo, c); + const int64x2_t a64_hi = vmull_n_s32(a_hi, c); + + const int64x2_t sum_lo = vmlal_n_s32(a64_lo, b_lo, c); + const int64x2_t sum_hi = vmlal_n_s32(a64_hi, b_hi, c); + const int64x2_t diff_lo = vmlsl_n_s32(a64_lo, b_lo, c); + const int64x2_t diff_hi = vmlsl_n_s32(a64_hi, b_hi, c); + + *add = vcombine_s32(vrshrn_n_s64(sum_lo, DCT_CONST_BITS), + vrshrn_n_s64(sum_hi, DCT_CONST_BITS)); + *sub = vcombine_s32(vrshrn_n_s64(diff_lo, DCT_CONST_BITS), + vrshrn_n_s64(diff_hi, DCT_CONST_BITS)); +} + +static INLINE void highbd_butterfly_two_coeff_s32( + const int32x4_t a, const int32x4_t b, const tran_coef_t c0, + const tran_coef_t c1, int32x4_t *add, int32x4_t *sub) { + const int32x2_t a_lo = vget_low_s32(a); + const int32x2_t a_hi = vget_high_s32(a); + const int32x2_t b_lo = vget_low_s32(b); + const int32x2_t b_hi = vget_high_s32(b); + + const int64x2_t axc0_64_lo = vmull_n_s32(a_lo, c0); + const int64x2_t axc0_64_hi = vmull_n_s32(a_hi, c0); + const int64x2_t axc1_64_lo = vmull_n_s32(a_lo, c1); + const int64x2_t axc1_64_hi = vmull_n_s32(a_hi, c1); + + const int64x2_t sum_lo = vmlal_n_s32(axc0_64_lo, b_lo, c1); + const int64x2_t sum_hi = vmlal_n_s32(axc0_64_hi, b_hi, c1); + const int64x2_t diff_lo = vmlsl_n_s32(axc1_64_lo, b_lo, c0); + const int64x2_t diff_hi = vmlsl_n_s32(axc1_64_hi, b_hi, c0); + + *add = vcombine_s32(vrshrn_n_s64(sum_lo, DCT_CONST_BITS), + vrshrn_n_s64(sum_hi, DCT_CONST_BITS)); + *sub = vcombine_s32(vrshrn_n_s64(diff_lo, DCT_CONST_BITS), + vrshrn_n_s64(diff_hi, DCT_CONST_BITS)); +} + +static INLINE void vpx_highbd_fdct4x4_pass1_neon(int32x4_t *in) { + int32x4_t out[4]; + // in_0 +/- in_3, in_1 +/- in_2 + const int32x4_t s_0 = vaddq_s32(in[0], in[3]); + const int32x4_t s_1 = vaddq_s32(in[1], in[2]); + const int32x4_t s_2 = vsubq_s32(in[1], in[2]); + const int32x4_t s_3 = vsubq_s32(in[0], in[3]); + + highbd_butterfly_one_coeff_s32(s_0, s_1, cospi_16_64, &out[0], &out[2]); + + // out[1] = s_3 * cospi_8_64 + s_2 * cospi_24_64 + // out[3] = s_3 * cospi_24_64 - s_2 * cospi_8_64 + highbd_butterfly_two_coeff_s32(s_3, s_2, cospi_8_64, cospi_24_64, &out[1], + &out[3]); + + transpose_s32_4x4(&out[0], &out[1], &out[2], &out[3]); + + in[0] = out[0]; + in[1] = out[1]; + in[2] = out[2]; + in[3] = out[3]; +} + +#endif // CONFIG_VP9_HIGHBITDEPTH #endif // VPX_VPX_DSP_ARM_FDCT_NEON_H_ diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index cbf0e6ea8..c5514b14d 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -552,7 +552,7 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") { specialize qw/vpx_fdct32x32_1 sse2 neon/; add_proto qw/void vpx_highbd_fdct4x4/, "const int16_t *input, tran_low_t *output, int stride"; - specialize qw/vpx_highbd_fdct4x4 sse2/; + specialize qw/vpx_highbd_fdct4x4 sse2 neon/; add_proto qw/void vpx_highbd_fdct8x8/, "const int16_t *input, tran_low_t *output, int stride"; specialize qw/vpx_highbd_fdct8x8 sse2/; -- 2.49.0