From 39de45d3cc178012cd5e8e17a2f9bcae2f858299 Mon Sep 17 00:00:00 2001 From: Linfeng Zhang Date: Mon, 26 Mar 2018 13:06:09 -0700 Subject: [PATCH] Update sad4d x86 functions Speed change is marginal. Change-Id: I4d548e9763ce43bd546f19132202f7a8509a32bf --- test/sad_test.cc | 4 +- vpx_dsp/x86/sad4d_avx2.c | 216 +++++++++++++++++---------------------- 2 files changed, 93 insertions(+), 127 deletions(-) diff --git a/test/sad_test.cc b/test/sad_test.cc index b93ade328..620cd3a2f 100644 --- a/test/sad_test.cc +++ b/test/sad_test.cc @@ -85,7 +85,7 @@ class SADTestBase : public ::testing::TestWithParam { #endif // CONFIG_VP9_HIGHBITDEPTH } mask_ = (1 << bit_depth_) - 1; - source_stride_ = (params_.width + 31) & ~31; + source_stride_ = (params_.width + 63) & ~63; reference_stride_ = params_.width * 2; rnd_.Reset(ACMRandom::DeterministicSeed()); } @@ -109,7 +109,7 @@ class SADTestBase : public ::testing::TestWithParam { protected: // Handle blocks up to 4 blocks 64x64 with stride up to 128 - static const int kDataAlignment = 16; + static const int kDataAlignment = 32; static const int kDataBlockSize = 64 * 128; static const int kDataBufferSize = 4 * kDataBlockSize; diff --git a/vpx_dsp/x86/sad4d_avx2.c b/vpx_dsp/x86/sad4d_avx2.c index 962b8fb11..2c6b36f17 100644 --- a/vpx_dsp/x86/sad4d_avx2.c +++ b/vpx_dsp/x86/sad4d_avx2.c @@ -11,154 +11,120 @@ #include "./vpx_dsp_rtcd.h" #include "vpx/vpx_integer.h" +static INLINE void calc_final(const __m256i *const sums /*[4]*/, + uint32_t res[4]) { + const __m256i t0 = _mm256_hadd_epi32(sums[0], sums[1]); + const __m256i t1 = _mm256_hadd_epi32(sums[2], sums[3]); + const __m256i t2 = _mm256_hadd_epi32(t0, t1); + const __m128i sum = _mm_add_epi32(_mm256_castsi256_si128(t2), + _mm256_extractf128_si256(t2, 1)); + _mm_storeu_si128((__m128i *)res, sum); +} + void vpx_sad32x32x4d_avx2(const uint8_t *src, int src_stride, const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) { - __m256i src_reg, ref0_reg, ref1_reg, ref2_reg, ref3_reg; - __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3; - __m256i sum_mlow, sum_mhigh; int i; - const uint8_t *ref0, *ref1, *ref2, *ref3; - - ref0 = ref[0]; - ref1 = ref[1]; - ref2 = ref[2]; - ref3 = ref[3]; - sum_ref0 = _mm256_set1_epi16(0); - sum_ref1 = _mm256_set1_epi16(0); - sum_ref2 = _mm256_set1_epi16(0); - sum_ref3 = _mm256_set1_epi16(0); + const uint8_t *refs[4]; + __m256i sums[4]; + + refs[0] = ref[0]; + refs[1] = ref[1]; + refs[2] = ref[2]; + refs[3] = ref[3]; + sums[0] = _mm256_setzero_si256(); + sums[1] = _mm256_setzero_si256(); + sums[2] = _mm256_setzero_si256(); + sums[3] = _mm256_setzero_si256(); + for (i = 0; i < 32; i++) { + __m256i r[4]; + // load src and all refs - src_reg = _mm256_loadu_si256((const __m256i *)src); - ref0_reg = _mm256_loadu_si256((const __m256i *)ref0); - ref1_reg = _mm256_loadu_si256((const __m256i *)ref1); - ref2_reg = _mm256_loadu_si256((const __m256i *)ref2); - ref3_reg = _mm256_loadu_si256((const __m256i *)ref3); + const __m256i s = _mm256_load_si256((const __m256i *)src); + r[0] = _mm256_loadu_si256((const __m256i *)refs[0]); + r[1] = _mm256_loadu_si256((const __m256i *)refs[1]); + r[2] = _mm256_loadu_si256((const __m256i *)refs[2]); + r[3] = _mm256_loadu_si256((const __m256i *)refs[3]); + // sum of the absolute differences between every ref-i to src - ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); - ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); - ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); - ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg); + r[0] = _mm256_sad_epu8(r[0], s); + r[1] = _mm256_sad_epu8(r[1], s); + r[2] = _mm256_sad_epu8(r[2], s); + r[3] = _mm256_sad_epu8(r[3], s); + // sum every ref-i - sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); - sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); - sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); - sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg); + sums[0] = _mm256_add_epi32(sums[0], r[0]); + sums[1] = _mm256_add_epi32(sums[1], r[1]); + sums[2] = _mm256_add_epi32(sums[2], r[2]); + sums[3] = _mm256_add_epi32(sums[3], r[3]); src += src_stride; - ref0 += ref_stride; - ref1 += ref_stride; - ref2 += ref_stride; - ref3 += ref_stride; - } - { - __m128i sum; - // in sum_ref-i the result is saved in the first 4 bytes - // the other 4 bytes are zeroed. - // sum_ref1 and sum_ref3 are shifted left by 4 bytes - sum_ref1 = _mm256_slli_si256(sum_ref1, 4); - sum_ref3 = _mm256_slli_si256(sum_ref3, 4); - - // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3 - sum_ref0 = _mm256_or_si256(sum_ref0, sum_ref1); - sum_ref2 = _mm256_or_si256(sum_ref2, sum_ref3); - - // merge every 64 bit from each sum_ref-i - sum_mlow = _mm256_unpacklo_epi64(sum_ref0, sum_ref2); - sum_mhigh = _mm256_unpackhi_epi64(sum_ref0, sum_ref2); - - // add the low 64 bit to the high 64 bit - sum_mlow = _mm256_add_epi32(sum_mlow, sum_mhigh); - - // add the low 128 bit to the high 128 bit - sum = _mm_add_epi32(_mm256_castsi256_si128(sum_mlow), - _mm256_extractf128_si256(sum_mlow, 1)); - - _mm_storeu_si128((__m128i *)(res), sum); + refs[0] += ref_stride; + refs[1] += ref_stride; + refs[2] += ref_stride; + refs[3] += ref_stride; } + + calc_final(sums, res); } void vpx_sad64x64x4d_avx2(const uint8_t *src, int src_stride, const uint8_t *const ref[4], int ref_stride, uint32_t res[4]) { - __m256i src_reg, srcnext_reg, ref0_reg, ref0next_reg; - __m256i ref1_reg, ref1next_reg, ref2_reg, ref2next_reg; - __m256i ref3_reg, ref3next_reg; - __m256i sum_ref0, sum_ref1, sum_ref2, sum_ref3; - __m256i sum_mlow, sum_mhigh; + __m256i sums[4]; int i; - const uint8_t *ref0, *ref1, *ref2, *ref3; - - ref0 = ref[0]; - ref1 = ref[1]; - ref2 = ref[2]; - ref3 = ref[3]; - sum_ref0 = _mm256_set1_epi16(0); - sum_ref1 = _mm256_set1_epi16(0); - sum_ref2 = _mm256_set1_epi16(0); - sum_ref3 = _mm256_set1_epi16(0); + const uint8_t *refs[4]; + + refs[0] = ref[0]; + refs[1] = ref[1]; + refs[2] = ref[2]; + refs[3] = ref[3]; + sums[0] = _mm256_setzero_si256(); + sums[1] = _mm256_setzero_si256(); + sums[2] = _mm256_setzero_si256(); + sums[3] = _mm256_setzero_si256(); + for (i = 0; i < 64; i++) { + __m256i r_lo[4], r_hi[4]; // load 64 bytes from src and all refs - src_reg = _mm256_loadu_si256((const __m256i *)src); - srcnext_reg = _mm256_loadu_si256((const __m256i *)(src + 32)); - ref0_reg = _mm256_loadu_si256((const __m256i *)ref0); - ref0next_reg = _mm256_loadu_si256((const __m256i *)(ref0 + 32)); - ref1_reg = _mm256_loadu_si256((const __m256i *)ref1); - ref1next_reg = _mm256_loadu_si256((const __m256i *)(ref1 + 32)); - ref2_reg = _mm256_loadu_si256((const __m256i *)ref2); - ref2next_reg = _mm256_loadu_si256((const __m256i *)(ref2 + 32)); - ref3_reg = _mm256_loadu_si256((const __m256i *)ref3); - ref3next_reg = _mm256_loadu_si256((const __m256i *)(ref3 + 32)); + const __m256i s_lo = _mm256_load_si256((const __m256i *)src); + const __m256i s_hi = _mm256_load_si256((const __m256i *)(src + 32)); + r_lo[0] = _mm256_loadu_si256((const __m256i *)refs[0]); + r_hi[0] = _mm256_loadu_si256((const __m256i *)(refs[0] + 32)); + r_lo[1] = _mm256_loadu_si256((const __m256i *)refs[1]); + r_hi[1] = _mm256_loadu_si256((const __m256i *)(refs[1] + 32)); + r_lo[2] = _mm256_loadu_si256((const __m256i *)refs[2]); + r_hi[2] = _mm256_loadu_si256((const __m256i *)(refs[2] + 32)); + r_lo[3] = _mm256_loadu_si256((const __m256i *)refs[3]); + r_hi[3] = _mm256_loadu_si256((const __m256i *)(refs[3] + 32)); + // sum of the absolute differences between every ref-i to src - ref0_reg = _mm256_sad_epu8(ref0_reg, src_reg); - ref1_reg = _mm256_sad_epu8(ref1_reg, src_reg); - ref2_reg = _mm256_sad_epu8(ref2_reg, src_reg); - ref3_reg = _mm256_sad_epu8(ref3_reg, src_reg); - ref0next_reg = _mm256_sad_epu8(ref0next_reg, srcnext_reg); - ref1next_reg = _mm256_sad_epu8(ref1next_reg, srcnext_reg); - ref2next_reg = _mm256_sad_epu8(ref2next_reg, srcnext_reg); - ref3next_reg = _mm256_sad_epu8(ref3next_reg, srcnext_reg); + r_lo[0] = _mm256_sad_epu8(r_lo[0], s_lo); + r_lo[1] = _mm256_sad_epu8(r_lo[1], s_lo); + r_lo[2] = _mm256_sad_epu8(r_lo[2], s_lo); + r_lo[3] = _mm256_sad_epu8(r_lo[3], s_lo); + r_hi[0] = _mm256_sad_epu8(r_hi[0], s_hi); + r_hi[1] = _mm256_sad_epu8(r_hi[1], s_hi); + r_hi[2] = _mm256_sad_epu8(r_hi[2], s_hi); + r_hi[3] = _mm256_sad_epu8(r_hi[3], s_hi); // sum every ref-i - sum_ref0 = _mm256_add_epi32(sum_ref0, ref0_reg); - sum_ref1 = _mm256_add_epi32(sum_ref1, ref1_reg); - sum_ref2 = _mm256_add_epi32(sum_ref2, ref2_reg); - sum_ref3 = _mm256_add_epi32(sum_ref3, ref3_reg); - sum_ref0 = _mm256_add_epi32(sum_ref0, ref0next_reg); - sum_ref1 = _mm256_add_epi32(sum_ref1, ref1next_reg); - sum_ref2 = _mm256_add_epi32(sum_ref2, ref2next_reg); - sum_ref3 = _mm256_add_epi32(sum_ref3, ref3next_reg); + sums[0] = _mm256_add_epi32(sums[0], r_lo[0]); + sums[1] = _mm256_add_epi32(sums[1], r_lo[1]); + sums[2] = _mm256_add_epi32(sums[2], r_lo[2]); + sums[3] = _mm256_add_epi32(sums[3], r_lo[3]); + sums[0] = _mm256_add_epi32(sums[0], r_hi[0]); + sums[1] = _mm256_add_epi32(sums[1], r_hi[1]); + sums[2] = _mm256_add_epi32(sums[2], r_hi[2]); + sums[3] = _mm256_add_epi32(sums[3], r_hi[3]); + src += src_stride; - ref0 += ref_stride; - ref1 += ref_stride; - ref2 += ref_stride; - ref3 += ref_stride; + refs[0] += ref_stride; + refs[1] += ref_stride; + refs[2] += ref_stride; + refs[3] += ref_stride; } - { - __m128i sum; - - // in sum_ref-i the result is saved in the first 4 bytes - // the other 4 bytes are zeroed. - // sum_ref1 and sum_ref3 are shifted left by 4 bytes - sum_ref1 = _mm256_slli_si256(sum_ref1, 4); - sum_ref3 = _mm256_slli_si256(sum_ref3, 4); - - // merge sum_ref0 and sum_ref1 also sum_ref2 and sum_ref3 - sum_ref0 = _mm256_or_si256(sum_ref0, sum_ref1); - sum_ref2 = _mm256_or_si256(sum_ref2, sum_ref3); - // merge every 64 bit from each sum_ref-i - sum_mlow = _mm256_unpacklo_epi64(sum_ref0, sum_ref2); - sum_mhigh = _mm256_unpackhi_epi64(sum_ref0, sum_ref2); - - // add the low 64 bit to the high 64 bit - sum_mlow = _mm256_add_epi32(sum_mlow, sum_mhigh); - - // add the low 128 bit to the high 128 bit - sum = _mm_add_epi32(_mm256_castsi256_si128(sum_mlow), - _mm256_extractf128_si256(sum_mlow, 1)); - - _mm_storeu_si128((__m128i *)(res), sum); - } + calc_final(sums, res); } -- 2.40.0