Add vpx_highbd_sad64x{64,32}_avg_avx2.
authorScott LaVarnway <slavarnway@google.com>
Wed, 12 Oct 2022 17:26:43 +0000 (10:26 -0700)
committerScott LaVarnway <slavarnway@google.com>
Wed, 12 Oct 2022 18:43:39 +0000 (11:43 -0700)
~2.8x faster than the sse2 version.

Bug: b/245917257

Change-Id: Ib727ba8a8c8fa4df450bafdde30ed99fd283f06d

test/sad_test.cc
vpx_dsp/vpx_dsp_rtcd_defs.pl
vpx_dsp/x86/highbd_sad_avx2.c

index a3c2952d63aeff2dcd109fd7d3703e2b519d8287..29e3f57f5e27712bb893f1633de30e94bce294e0 100644 (file)
@@ -1120,18 +1120,24 @@ const SadMxNAvgParam avg_avx2_tests[] = {
   SadMxNAvgParam(32, 32, &vpx_sad32x32_avg_avx2),
   SadMxNAvgParam(32, 16, &vpx_sad32x16_avg_avx2),
 #if CONFIG_VP9_HIGHBITDEPTH
+  SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 8),
+  SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 8),
   SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 8),
   SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 8),
   SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 8),
   SadMxNAvgParam(16, 32, &vpx_highbd_sad16x32_avg_avx2, 8),
   SadMxNAvgParam(16, 16, &vpx_highbd_sad16x16_avg_avx2, 8),
   SadMxNAvgParam(16, 8, &vpx_highbd_sad16x8_avg_avx2, 8),
+  SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 10),
+  SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 10),
   SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 10),
   SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 10),
   SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 10),
   SadMxNAvgParam(16, 32, &vpx_highbd_sad16x32_avg_avx2, 10),
   SadMxNAvgParam(16, 16, &vpx_highbd_sad16x16_avg_avx2, 10),
   SadMxNAvgParam(16, 8, &vpx_highbd_sad16x8_avg_avx2, 10),
+  SadMxNAvgParam(64, 64, &vpx_highbd_sad64x64_avg_avx2, 12),
+  SadMxNAvgParam(64, 32, &vpx_highbd_sad64x32_avg_avx2, 12),
   SadMxNAvgParam(32, 64, &vpx_highbd_sad32x64_avg_avx2, 12),
   SadMxNAvgParam(32, 32, &vpx_highbd_sad32x32_avg_avx2, 12),
   SadMxNAvgParam(32, 16, &vpx_highbd_sad32x16_avg_avx2, 12),
index 4db6de37b6b98a83df07bb549abf1fc9d58cf41f..5fe9c1287997f0a0616404910d79182ff4675837 100644 (file)
@@ -995,10 +995,10 @@ if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   add_proto qw/void vpx_highbd_minmax_8x8/, "const uint8_t *s8, int p, const uint8_t *d8, int dp, int *min, int *max";
 
   add_proto qw/unsigned int vpx_highbd_sad64x64_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred";
-  specialize qw/vpx_highbd_sad64x64_avg sse2 neon/;
+  specialize qw/vpx_highbd_sad64x64_avg sse2 neon avx2/;
 
   add_proto qw/unsigned int vpx_highbd_sad64x32_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred";
-  specialize qw/vpx_highbd_sad64x32_avg sse2 neon/;
+  specialize qw/vpx_highbd_sad64x32_avg sse2 neon avx2/;
 
   add_proto qw/unsigned int vpx_highbd_sad32x64_avg/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, const uint8_t *second_pred";
   specialize qw/vpx_highbd_sad32x64_avg sse2 neon avx2/;
index 24ebe4e94a85004674f537f4f28e95efa3b18f05..7533ccfddb382eee1504dc2a81a94aa58d1ecded 100644 (file)
@@ -225,6 +225,83 @@ unsigned int vpx_highbd_sad16x8_avx2(const uint8_t *src8_ptr, int src_stride,
 }
 
 // AVG -------------------------------------------------------------------------
+static VPX_FORCE_INLINE void highbd_sad64xH_avg(__m256i *sums_16,
+                                                const uint16_t *src,
+                                                int src_stride, uint16_t *ref,
+                                                int ref_stride, uint16_t *sec,
+                                                int height) {
+  int i;
+  for (i = 0; i < height; ++i) {
+    // load src and all ref[]
+    const __m256i s0 = _mm256_load_si256((const __m256i *)src);
+    const __m256i s1 = _mm256_load_si256((const __m256i *)(src + 16));
+    const __m256i s2 = _mm256_load_si256((const __m256i *)(src + 32));
+    const __m256i s3 = _mm256_load_si256((const __m256i *)(src + 48));
+    const __m256i r0 = _mm256_loadu_si256((const __m256i *)ref);
+    const __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref + 16));
+    const __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref + 32));
+    const __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref + 48));
+    const __m256i x0 = _mm256_loadu_si256((const __m256i *)sec);
+    const __m256i x1 = _mm256_loadu_si256((const __m256i *)(sec + 16));
+    const __m256i x2 = _mm256_loadu_si256((const __m256i *)(sec + 32));
+    const __m256i x3 = _mm256_loadu_si256((const __m256i *)(sec + 48));
+    const __m256i avg0 = _mm256_avg_epu16(r0, x0);
+    const __m256i avg1 = _mm256_avg_epu16(r1, x1);
+    const __m256i avg2 = _mm256_avg_epu16(r2, x2);
+    const __m256i avg3 = _mm256_avg_epu16(r3, x3);
+    // absolute differences between every ref/pred avg to src
+    const __m256i abs_diff0 = _mm256_abs_epi16(_mm256_sub_epi16(avg0, s0));
+    const __m256i abs_diff1 = _mm256_abs_epi16(_mm256_sub_epi16(avg1, s1));
+    const __m256i abs_diff2 = _mm256_abs_epi16(_mm256_sub_epi16(avg2, s2));
+    const __m256i abs_diff3 = _mm256_abs_epi16(_mm256_sub_epi16(avg3, s3));
+    // sum every abs diff
+    *sums_16 =
+        _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff0, abs_diff1));
+    *sums_16 =
+        _mm256_add_epi16(*sums_16, _mm256_add_epi16(abs_diff2, abs_diff3));
+
+    src += src_stride;
+    ref += ref_stride;
+    sec += 64;
+  }
+}
+
+#define HIGHBD_SAD64XN_AVG(n)                                                 \
+  unsigned int vpx_highbd_sad64x##n##_avg_avx2(                               \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,         \
+      int ref_stride, const uint8_t *second_pred) {                           \
+    const uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                             \
+    uint16_t *sec = CONVERT_TO_SHORTPTR(second_pred);                         \
+    __m256i sums_32 = _mm256_setzero_si256();                                 \
+    int i;                                                                    \
+                                                                              \
+    for (i = 0; i < (n / 2); ++i) {                                           \
+      __m256i sums_16 = _mm256_setzero_si256();                               \
+                                                                              \
+      highbd_sad64xH_avg(&sums_16, src, src_stride, ref, ref_stride, sec, 2); \
+                                                                              \
+      /* sums_16 will outrange after 2 rows, so add current sums_16 to        \
+       * sums_32*/                                                            \
+      sums_32 = _mm256_add_epi32(                                             \
+          sums_32,                                                            \
+          _mm256_add_epi32(                                                   \
+              _mm256_cvtepu16_epi32(_mm256_castsi256_si128(sums_16)),         \
+              _mm256_cvtepu16_epi32(_mm256_extractf128_si256(sums_16, 1))));  \
+                                                                              \
+      src += src_stride << 1;                                                 \
+      ref += ref_stride << 1;                                                 \
+      sec += 64 << 1;                                                         \
+    }                                                                         \
+    return calc_final(sums_32);                                               \
+  }
+
+// 64x64
+HIGHBD_SAD64XN_AVG(64)
+
+// 64x32
+HIGHBD_SAD64XN_AVG(32)
+
 static VPX_FORCE_INLINE void highbd_sad32xH_avg(__m256i *sums_16,
                                                 const uint16_t *src,
                                                 int src_stride, uint16_t *ref,