From 0f563e5fadbccb10fabd6ac80c256a4321401e22 Mon Sep 17 00:00:00 2001 From: Jonathan Wright Date: Fri, 7 May 2021 13:25:51 +0100 Subject: [PATCH] Optimize Neon reductions in sum_neon.h using ADDV instruction Use the AArch64-only ADDV and ADDLV instructions to accelerate reductions that add across a Neon vector in sum_neon.h. This commit also refactors the inline functions to return a scalar instead of a vector - allowing for optimization of the surrounding code at each call site. Bug: b/181236880 Change-Id: Ieed2a2dd3c74f8a52957bf404141ffc044bd5d79 --- vpx_dsp/arm/avg_neon.c | 11 +++-------- vpx_dsp/arm/fdct_partial_neon.c | 32 +++++++++++--------------------- vpx_dsp/arm/sad_neon.c | 24 ++++++++++++------------ vpx_dsp/arm/sum_neon.h | 33 ++++++++++++++++++++++++--------- vpx_dsp/arm/variance_neon.c | 21 +++++++++------------ 5 files changed, 59 insertions(+), 62 deletions(-) diff --git a/vpx_dsp/arm/avg_neon.c b/vpx_dsp/arm/avg_neon.c index fa7dd0960..8e57bdaa5 100644 --- a/vpx_dsp/arm/avg_neon.c +++ b/vpx_dsp/arm/avg_neon.c @@ -22,15 +22,13 @@ uint32_t vpx_avg_4x4_neon(const uint8_t *a, int a_stride) { const uint8x16_t b = load_unaligned_u8q(a, a_stride); const uint16x8_t c = vaddl_u8(vget_low_u8(b), vget_high_u8(b)); - const uint32x2_t d = horizontal_add_uint16x8(c); - return vget_lane_u32(vrshr_n_u32(d, 4), 0); + return (horizontal_add_uint16x8(c) + (1 << 3)) >> 4; } uint32_t vpx_avg_8x8_neon(const uint8_t *a, int a_stride) { int i; uint8x8_t b, c; uint16x8_t sum; - uint32x2_t d; b = vld1_u8(a); a += a_stride; c = vld1_u8(a); @@ -43,9 +41,7 @@ uint32_t vpx_avg_8x8_neon(const uint8_t *a, int a_stride) { sum = vaddw_u8(sum, d); } - d = horizontal_add_uint16x8(sum); - - return vget_lane_u32(vrshr_n_u32(d, 6), 0); + return (horizontal_add_uint16x8(sum) + (1 << 5)) >> 6; } // coeff: 16 bits, dynamic range [-32640, 32640]. @@ -139,8 +135,7 @@ int16_t vpx_int_pro_col_neon(uint8_t const *ref, const int width) { ref += 16; } - return vget_lane_s16(vreinterpret_s16_u32(horizontal_add_uint16x8(vec_sum)), - 0); + return (int16_t)horizontal_add_uint16x8(vec_sum); } // ref, src = [0, 510] - max diff = 16-bits diff --git a/vpx_dsp/arm/fdct_partial_neon.c b/vpx_dsp/arm/fdct_partial_neon.c index e73de41d7..0a1cdca41 100644 --- a/vpx_dsp/arm/fdct_partial_neon.c +++ b/vpx_dsp/arm/fdct_partial_neon.c @@ -15,19 +15,10 @@ #include "vpx_dsp/arm/mem_neon.h" #include "vpx_dsp/arm/sum_neon.h" -static INLINE tran_low_t get_lane(const int32x2_t a) { -#if CONFIG_VP9_HIGHBITDEPTH - return vget_lane_s32(a, 0); -#else - return vget_lane_s16(vreinterpret_s16_s32(a), 0); -#endif // CONFIG_VP9_HIGHBITDETPH -} - void vpx_fdct4x4_1_neon(const int16_t *input, tran_low_t *output, int stride) { int16x4_t a0, a1, a2, a3; int16x8_t b0, b1; int16x8_t c; - int32x2_t d; a0 = vld1_s16(input); input += stride; @@ -42,9 +33,7 @@ void vpx_fdct4x4_1_neon(const int16_t *input, tran_low_t *output, int stride) { c = vaddq_s16(b0, b1); - d = horizontal_add_int16x8(c); - - output[0] = get_lane(vshl_n_s32(d, 1)); + output[0] = (tran_low_t)(horizontal_add_int16x8(c) << 1); output[1] = 0; } @@ -57,7 +46,7 @@ void vpx_fdct8x8_1_neon(const int16_t *input, tran_low_t *output, int stride) { sum = vaddq_s16(sum, input_00); } - output[0] = get_lane(horizontal_add_int16x8(sum)); + output[0] = (tran_low_t)horizontal_add_int16x8(sum); output[1] = 0; } @@ -66,7 +55,7 @@ void vpx_fdct16x16_1_neon(const int16_t *input, tran_low_t *output, int r; int16x8_t left = vld1q_s16(input); int16x8_t right = vld1q_s16(input + 8); - int32x2_t sum; + int32_t sum; input += stride; for (r = 1; r < 16; ++r) { @@ -77,9 +66,9 @@ void vpx_fdct16x16_1_neon(const int16_t *input, tran_low_t *output, right = vaddq_s16(right, b); } - sum = vadd_s32(horizontal_add_int16x8(left), horizontal_add_int16x8(right)); + sum = horizontal_add_int16x8(left) + horizontal_add_int16x8(right); - output[0] = get_lane(vshr_n_s32(sum, 1)); + output[0] = (tran_low_t)(sum >> 1); output[1] = 0; } @@ -90,7 +79,7 @@ void vpx_fdct32x32_1_neon(const int16_t *input, tran_low_t *output, int16x8_t a1 = vld1q_s16(input + 8); int16x8_t a2 = vld1q_s16(input + 16); int16x8_t a3 = vld1q_s16(input + 24); - int32x2_t sum; + int32_t sum; input += stride; for (r = 1; r < 32; ++r) { @@ -105,9 +94,10 @@ void vpx_fdct32x32_1_neon(const int16_t *input, tran_low_t *output, a3 = vaddq_s16(a3, b3); } - sum = vadd_s32(horizontal_add_int16x8(a0), horizontal_add_int16x8(a1)); - sum = vadd_s32(sum, horizontal_add_int16x8(a2)); - sum = vadd_s32(sum, horizontal_add_int16x8(a3)); - output[0] = get_lane(vshr_n_s32(sum, 3)); + sum = horizontal_add_int16x8(a0); + sum += horizontal_add_int16x8(a1); + sum += horizontal_add_int16x8(a2); + sum += horizontal_add_int16x8(a3); + output[0] = (tran_low_t)(sum >> 3); output[1] = 0; } diff --git a/vpx_dsp/arm/sad_neon.c b/vpx_dsp/arm/sad_neon.c index c4a49e366..59567bda5 100644 --- a/vpx_dsp/arm/sad_neon.c +++ b/vpx_dsp/arm/sad_neon.c @@ -23,7 +23,7 @@ uint32_t vpx_sad4x4_neon(const uint8_t *src_ptr, int src_stride, const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride); uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(ref_u8)); abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8)); - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); + return horizontal_add_uint16x8(abs); } uint32_t vpx_sad4x4_avg_neon(const uint8_t *src_ptr, int src_stride, @@ -35,7 +35,7 @@ uint32_t vpx_sad4x4_avg_neon(const uint8_t *src_ptr, int src_stride, const uint8x16_t avg = vrhaddq_u8(ref_u8, second_pred_u8); uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(avg)); abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg)); - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); + return horizontal_add_uint16x8(abs); } uint32_t vpx_sad4x8_neon(const uint8_t *src_ptr, int src_stride, @@ -51,7 +51,7 @@ uint32_t vpx_sad4x8_neon(const uint8_t *src_ptr, int src_stride, abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8)); } - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); + return horizontal_add_uint16x8(abs); } uint32_t vpx_sad4x8_avg_neon(const uint8_t *src_ptr, int src_stride, @@ -71,7 +71,7 @@ uint32_t vpx_sad4x8_avg_neon(const uint8_t *src_ptr, int src_stride, abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg)); } - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); + return horizontal_add_uint16x8(abs); } static INLINE uint16x8_t sad8x(const uint8_t *src_ptr, int src_stride, @@ -114,7 +114,7 @@ static INLINE uint16x8_t sad8x_avg(const uint8_t *src_ptr, int src_stride, uint32_t vpx_sad8x##n##_neon(const uint8_t *src_ptr, int src_stride, \ const uint8_t *ref_ptr, int ref_stride) { \ const uint16x8_t abs = sad8x(src_ptr, src_stride, ref_ptr, ref_stride, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } \ \ uint32_t vpx_sad8x##n##_avg_neon(const uint8_t *src_ptr, int src_stride, \ @@ -122,7 +122,7 @@ static INLINE uint16x8_t sad8x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *second_pred) { \ const uint16x8_t abs = \ sad8x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } sad8xN(4); @@ -172,7 +172,7 @@ static INLINE uint16x8_t sad16x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride) { \ const uint16x8_t abs = \ sad16x(src_ptr, src_stride, ref_ptr, ref_stride, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } \ \ uint32_t vpx_sad16x##n##_avg_neon(const uint8_t *src_ptr, int src_stride, \ @@ -180,7 +180,7 @@ static INLINE uint16x8_t sad16x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *second_pred) { \ const uint16x8_t abs = \ sad16x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } sad16xN(8); @@ -240,7 +240,7 @@ static INLINE uint16x8_t sad32x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride) { \ const uint16x8_t abs = \ sad32x(src_ptr, src_stride, ref_ptr, ref_stride, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } \ \ uint32_t vpx_sad32x##n##_avg_neon(const uint8_t *src_ptr, int src_stride, \ @@ -248,7 +248,7 @@ static INLINE uint16x8_t sad32x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *second_pred) { \ const uint16x8_t abs = \ sad32x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \ - return vget_lane_u32(horizontal_add_uint16x8(abs), 0); \ + return horizontal_add_uint16x8(abs); \ } sad32xN(16); @@ -338,7 +338,7 @@ static INLINE uint32x4_t sad64x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride) { \ const uint32x4_t abs = \ sad64x(src_ptr, src_stride, ref_ptr, ref_stride, n); \ - return vget_lane_u32(horizontal_add_uint32x4(abs), 0); \ + return horizontal_add_uint32x4(abs); \ } \ \ uint32_t vpx_sad64x##n##_avg_neon(const uint8_t *src_ptr, int src_stride, \ @@ -346,7 +346,7 @@ static INLINE uint32x4_t sad64x_avg(const uint8_t *src_ptr, int src_stride, const uint8_t *second_pred) { \ const uint32x4_t abs = \ sad64x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \ - return vget_lane_u32(horizontal_add_uint32x4(abs), 0); \ + return horizontal_add_uint32x4(abs); \ } sad64xN(32); diff --git a/vpx_dsp/arm/sum_neon.h b/vpx_dsp/arm/sum_neon.h index 9e6833aad..630296237 100644 --- a/vpx_dsp/arm/sum_neon.h +++ b/vpx_dsp/arm/sum_neon.h @@ -16,23 +16,38 @@ #include "./vpx_config.h" #include "vpx/vpx_integer.h" -static INLINE int32x2_t horizontal_add_int16x8(const int16x8_t a) { +static INLINE int32_t horizontal_add_int16x8(const int16x8_t a) { +#if defined(__aarch64__) + return vaddlvq_s16(a); +#else const int32x4_t b = vpaddlq_s16(a); const int64x2_t c = vpaddlq_s32(b); - return vadd_s32(vreinterpret_s32_s64(vget_low_s64(c)), - vreinterpret_s32_s64(vget_high_s64(c))); + const int32x2_t d = vadd_s32(vreinterpret_s32_s64(vget_low_s64(c)), + vreinterpret_s32_s64(vget_high_s64(c))); + return vget_lane_s32(d, 0); +#endif } -static INLINE uint32x2_t horizontal_add_uint16x8(const uint16x8_t a) { +static INLINE uint32_t horizontal_add_uint16x8(const uint16x8_t a) { +#if defined(__aarch64__) + return vaddlvq_u16(a); +#else const uint32x4_t b = vpaddlq_u16(a); const uint64x2_t c = vpaddlq_u32(b); - return vadd_u32(vreinterpret_u32_u64(vget_low_u64(c)), - vreinterpret_u32_u64(vget_high_u64(c))); + const uint32x2_t d = vadd_u32(vreinterpret_u32_u64(vget_low_u64(c)), + vreinterpret_u32_u64(vget_high_u64(c))); + return vget_lane_u32(d, 0); +#endif } -static INLINE uint32x2_t horizontal_add_uint32x4(const uint32x4_t a) { +static INLINE uint32_t horizontal_add_uint32x4(const uint32x4_t a) { +#if defined(__aarch64__) + return vaddvq_u32(a); +#else const uint64x2_t b = vpaddlq_u32(a); - return vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), - vreinterpret_u32_u64(vget_high_u64(b))); + const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)), + vreinterpret_u32_u64(vget_high_u64(b))); + return vget_lane_u32(c, 0); +#endif } #endif // VPX_VPX_DSP_ARM_SUM_NEON_H_ diff --git a/vpx_dsp/arm/variance_neon.c b/vpx_dsp/arm/variance_neon.c index 77b1015b7..e08f31f2c 100644 --- a/vpx_dsp/arm/variance_neon.c +++ b/vpx_dsp/arm/variance_neon.c @@ -66,10 +66,9 @@ static void variance_neon_w4x4(const uint8_t *src_ptr, int src_stride, ref_ptr += 4 * ref_stride; } - *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); - *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( - vaddq_s32(sse_lo_s32, sse_hi_s32))), - 0); + *sum = horizontal_add_int16x8(sum_s16); + *sse = horizontal_add_uint32x4( + vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32))); } // Process a block of any size where the width is divisible by 16. @@ -115,10 +114,9 @@ static void variance_neon_w16(const uint8_t *src_ptr, int src_stride, ref_ptr += ref_stride; } - *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); - *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( - vaddq_s32(sse_lo_s32, sse_hi_s32))), - 0); + *sum = horizontal_add_int16x8(sum_s16); + *sse = horizontal_add_uint32x4( + vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32))); } // Process a block of width 8 two rows at a time. @@ -157,10 +155,9 @@ static void variance_neon_w8x2(const uint8_t *src_ptr, int src_stride, i += 2; } while (i < h); - *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0); - *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32( - vaddq_s32(sse_lo_s32, sse_hi_s32))), - 0); + *sum = horizontal_add_int16x8(sum_s16); + *sse = horizontal_add_uint32x4( + vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32))); } void vpx_get8x8var_neon(const uint8_t *src_ptr, int src_stride, -- 2.40.0