]> granicus.if.org Git - libvpx/commitdiff
Use ABD and UDOT to implement Neon sad_4d functions
authorJonathan Wright <jonathan.wright@arm.com>
Mon, 10 May 2021 11:22:03 +0000 (12:22 +0100)
committerJonathan Wright <jonathan.wright@arm.com>
Mon, 10 May 2021 14:20:29 +0000 (15:20 +0100)
Implementing sad16_neon using ABD, UDOT instead of ABAL, ABAL2 saves
a cycle and removes resource contention for a single SIMD pipe on
modern out-of-order Arm CPUs. The UDOT accumulation into 32-bit
elements also allows for a faster reduction at the end of each SAD
function.

The existing implementation is retained for CPUs that do not
implement the Armv8.4-A UDOT instruction, and CPUs executing in
AArch32 mode.

Bug: b/181236880
Change-Id: Ibd0da46e86751d2f808c7b1e424f82b046a1aa6f

vpx_dsp/arm/sad4d_neon.c

index 256bc41ce7aa7bd6691ac40cd4185cadfdf2a44e..5c7a0fcaf00931082316065d547f949bf946b16d 100644 (file)
@@ -98,6 +98,8 @@ static INLINE void sad_512_pel_final_neon(const uint16x8_t *sum /*[4]*/,
   vst1q_u32(res, r);
 }
 
+#if defined(__arm__) || !defined(__ARM_FEATURE_DOTPROD)
+
 // Can handle 1024 pixels' sad sum (such as 32x32)
 static INLINE void sad_1024_pel_final_neon(const uint16x8_t *sum /*[4]*/,
                                            uint32_t *const res) {
@@ -191,6 +193,8 @@ static INLINE void sad_4096_pel_final_neon(const uint16x8_t *sum /*[8]*/,
 #endif
 }
 
+#endif
+
 static INLINE void sad8x_4d(const uint8_t *src_ptr, int src_stride,
                             const uint8_t *const ref_array[4], int ref_stride,
                             uint32_t *res, const int height) {
@@ -233,6 +237,41 @@ void vpx_sad8x16x4d_neon(const uint8_t *src_ptr, int src_stride,
 
 ////////////////////////////////////////////////////////////////////////////////
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD) && \
+    (__ARM_FEATURE_DOTPROD == 1)
+
+static INLINE void sad16_neon(const uint8_t *ref_ptr, const uint8x16_t src_ptr,
+                              uint32x4_t *const sum) {
+  const uint8x16_t r = vld1q_u8(ref_ptr);
+  const uint8x16_t diff = vabdq_u8(src_ptr, r);
+  *sum = vdotq_u32(*sum, diff, vdupq_n_u8(1));
+}
+
+static INLINE void sad16x_4d(const uint8_t *src_ptr, int src_stride,
+                             const uint8_t *const ref_array[4], int ref_stride,
+                             uint32_t *res, const int height) {
+  int i;
+  uint32x4_t r0, r1;
+  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
+                                 ref_array[3] };
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
+
+  for (i = 0; i < height; ++i) {
+    const uint8x16_t s = vld1q_u8(src_ptr + i * src_stride);
+    sad16_neon(ref_loop[0] + i * ref_stride, s, &sum[0]);
+    sad16_neon(ref_loop[1] + i * ref_stride, s, &sum[1]);
+    sad16_neon(ref_loop[2] + i * ref_stride, s, &sum[2]);
+    sad16_neon(ref_loop[3] + i * ref_stride, s, &sum[3]);
+  }
+
+  r0 = vpaddq_u32(sum[0], sum[1]);
+  r1 = vpaddq_u32(sum[2], sum[3]);
+  vst1q_u32(res, vpaddq_u32(r0, r1));
+}
+
+#else
+
 static INLINE void sad16_neon(const uint8_t *ref_ptr, const uint8x16_t src_ptr,
                               uint16x8_t *const sum) {
   const uint8x16_t r = vld1q_u8(ref_ptr);
@@ -266,6 +305,8 @@ static INLINE void sad16x_4d(const uint8_t *src_ptr, int src_stride,
   sad_512_pel_final_neon(sum, res);
 }
 
+#endif
+
 void vpx_sad16x8x4d_neon(const uint8_t *src_ptr, int src_stride,
                          const uint8_t *const ref_array[4], int ref_stride,
                          uint32_t *res) {
@@ -286,6 +327,67 @@ void vpx_sad16x32x4d_neon(const uint8_t *src_ptr, int src_stride,
 
 ////////////////////////////////////////////////////////////////////////////////
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD) && \
+    (__ARM_FEATURE_DOTPROD == 1)
+
+static INLINE void sad32x_4d(const uint8_t *src_ptr, int src_stride,
+                             const uint8_t *const ref_array[4], int ref_stride,
+                             uint32_t *res, const int height) {
+  int i;
+  uint32x4_t r0, r1;
+  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
+                                 ref_array[3] };
+
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
+
+  for (i = 0; i < height; ++i) {
+    uint8x16_t s;
+
+    s = vld1q_u8(src_ptr + 0 * 16);
+    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
+
+    s = vld1q_u8(src_ptr + 1 * 16);
+    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
+
+    src_ptr += src_stride;
+    ref_loop[0] += ref_stride;
+    ref_loop[1] += ref_stride;
+    ref_loop[2] += ref_stride;
+    ref_loop[3] += ref_stride;
+  }
+
+  r0 = vpaddq_u32(sum[0], sum[1]);
+  r1 = vpaddq_u32(sum[2], sum[3]);
+  vst1q_u32(res, vpaddq_u32(r0, r1));
+}
+
+void vpx_sad32x16x4d_neon(const uint8_t *src_ptr, int src_stride,
+                          const uint8_t *const ref_array[4], int ref_stride,
+                          uint32_t *res) {
+  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, res, 16);
+}
+
+void vpx_sad32x32x4d_neon(const uint8_t *src_ptr, int src_stride,
+                          const uint8_t *const ref_array[4], int ref_stride,
+                          uint32_t *res) {
+  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, res, 32);
+}
+
+void vpx_sad32x64x4d_neon(const uint8_t *src_ptr, int src_stride,
+                          const uint8_t *const ref_array[4], int ref_stride,
+                          uint32_t *res) {
+  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, res, 64);
+}
+
+#else
+
 static INLINE void sad32x_4d(const uint8_t *src_ptr, int src_stride,
                              const uint8_t *const ref_array[4], int ref_stride,
                              const int height, uint16x8_t *const sum) {
@@ -342,8 +444,118 @@ void vpx_sad32x64x4d_neon(const uint8_t *src_ptr, int src_stride,
   sad_2048_pel_final_neon(sum, res);
 }
 
+#endif
+
 ////////////////////////////////////////////////////////////////////////////////
 
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD) && \
+    (__ARM_FEATURE_DOTPROD == 1)
+
+void vpx_sad64x32x4d_neon(const uint8_t *src_ptr, int src_stride,
+                          const uint8_t *const ref_array[4], int ref_stride,
+                          uint32_t *res) {
+  int i;
+  uint32x4_t r0, r1;
+  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
+                                 ref_array[3] };
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
+
+  for (i = 0; i < 32; ++i) {
+    uint8x16_t s;
+
+    s = vld1q_u8(src_ptr + 0 * 16);
+    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
+
+    s = vld1q_u8(src_ptr + 1 * 16);
+    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
+
+    s = vld1q_u8(src_ptr + 2 * 16);
+    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[3]);
+
+    s = vld1q_u8(src_ptr + 3 * 16);
+    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[3]);
+
+    src_ptr += src_stride;
+    ref_loop[0] += ref_stride;
+    ref_loop[1] += ref_stride;
+    ref_loop[2] += ref_stride;
+    ref_loop[3] += ref_stride;
+  }
+
+  r0 = vpaddq_u32(sum[0], sum[1]);
+  r1 = vpaddq_u32(sum[2], sum[3]);
+  vst1q_u32(res, vpaddq_u32(r0, r1));
+}
+
+void vpx_sad64x64x4d_neon(const uint8_t *src_ptr, int src_stride,
+                          const uint8_t *const ref_array[4], int ref_stride,
+                          uint32_t *res) {
+  int i;
+  uint32x4_t r0, r1, r2, r3;
+  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
+                                 ref_array[3] };
+  uint32x4_t sum[8] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  for (i = 0; i < 64; ++i) {
+    uint8x16_t s;
+
+    s = vld1q_u8(src_ptr + 0 * 16);
+    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[4]);
+    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[6]);
+
+    s = vld1q_u8(src_ptr + 1 * 16);
+    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
+    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[2]);
+    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[4]);
+    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[6]);
+
+    s = vld1q_u8(src_ptr + 2 * 16);
+    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[3]);
+    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[5]);
+    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[7]);
+
+    s = vld1q_u8(src_ptr + 3 * 16);
+    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[1]);
+    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[3]);
+    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[5]);
+    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[7]);
+
+    src_ptr += src_stride;
+    ref_loop[0] += ref_stride;
+    ref_loop[1] += ref_stride;
+    ref_loop[2] += ref_stride;
+    ref_loop[3] += ref_stride;
+  }
+
+  r0 = vpaddq_u32(sum[0], sum[1]);
+  r1 = vpaddq_u32(sum[2], sum[3]);
+  r2 = vpaddq_u32(sum[4], sum[5]);
+  r3 = vpaddq_u32(sum[6], sum[7]);
+  r0 = vpaddq_u32(r0, r1);
+  r1 = vpaddq_u32(r2, r3);
+  vst1q_u32(res, vpaddq_u32(r0, r1));
+}
+
+#else
+
 void vpx_sad64x32x4d_neon(const uint8_t *src_ptr, int src_stride,
                           const uint8_t *const ref_array[4], int ref_stride,
                           uint32_t *res) {
@@ -436,3 +648,5 @@ void vpx_sad64x64x4d_neon(const uint8_t *src_ptr, int src_stride,
 
   sad_4096_pel_final_neon(sum, res);
 }
+
+#endif