]> granicus.if.org Git - libvpx/commitdiff
Refactor Neon implementation of SAD4D functions
authorSalome Thirot <salome.thirot@arm.com>
Fri, 27 Jan 2023 16:16:16 +0000 (16:16 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Mon, 30 Jan 2023 13:14:54 +0000 (13:14 +0000)
Refactor and optimize the Neon implementation of SAD4D functions -
effectively backporting these libaom changes[1,2].

[1] https://aomedia-review.googlesource.com/c/aom/+/162181
[2] https://aomedia-review.googlesource.com/c/aom/+/162183

Change-Id: Icb04bd841d86f2d0e2596aa7ba86b74f8d2d360b

vpx_dsp/arm/sad4d_neon.c
vpx_dsp/arm/sum_neon.h

index 5fc621aee186ec0c9ea48312d8c2f55d446996e7..5064770ee677034bbd90c78d7964aa4996c51f25 100644 (file)
 #include "vpx_dsp/arm/mem_neon.h"
 #include "vpx_dsp/arm/sum_neon.h"
 
-static INLINE uint8x8_t load_unaligned_2_buffers(const void *const buf0,
-                                                 const void *const buf1) {
-  uint32_t a;
-  uint32x2_t aa;
-  memcpy(&a, buf0, 4);
-  aa = vdup_n_u32(a);
-  memcpy(&a, buf1, 4);
-  aa = vset_lane_u32(a, aa, 1);
-  return vreinterpret_u8_u32(aa);
-}
-
-static INLINE void sad4x_4d(const uint8_t *const src_ptr, const int src_stride,
-                            const uint8_t *const ref_array[4],
-                            const int ref_stride, const int height,
-                            uint32_t sad_array[4]) {
-  int i;
-  uint16x8_t abs[2] = { vdupq_n_u16(0), vdupq_n_u16(0) };
-#if !defined(__aarch64__)
-  uint16x4_t a[2];
-#endif
-  uint32x4_t r;
-
-  assert(!((intptr_t)src_ptr % sizeof(uint32_t)));
-  assert(!(src_stride % sizeof(uint32_t)));
-
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t s = vreinterpret_u8_u32(
-        vld1_dup_u32((const uint32_t *)(src_ptr + i * src_stride)));
-    const uint8x8_t ref01 = load_unaligned_2_buffers(
-        ref_array[0] + i * ref_stride, ref_array[1] + i * ref_stride);
-    const uint8x8_t ref23 = load_unaligned_2_buffers(
-        ref_array[2] + i * ref_stride, ref_array[3] + i * ref_stride);
-    abs[0] = vabal_u8(abs[0], s, ref01);
-    abs[1] = vabal_u8(abs[1], s, ref23);
-  }
-
-#if defined(__aarch64__)
-  abs[0] = vpaddq_u16(abs[0], abs[1]);
-  r = vpaddlq_u16(abs[0]);
-#else
-  a[0] = vpadd_u16(vget_low_u16(abs[0]), vget_high_u16(abs[0]));
-  a[1] = vpadd_u16(vget_low_u16(abs[1]), vget_high_u16(abs[1]));
-  r = vpaddlq_u16(vcombine_u16(a[0], a[1]));
-#endif
-  vst1q_u32(sad_array, r);
-}
-
-void vpx_sad4x4x4d_neon(const uint8_t *src_ptr, int src_stride,
-                        const uint8_t *const ref_array[4], int ref_stride,
-                        uint32_t sad_array[4]) {
-  sad4x_4d(src_ptr, src_stride, ref_array, ref_stride, 4, sad_array);
-}
-
-void vpx_sad4x8x4d_neon(const uint8_t *src_ptr, int src_stride,
-                        const uint8_t *const ref_array[4], int ref_stride,
-                        uint32_t sad_array[4]) {
-  sad4x_4d(src_ptr, src_stride, ref_array, ref_stride, 8, sad_array);
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-// Can handle 512 pixels' sad sum (such as 16x32 or 32x16)
-static INLINE void sad_512_pel_final_neon(const uint16x8_t sum[4],
-                                          uint32_t sad_array[4]) {
-#if defined(__aarch64__)
-  const uint16x8_t a0 = vpaddq_u16(sum[0], sum[1]);
-  const uint16x8_t a1 = vpaddq_u16(sum[2], sum[3]);
-  const uint16x8_t b0 = vpaddq_u16(a0, a1);
-  const uint32x4_t r = vpaddlq_u16(b0);
-#else
-  const uint16x4_t a0 = vadd_u16(vget_low_u16(sum[0]), vget_high_u16(sum[0]));
-  const uint16x4_t a1 = vadd_u16(vget_low_u16(sum[1]), vget_high_u16(sum[1]));
-  const uint16x4_t a2 = vadd_u16(vget_low_u16(sum[2]), vget_high_u16(sum[2]));
-  const uint16x4_t a3 = vadd_u16(vget_low_u16(sum[3]), vget_high_u16(sum[3]));
-  const uint16x4_t b0 = vpadd_u16(a0, a1);
-  const uint16x4_t b1 = vpadd_u16(a2, a3);
-  const uint32x4_t r = vpaddlq_u16(vcombine_u16(b0, b1));
-#endif
-  vst1q_u32(sad_array, r);
-}
+#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
-#if defined(__arm__) || !defined(__ARM_FEATURE_DOTPROD)
-
-// Can handle 1024 pixels' sad sum (such as 32x32)
-static INLINE void sad_1024_pel_final_neon(const uint16x8_t sum[4],
-                                           uint32_t sad_array[4]) {
-#if defined(__aarch64__)
-  const uint16x8_t a0 = vpaddq_u16(sum[0], sum[1]);
-  const uint16x8_t a1 = vpaddq_u16(sum[2], sum[3]);
-  const uint32x4_t b0 = vpaddlq_u16(a0);
-  const uint32x4_t b1 = vpaddlq_u16(a1);
-  const uint32x4_t r = vpaddq_u32(b0, b1);
-  vst1q_u32(sad_array, r);
-#else
-  const uint16x4_t a0 = vpadd_u16(vget_low_u16(sum[0]), vget_high_u16(sum[0]));
-  const uint16x4_t a1 = vpadd_u16(vget_low_u16(sum[1]), vget_high_u16(sum[1]));
-  const uint16x4_t a2 = vpadd_u16(vget_low_u16(sum[2]), vget_high_u16(sum[2]));
-  const uint16x4_t a3 = vpadd_u16(vget_low_u16(sum[3]), vget_high_u16(sum[3]));
-  const uint32x4_t b0 = vpaddlq_u16(vcombine_u16(a0, a1));
-  const uint32x4_t b1 = vpaddlq_u16(vcombine_u16(a2, a3));
-  const uint32x2_t c0 = vpadd_u32(vget_low_u32(b0), vget_high_u32(b0));
-  const uint32x2_t c1 = vpadd_u32(vget_low_u32(b1), vget_high_u32(b1));
-  vst1q_u32(sad_array, vcombine_u32(c0, c1));
-#endif
+static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
+                              uint32x4_t *const sad_sum) {
+  uint8x16_t abs_diff = vabdq_u8(src, ref);
+  *sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1));
 }
 
-// Can handle 2048 pixels' sad sum (such as 32x64 or 64x32)
-static INLINE void sad_2048_pel_final_neon(const uint16x8_t sum[4],
-                                           uint32_t sad_array[4]) {
-#if defined(__aarch64__)
-  const uint32x4_t a0 = vpaddlq_u16(sum[0]);
-  const uint32x4_t a1 = vpaddlq_u16(sum[1]);
-  const uint32x4_t a2 = vpaddlq_u16(sum[2]);
-  const uint32x4_t a3 = vpaddlq_u16(sum[3]);
-  const uint32x4_t b0 = vpaddq_u32(a0, a1);
-  const uint32x4_t b1 = vpaddq_u32(a2, a3);
-  const uint32x4_t r = vpaddq_u32(b0, b1);
-  vst1q_u32(sad_array, r);
-#else
-  const uint32x4_t a0 = vpaddlq_u16(sum[0]);
-  const uint32x4_t a1 = vpaddlq_u16(sum[1]);
-  const uint32x4_t a2 = vpaddlq_u16(sum[2]);
-  const uint32x4_t a3 = vpaddlq_u16(sum[3]);
-  const uint32x2_t b0 = vadd_u32(vget_low_u32(a0), vget_high_u32(a0));
-  const uint32x2_t b1 = vadd_u32(vget_low_u32(a1), vget_high_u32(a1));
-  const uint32x2_t b2 = vadd_u32(vget_low_u32(a2), vget_high_u32(a2));
-  const uint32x2_t b3 = vadd_u32(vget_low_u32(a3), vget_high_u32(a3));
-  const uint32x2_t c0 = vpadd_u32(b0, b1);
-  const uint32x2_t c1 = vpadd_u32(b2, b3);
-  vst1q_u32(sad_array, vcombine_u32(c0, c1));
-#endif
+static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint32x4_t res0, res1;
+  uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                           vdupq_n_u32(0) };
+  uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                           vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    uint8x16_t s0, s1, s2, s3;
+
+    s0 = vld1q_u8(src + i * src_stride);
+    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+    s1 = vld1q_u8(src + i * src_stride + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+    s2 = vld1q_u8(src + i * src_stride + 32);
+    sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
+    sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
+    sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
+    sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+
+    s3 = vld1q_u8(src + i * src_stride + 48);
+    sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
+    sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
+    sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
+    sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+
+    i++;
+  } while (i < h);
+
+  res0 = vpaddq_u32(vaddq_u32(sum_lo[0], sum_hi[0]),
+                    vaddq_u32(sum_lo[1], sum_hi[1]));
+  res1 = vpaddq_u32(vaddq_u32(sum_lo[2], sum_hi[2]),
+                    vaddq_u32(sum_lo[3], sum_hi[3]));
+  vst1q_u32(res, vpaddq_u32(res0, res1));
 }
 
-// Can handle 4096 pixels' sad sum (such as 64x64)
-static INLINE void sad_4096_pel_final_neon(const uint16x8_t sum[8],
-                                           uint32_t sad_array[4]) {
-#if defined(__aarch64__)
-  const uint32x4_t a0 = vpaddlq_u16(sum[0]);
-  const uint32x4_t a1 = vpaddlq_u16(sum[1]);
-  const uint32x4_t a2 = vpaddlq_u16(sum[2]);
-  const uint32x4_t a3 = vpaddlq_u16(sum[3]);
-  const uint32x4_t a4 = vpaddlq_u16(sum[4]);
-  const uint32x4_t a5 = vpaddlq_u16(sum[5]);
-  const uint32x4_t a6 = vpaddlq_u16(sum[6]);
-  const uint32x4_t a7 = vpaddlq_u16(sum[7]);
-  const uint32x4_t b0 = vaddq_u32(a0, a1);
-  const uint32x4_t b1 = vaddq_u32(a2, a3);
-  const uint32x4_t b2 = vaddq_u32(a4, a5);
-  const uint32x4_t b3 = vaddq_u32(a6, a7);
-  const uint32x4_t c0 = vpaddq_u32(b0, b1);
-  const uint32x4_t c1 = vpaddq_u32(b2, b3);
-  const uint32x4_t r = vpaddq_u32(c0, c1);
-  vst1q_u32(sad_array, r);
-#else
-  const uint32x4_t a0 = vpaddlq_u16(sum[0]);
-  const uint32x4_t a1 = vpaddlq_u16(sum[1]);
-  const uint32x4_t a2 = vpaddlq_u16(sum[2]);
-  const uint32x4_t a3 = vpaddlq_u16(sum[3]);
-  const uint32x4_t a4 = vpaddlq_u16(sum[4]);
-  const uint32x4_t a5 = vpaddlq_u16(sum[5]);
-  const uint32x4_t a6 = vpaddlq_u16(sum[6]);
-  const uint32x4_t a7 = vpaddlq_u16(sum[7]);
-  const uint32x4_t b0 = vaddq_u32(a0, a1);
-  const uint32x4_t b1 = vaddq_u32(a2, a3);
-  const uint32x4_t b2 = vaddq_u32(a4, a5);
-  const uint32x4_t b3 = vaddq_u32(a6, a7);
-  const uint32x2_t c0 = vadd_u32(vget_low_u32(b0), vget_high_u32(b0));
-  const uint32x2_t c1 = vadd_u32(vget_low_u32(b1), vget_high_u32(b1));
-  const uint32x2_t c2 = vadd_u32(vget_low_u32(b2), vget_high_u32(b2));
-  const uint32x2_t c3 = vadd_u32(vget_low_u32(b3), vget_high_u32(b3));
-  const uint32x2_t d0 = vpadd_u32(c0, c1);
-  const uint32x2_t d1 = vpadd_u32(c2, c3);
-  vst1q_u32(sad_array, vcombine_u32(d0, d1));
-#endif
+static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint32x4_t res0, res1;
+  uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                           vdupq_n_u32(0) };
+  uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                           vdupq_n_u32(0) };
+
+  int i = 0;
+  do {
+    uint8x16_t s0, s1;
+
+    s0 = vld1q_u8(src + i * src_stride);
+    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+    s1 = vld1q_u8(src + i * src_stride + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+    i++;
+  } while (i < h);
+
+  res0 = vpaddq_u32(vaddq_u32(sum_lo[0], sum_hi[0]),
+                    vaddq_u32(sum_lo[1], sum_hi[1]));
+  res1 = vpaddq_u32(vaddq_u32(sum_lo[2], sum_hi[2]),
+                    vaddq_u32(sum_lo[3], sum_hi[3]));
+  vst1q_u32(res, vpaddq_u32(res0, res1));
 }
 
-#endif
-
-static INLINE void sad8x_4d(const uint8_t *src_ptr, int src_stride,
-                            const uint8_t *const ref_array[4], int ref_stride,
-                            uint32_t sad_array[4], const int height) {
-  int i, j;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
-
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t s = vld1_u8(src_ptr);
-    src_ptr += src_stride;
-    for (j = 0; j < 4; ++j) {
-      const uint8x8_t b_u8 = vld1_u8(ref_loop[j]);
-      ref_loop[j] += ref_stride;
-      sum[j] = vabal_u8(sum[j], s, b_u8);
-    }
-  }
-
-  sad_512_pel_final_neon(sum, sad_array);
-}
+static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint32x4_t res0, res1;
+  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
+                        vdupq_n_u32(0) };
 
-void vpx_sad8x4x4d_neon(const uint8_t *src_ptr, int src_stride,
-                        const uint8_t *const ref_array[4], int ref_stride,
-                        uint32_t sad_array[4]) {
-  sad8x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 4);
-}
+  int i = 0;
+  do {
+    const uint8x16_t s = vld1q_u8(src + i * src_stride);
+    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
 
-void vpx_sad8x8x4d_neon(const uint8_t *src_ptr, int src_stride,
-                        const uint8_t *const ref_array[4], int ref_stride,
-                        uint32_t sad_array[4]) {
-  sad8x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 8);
-}
+    i++;
+  } while (i < h);
 
-void vpx_sad8x16x4d_neon(const uint8_t *src_ptr, int src_stride,
-                         const uint8_t *const ref_array[4], int ref_stride,
-                         uint32_t sad_array[4]) {
-  sad8x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 16);
+  res0 = vpaddq_u32(sum[0], sum[1]);
+  res1 = vpaddq_u32(sum[2], sum[3]);
+  vst1q_u32(res, vpaddq_u32(res0, res1));
 }
 
-////////////////////////////////////////////////////////////////////////////////
-
-#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
 
-static INLINE void sad16_neon(const uint8_t *ref_ptr, const uint8x16_t src_ptr,
-                              uint32x4_t *const sum) {
-  const uint8x16_t r = vld1q_u8(ref_ptr);
-  const uint8x16_t diff = vabdq_u8(src_ptr, r);
-  *sum = vdotq_u32(*sum, diff, vdupq_n_u8(1));
+static INLINE void sad16_neon(uint8x16_t src, uint8x16_t ref,
+                              uint16x8_t *const sad_sum) {
+  uint8x16_t abs_diff = vabdq_u8(src, ref);
+  *sad_sum = vpadalq_u8(*sad_sum, abs_diff);
 }
 
-static INLINE void sad16x_4d(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *const ref_array[4], int ref_stride,
-                             uint32_t sad_array[4], const int height) {
-  int i;
-  uint32x4_t r0, r1;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                        vdupq_n_u32(0) };
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t s = vld1q_u8(src_ptr + i * src_stride);
-    sad16_neon(ref_loop[0] + i * ref_stride, s, &sum[0]);
-    sad16_neon(ref_loop[1] + i * ref_stride, s, &sum[1]);
-    sad16_neon(ref_loop[2] + i * ref_stride, s, &sum[2]);
-    sad16_neon(ref_loop[3] + i * ref_stride, s, &sum[3]);
-  }
-
-  r0 = vpaddq_u32(sum[0], sum[1]);
-  r1 = vpaddq_u32(sum[2], sum[3]);
-  vst1q_u32(sad_array, vpaddq_u32(r0, r1));
+static INLINE void sad64xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  int h_tmp = h > 64 ? 64 : h;
+  int i = 0;
+  vst1q_u32(res, vdupq_n_u32(0));
+
+  do {
+    uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+    uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                             vdupq_n_u16(0) };
+
+    do {
+      uint8x16_t s0, s1, s2, s3;
+
+      s0 = vld1q_u8(src + i * src_stride);
+      sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+      sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+      sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+      sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+      s1 = vld1q_u8(src + i * src_stride + 16);
+      sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+      sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+      sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+      sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+      s2 = vld1q_u8(src + i * src_stride + 32);
+      sad16_neon(s2, vld1q_u8(ref[0] + i * ref_stride + 32), &sum_lo[0]);
+      sad16_neon(s2, vld1q_u8(ref[1] + i * ref_stride + 32), &sum_lo[1]);
+      sad16_neon(s2, vld1q_u8(ref[2] + i * ref_stride + 32), &sum_lo[2]);
+      sad16_neon(s2, vld1q_u8(ref[3] + i * ref_stride + 32), &sum_lo[3]);
+
+      s3 = vld1q_u8(src + i * src_stride + 48);
+      sad16_neon(s3, vld1q_u8(ref[0] + i * ref_stride + 48), &sum_hi[0]);
+      sad16_neon(s3, vld1q_u8(ref[1] + i * ref_stride + 48), &sum_hi[1]);
+      sad16_neon(s3, vld1q_u8(ref[2] + i * ref_stride + 48), &sum_hi[2]);
+      sad16_neon(s3, vld1q_u8(ref[3] + i * ref_stride + 48), &sum_hi[3]);
+
+      i++;
+    } while (i < h_tmp);
+
+    res[0] += horizontal_long_add_uint16x8(sum_lo[0], sum_hi[0]);
+    res[1] += horizontal_long_add_uint16x8(sum_lo[1], sum_hi[1]);
+    res[2] += horizontal_long_add_uint16x8(sum_lo[2], sum_hi[2]);
+    res[3] += horizontal_long_add_uint16x8(sum_lo[3], sum_hi[3]);
+
+    h_tmp += 64;
+  } while (i < h);
 }
 
-#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
-
-static INLINE void sad16_neon(const uint8_t *ref_ptr, const uint8x16_t src_ptr,
-                              uint16x8_t *const sum) {
-  const uint8x16_t r = vld1q_u8(ref_ptr);
-  *sum = vabal_u8(*sum, vget_low_u8(src_ptr), vget_low_u8(r));
-  *sum = vabal_u8(*sum, vget_high_u8(src_ptr), vget_high_u8(r));
+static INLINE void sad32xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
+  uint16x8_t sum_lo[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
+  uint16x8_t sum_hi[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                           vdupq_n_u16(0) };
+
+  int i = 0;
+  do {
+    uint8x16_t s0, s1;
+
+    s0 = vld1q_u8(src + i * src_stride);
+    sad16_neon(s0, vld1q_u8(ref[0] + i * ref_stride), &sum_lo[0]);
+    sad16_neon(s0, vld1q_u8(ref[1] + i * ref_stride), &sum_lo[1]);
+    sad16_neon(s0, vld1q_u8(ref[2] + i * ref_stride), &sum_lo[2]);
+    sad16_neon(s0, vld1q_u8(ref[3] + i * ref_stride), &sum_lo[3]);
+
+    s1 = vld1q_u8(src + i * src_stride + 16);
+    sad16_neon(s1, vld1q_u8(ref[0] + i * ref_stride + 16), &sum_hi[0]);
+    sad16_neon(s1, vld1q_u8(ref[1] + i * ref_stride + 16), &sum_hi[1]);
+    sad16_neon(s1, vld1q_u8(ref[2] + i * ref_stride + 16), &sum_hi[2]);
+    sad16_neon(s1, vld1q_u8(ref[3] + i * ref_stride + 16), &sum_hi[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_long_add_uint16x8(sum_lo[0], sum_hi[0]);
+  res[1] = horizontal_long_add_uint16x8(sum_lo[1], sum_hi[1]);
+  res[2] = horizontal_long_add_uint16x8(sum_lo[2], sum_hi[2]);
+  res[3] = horizontal_long_add_uint16x8(sum_lo[3], sum_hi[3]);
 }
 
-static INLINE void sad16x_4d(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *const ref_array[4], int ref_stride,
-                             uint32_t sad_array[4], const int height) {
-  int i;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
+static INLINE void sad16xhx4d_neon(const uint8_t *src, int src_stride,
+                                   const uint8_t *const ref[4], int ref_stride,
+                                   uint32_t res[4], int h) {
   uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
                         vdupq_n_u16(0) };
 
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t s = vld1q_u8(src_ptr);
-    src_ptr += src_stride;
-    /* Manual unrolling here stops the compiler from getting confused. */
-    sad16_neon(ref_loop[0], s, &sum[0]);
-    ref_loop[0] += ref_stride;
-    sad16_neon(ref_loop[1], s, &sum[1]);
-    ref_loop[1] += ref_stride;
-    sad16_neon(ref_loop[2], s, &sum[2]);
-    ref_loop[2] += ref_stride;
-    sad16_neon(ref_loop[3], s, &sum[3]);
-    ref_loop[3] += ref_stride;
-  }
-
-  sad_512_pel_final_neon(sum, sad_array);
+  int i = 0;
+  do {
+    const uint8x16_t s = vld1q_u8(src + i * src_stride);
+    sad16_neon(s, vld1q_u8(ref[0] + i * ref_stride), &sum[0]);
+    sad16_neon(s, vld1q_u8(ref[1] + i * ref_stride), &sum[1]);
+    sad16_neon(s, vld1q_u8(ref[2] + i * ref_stride), &sum[2]);
+    sad16_neon(s, vld1q_u8(ref[3] + i * ref_stride), &sum[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_add_uint16x8(sum[0]);
+  res[1] = horizontal_add_uint16x8(sum[1]);
+  res[2] = horizontal_add_uint16x8(sum[2]);
+  res[3] = horizontal_add_uint16x8(sum[3]);
 }
 
 #endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
 
-void vpx_sad16x8x4d_neon(const uint8_t *src_ptr, int src_stride,
-                         const uint8_t *const ref_array[4], int ref_stride,
-                         uint32_t sad_array[4]) {
-  sad16x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 8);
-}
-
-void vpx_sad16x16x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  sad16x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 16);
-}
-
-void vpx_sad16x32x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  sad16x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 32);
-}
-
-////////////////////////////////////////////////////////////////////////////////
-
-#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
-
-static INLINE void sad32x_4d(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *const ref_array[4], int ref_stride,
-                             uint32_t sad_array[4], const int height) {
-  int i;
-  uint32x4_t r0, r1;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-
-  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                        vdupq_n_u32(0) };
-
-  for (i = 0; i < height; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
-  }
-
-  r0 = vpaddq_u32(sum[0], sum[1]);
-  r1 = vpaddq_u32(sum[2], sum[3]);
-  vst1q_u32(sad_array, vpaddq_u32(r0, r1));
-}
-
-void vpx_sad32x16x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 16);
-}
-
-void vpx_sad32x32x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 32);
-}
-
-void vpx_sad32x64x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, sad_array, 64);
-}
-
-#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
-
-static INLINE void sad32x_4d(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *const ref_array[4], int ref_stride,
-                             const int height, uint16x8_t *const sum) {
-  int i;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-
-  sum[0] = sum[1] = sum[2] = sum[3] = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
-  }
+static INLINE void sad8_neon(uint8x8_t src, uint8x8_t ref,
+                             uint16x8_t *const sad_sum) {
+  uint8x8_t abs_diff = vabd_u8(src, ref);
+  *sad_sum = vaddw_u8(*sad_sum, abs_diff);
 }
 
-void vpx_sad32x16x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  uint16x8_t sum[4];
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, 16, sum);
-  sad_512_pel_final_neon(sum, sad_array);
-}
-
-void vpx_sad32x32x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  uint16x8_t sum[4];
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, 32, sum);
-  sad_1024_pel_final_neon(sum, sad_array);
-}
+static INLINE void sad8xhx4d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[4], int ref_stride,
+                                  uint32_t res[4], int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
 
-void vpx_sad32x64x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  uint16x8_t sum[4];
-  sad32x_4d(src_ptr, src_stride, ref_array, ref_stride, 64, sum);
-  sad_2048_pel_final_neon(sum, sad_array);
+  int i = 0;
+  do {
+    const uint8x8_t s = vld1_u8(src + i * src_stride);
+    sad8_neon(s, vld1_u8(ref[0] + i * ref_stride), &sum[0]);
+    sad8_neon(s, vld1_u8(ref[1] + i * ref_stride), &sum[1]);
+    sad8_neon(s, vld1_u8(ref[2] + i * ref_stride), &sum[2]);
+    sad8_neon(s, vld1_u8(ref[3] + i * ref_stride), &sum[3]);
+
+    i++;
+  } while (i < h);
+
+  res[0] = horizontal_add_uint16x8(sum[0]);
+  res[1] = horizontal_add_uint16x8(sum[1]);
+  res[2] = horizontal_add_uint16x8(sum[2]);
+  res[3] = horizontal_add_uint16x8(sum[3]);
 }
 
-#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
-
-////////////////////////////////////////////////////////////////////////////////
-
-#if defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
-
-void vpx_sad64x32x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  int i;
-  uint32x4_t r0, r1;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                        vdupq_n_u32(0) };
-
-  for (i = 0; i < 32; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 2 * 16);
-    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 3 * 16);
-    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[3]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
-  }
+static INLINE void sad4xhx4d_neon(const uint8_t *src, int src_stride,
+                                  const uint8_t *const ref[4], int ref_stride,
+                                  uint32_t res[4], int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
 
-  r0 = vpaddq_u32(sum[0], sum[1]);
-  r1 = vpaddq_u32(sum[2], sum[3]);
-  vst1q_u32(sad_array, vpaddq_u32(r0, r1));
+  int i = 0;
+  do {
+    uint32x2_t s, r0, r1, r2, r3;
+    uint32_t s_lo, s_hi, r0_lo, r0_hi, r1_lo, r1_hi, r2_lo, r2_hi, r3_lo, r3_hi;
+
+    memcpy(&s_lo, src + i * src_stride, 4);
+    memcpy(&r0_lo, ref[0] + i * ref_stride, 4);
+    memcpy(&r1_lo, ref[1] + i * ref_stride, 4);
+    memcpy(&r2_lo, ref[2] + i * ref_stride, 4);
+    memcpy(&r3_lo, ref[3] + i * ref_stride, 4);
+    s = vdup_n_u32(s_lo);
+    r0 = vdup_n_u32(r0_lo);
+    r1 = vdup_n_u32(r1_lo);
+    r2 = vdup_n_u32(r2_lo);
+    r3 = vdup_n_u32(r3_lo);
+
+    memcpy(&s_hi, src + (i + 1) * src_stride, 4);
+    memcpy(&r0_hi, ref[0] + (i + 1) * ref_stride, 4);
+    memcpy(&r1_hi, ref[1] + (i + 1) * ref_stride, 4);
+    memcpy(&r2_hi, ref[2] + (i + 1) * ref_stride, 4);
+    memcpy(&r3_hi, ref[3] + (i + 1) * ref_stride, 4);
+    s = vset_lane_u32(s_hi, s, 1);
+    r0 = vset_lane_u32(r0_hi, r0, 1);
+    r1 = vset_lane_u32(r1_hi, r1, 1);
+    r2 = vset_lane_u32(r2_hi, r2, 1);
+    r3 = vset_lane_u32(r3_hi, r3, 1);
+
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r0), &sum[0]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r1), &sum[1]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r2), &sum[2]);
+    sad8_neon(vreinterpret_u8_u32(s), vreinterpret_u8_u32(r3), &sum[3]);
+
+    i += 2;
+  } while (i < h);
+
+  res[0] = horizontal_add_uint16x8(sum[0]);
+  res[1] = horizontal_add_uint16x8(sum[1]);
+  res[2] = horizontal_add_uint16x8(sum[2]);
+  res[3] = horizontal_add_uint16x8(sum[3]);
 }
 
-void vpx_sad64x64x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  int i;
-  uint32x4_t r0, r1, r2, r3;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint32x4_t sum[8] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                        vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
-                        vdupq_n_u32(0), vdupq_n_u32(0) };
-
-  for (i = 0; i < 64; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[4]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[6]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[4]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[6]);
-
-    s = vld1q_u8(src_ptr + 2 * 16);
-    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[3]);
-    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[5]);
-    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[7]);
-
-    s = vld1q_u8(src_ptr + 3 * 16);
-    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[3]);
-    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[5]);
-    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[7]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
+#define SAD_WXH_4D_NEON(w, h)                                                  \
+  void vpx_sad##w##x##h##x4d_neon(const uint8_t *src, int src_stride,          \
+                                  const uint8_t *const ref[4], int ref_stride, \
+                                  uint32_t res[4]) {                           \
+    sad##w##xhx4d_neon(src, src_stride, ref, ref_stride, res, (h));            \
   }
 
-  r0 = vpaddq_u32(sum[0], sum[1]);
-  r1 = vpaddq_u32(sum[2], sum[3]);
-  r2 = vpaddq_u32(sum[4], sum[5]);
-  r3 = vpaddq_u32(sum[6], sum[7]);
-  r0 = vpaddq_u32(r0, r1);
-  r1 = vpaddq_u32(r2, r3);
-  vst1q_u32(sad_array, vpaddq_u32(r0, r1));
-}
+SAD_WXH_4D_NEON(4, 4)
+SAD_WXH_4D_NEON(4, 8)
 
-#else  // !(defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD))
+SAD_WXH_4D_NEON(8, 4)
+SAD_WXH_4D_NEON(8, 8)
+SAD_WXH_4D_NEON(8, 16)
 
-void vpx_sad64x32x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  int i;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0) };
+SAD_WXH_4D_NEON(16, 8)
+SAD_WXH_4D_NEON(16, 16)
+SAD_WXH_4D_NEON(16, 32)
 
-  for (i = 0; i < 32; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 2 * 16);
-    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[3]);
-
-    s = vld1q_u8(src_ptr + 3 * 16);
-    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[3]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
-  }
+SAD_WXH_4D_NEON(32, 16)
+SAD_WXH_4D_NEON(32, 32)
+SAD_WXH_4D_NEON(32, 64)
 
-  sad_2048_pel_final_neon(sum, sad_array);
-}
+SAD_WXH_4D_NEON(64, 32)
+SAD_WXH_4D_NEON(64, 64)
 
-void vpx_sad64x64x4d_neon(const uint8_t *src_ptr, int src_stride,
-                          const uint8_t *const ref_array[4], int ref_stride,
-                          uint32_t sad_array[4]) {
-  int i;
-  const uint8_t *ref_loop[4] = { ref_array[0], ref_array[1], ref_array[2],
-                                 ref_array[3] };
-  uint16x8_t sum[8] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
-                        vdupq_n_u16(0), vdupq_n_u16(0) };
-
-  for (i = 0; i < 64; ++i) {
-    uint8x16_t s;
-
-    s = vld1q_u8(src_ptr + 0 * 16);
-    sad16_neon(ref_loop[0] + 0 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 0 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[2] + 0 * 16, s, &sum[4]);
-    sad16_neon(ref_loop[3] + 0 * 16, s, &sum[6]);
-
-    s = vld1q_u8(src_ptr + 1 * 16);
-    sad16_neon(ref_loop[0] + 1 * 16, s, &sum[0]);
-    sad16_neon(ref_loop[1] + 1 * 16, s, &sum[2]);
-    sad16_neon(ref_loop[2] + 1 * 16, s, &sum[4]);
-    sad16_neon(ref_loop[3] + 1 * 16, s, &sum[6]);
-
-    s = vld1q_u8(src_ptr + 2 * 16);
-    sad16_neon(ref_loop[0] + 2 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[1] + 2 * 16, s, &sum[3]);
-    sad16_neon(ref_loop[2] + 2 * 16, s, &sum[5]);
-    sad16_neon(ref_loop[3] + 2 * 16, s, &sum[7]);
-
-    s = vld1q_u8(src_ptr + 3 * 16);
-    sad16_neon(ref_loop[0] + 3 * 16, s, &sum[1]);
-    sad16_neon(ref_loop[1] + 3 * 16, s, &sum[3]);
-    sad16_neon(ref_loop[2] + 3 * 16, s, &sum[5]);
-    sad16_neon(ref_loop[3] + 3 * 16, s, &sum[7]);
-
-    src_ptr += src_stride;
-    ref_loop[0] += ref_stride;
-    ref_loop[1] += ref_stride;
-    ref_loop[2] += ref_stride;
-    ref_loop[3] += ref_stride;
-  }
-
-  sad_4096_pel_final_neon(sum, sad_array);
-}
-
-#endif  // defined(__aarch64__) && defined(__ARM_FEATURE_DOTPROD)
+#undef SAD_WXH_4D_NEON
index 9a7c424e8e4af28d3773d07dae5bb218373bd47b..5f20f9d99a49e214b7a897f44250eaa467359a5b 100644 (file)
@@ -40,6 +40,23 @@ static INLINE uint32_t horizontal_add_uint16x8(const uint16x8_t a) {
 #endif
 }
 
+static INLINE uint32_t horizontal_long_add_uint16x8(const uint16x8_t vec_lo,
+                                                    const uint16x8_t vec_hi) {
+#if defined(__aarch64__)
+  return vaddlvq_u16(vec_lo) + vaddlvq_u16(vec_hi);
+#else
+  const uint32x4_t vec_l_lo =
+      vaddl_u16(vget_low_u16(vec_lo), vget_high_u16(vec_lo));
+  const uint32x4_t vec_l_hi =
+      vaddl_u16(vget_low_u16(vec_hi), vget_high_u16(vec_hi));
+  const uint32x4_t a = vaddq_u32(vec_l_lo, vec_l_hi);
+  const uint64x2_t b = vpaddlq_u32(a);
+  const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
+                                vreinterpret_u32_u64(vget_high_u64(b)));
+  return vget_lane_u32(c, 0);
+#endif
+}
+
 static INLINE int32_t horizontal_add_int32x2(const int32x2_t a) {
 #if defined(__aarch64__)
   return vaddv_s32(a);