From 7142689f00e73d461b8d00347ee84da2ee420994 Mon Sep 17 00:00:00 2001 From: Scott LaVarnway Date: Wed, 12 Oct 2022 10:26:43 -0700 Subject: [PATCH] Add vpx_highbd_sad64x{64,32}_avg_avx2. ~2.8x faster than the sse2 version. Bug: b/245917257 Change-Id: Ib727ba8a8c8fa4df450bafdde30ed99fd283f06d --- test/sad_test.cc | 6 +++ vpx_dsp/vpx_dsp_rtcd_defs.pl | 4 +- vpx_dsp/x86/highbd_sad_avx2.c | 77 +++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/test/sad_test.cc b/test/sad_test.cc index a3c2952d6..29e3f57f5 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc @@ -1120,18 +1120,24 @@ const SadMxNAvgParam avg_avx2_tests[] = { SadMxNAvgParam(32, 32, &vpx_sad32x32_avg_avx2), SadMxNAvgParam(32, 16, &vpx_sad32x16_avg_avx2), #if CONFIG_VP9_HIGHBITDEPTH + SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 8), + SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 8), SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 8), SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 8), SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 8), SadMxNAvgParam(16, 32, &vpx_highbd_sad16x32_avg_avx2, 8), SadMxNAvgParam(16, 16, &vpx_highbd_sad16x16_avg_avx2, 8), SadMxNAvgParam(16, 8, &vpx_highbd_sad16x8_avg_avx2, 8), + SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 10), + SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 10), SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 10), SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 10), SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 10), SadMxNAvgParam(16, 32, &vpx_highbd_sad16x32_avg_avx2, 10), SadMxNAvgParam(16, 16, &vpx_highbd_sad16x16_avg_avx2, 10), SadMxNAvgParam(16, 8, &vpx_highbd_sad16x8_avg_avx2, 10), + SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 12), + SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 12), SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 12), SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 12), SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 12), diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index 4db6de37b..5fe9c1287 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -995,10 +995,10 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") { add_proto qw/void vpx_highbd_minmax_8x8/, "const uint8_t *s8, int p, const uint8_t *d8, int dp, int *min, int *max"; add_proto qw/unsigned int vpx_highbd_sad64x64_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred"; - specialize qw/vpx_highbd_sad64x64_avg sse2 neon/; + specialize qw/vpx_highbd_sad64x64_avg sse2 neon avx2/; add_proto qw/unsigned int vpx_highbd_sad64x32_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred"; - specialize qw/vpx_highbd_sad64x32_avg sse2 neon/; + specialize qw/vpx_highbd_sad64x32_avg sse2 neon avx2/; add_proto qw/unsigned int vpx_highbd_sad32x64_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred"; specialize qw/vpx_highbd_sad32x64_avg sse2 neon avx2/; diff --git a/vpx_dsp/x86/highbd_sad_avx2.c b/vpx_dsp/x86/highbd_sad_avx2.c index 24ebe4e94..7533ccfdd 100644 --- a/vpx_dsp/x86/highbd_sad_avx2.c +++ b/vpx_dsp/x86/highbd_sad_avx2.c @@ -225,6 +225,83 @@ unsigned int vpx_highbd_sad16x8_avx2(const uint8_t *src8_ptr, int src_stride, } // AVG ------------------------------------------------------------------------- +static VPX_FORCE_INLINE void highbd_sad64xH_avg(__m256i *sums_16, + const uint16_t *src, + int src_stride, uint16_t *ref, + int ref_stride, uint16_t *sec, + int height) { + int i; + for (i = 0; i < height; ++i) { + // load src and all ref[] + const __m256i s0 = _mm256_load_si256((const __m256i *)src); + const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16)); + const __m256i s2 = _mm256_load_si256((const __m256i *)(src + 32)); + const __m256i s3 = _mm256_load_si256((const __m256i *)(src + 48)); + const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref); + const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16)); + const __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref + 32)); + const __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref + 48)); + const __m256i x0 = _mm256_loadu_si256((const __m256i *)sec); + const __m256i x1 = _mm256_loadu_si256((const __m256i *)(sec + 16)); + const __m256i x2 = _mm256_loadu_si256((const __m256i *)(sec + 32)); + const __m256i x3 = _mm256_loadu_si256((const __m256i *)(sec + 48)); + const __m256i avg0 = _mm256_avg_epu16(r0, x0); + const __m256i avg1 = _mm256_avg_epu16(r1, x1); + const __m256i avg2 = _mm256_avg_epu16(r2, x2); + const __m256i avg3 = _mm256_avg_epu16(r3, x3); + // absolute differences between every ref/pred avg to src + const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(avg0, s0)); + const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(avg1, s1)); + const __m256i abs_diff2 = _mm256_abs_epi16(_mm256_sub_epi16(avg2, s2)); + const __m256i abs_diff3 = _mm256_abs_epi16(_mm256_sub_epi16(avg3, s3)); + // sum every abs diff + *sums_16 = + _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff0, abs_diff1)); + *sums_16 = + _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff2, abs_diff3)); + + src += src_stride; + ref += ref_stride; + sec += 64; + } +} + +#define HIGHBD_SAD64XN_AVG(n) \ + unsigned int vpx_highbd_sad64x##n##_avg_avx2( \ + const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \ + int ref_stride, const uint8_t *second_pred) { \ + const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr); \ + uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr); \ + uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred); \ + __m256i sums_32 = _mm256_setzero_si256(); \ + int i; \ + \ + for (i = 0; i < (n / 2); ++i) { \ + __m256i sums_16 = _mm256_setzero_si256(); \ + \ + highbd_sad64xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 2); \ + \ + /* sums_16 will outrange after 2 rows, so add current sums_16 to \ + * sums_32*/ \ + sums_32 = _mm256_add_epi32( \ + sums_32, \ + _mm256_add_epi32( \ + _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)), \ + _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1)))); \ + \ + src += src_stride << 1; \ + ref += ref_stride << 1; \ + sec += 64 << 1; \ + } \ + return calc_final(sums_32); \ + } + +// 64x64 +HIGHBD_SAD64XN_AVG(64) + +// 64x32 +HIGHBD_SAD64XN_AVG(32) + static VPX_FORCE_INLINE void highbd_sad32xH_avg(__m256i *sums_16, const uint16_t *src, int src_stride, uint16_t *ref, -- 2.40.0