]> granicus.if.org Git - libvpx/commitdiff
Refactor Neon implementation of SAD functions
authorSalome Thirot <salome.thirot@arm.com>
Tue, 24 Jan 2023 14:27:14 +0000 (14:27 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Wed, 25 Jan 2023 15:35:51 +0000 (15:35 +0000)
Refactor and optimize the Neon implementation of SAD functions -
effectively backporting these libaom changes[1,2,3].

[1] https://aomedia-review.googlesource.com/c/aom/+/161921
[2] https://aomedia-review.googlesource.com/c/aom/+/161923
[3] https://aomedia-review.googlesource.com/c/aom/+/166963

Change-Id: I2d72fd0f27d61a3e31a78acd33172e2afb044cb8

vpx_dsp/arm/sad_neon.c

index ad575d4aaef823e9b5b5bc6684b67a71990dbe47..7336edb69403a6375ac4ef1236afe3738e4d0b1f 100644 (file)
 #include "vpx_dsp/arm/mem_neon.h"
 #include "vpx_dsp/arm/sum_neon.h"
 
-uint32_t vpx_sad4x4_neon(const uint8_t *src_ptr, int src_stride,
-                         const uint8_t *ref_ptr, int ref_stride) {
-  const uint8x16_t src_u8 = load_unaligned_u8q(src_ptr, src_stride);
-  const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
 #if defined(__ARM_FEATURE_DOTPROD)
-  const uint8x16_t sad_u8 = vabdq_u8(src_u8, ref_u8);
-  const uint32x4_t dp = vdotq_u32(vdupq_n_u32(0), sad_u8, vdupq_n_u8(1));
-  return horizontal_add_uint32x4(dp);
-#else
-  uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(ref_u8));
-  abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8));
-  return horizontal_add_uint16x8(abs);
-#endif
-}
 
-uint32_t vpx_sad4x4_avg_neon(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *ref_ptr, int ref_stride,
-                             const uint8_t *second_pred) {
-  const uint8x16_t src_u8 = load_unaligned_u8q(src_ptr, src_stride);
-  const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
-  const uint8x16_t second_pred_u8 = vld1q_u8(second_pred);
-  const uint8x16_t avg = vrhaddq_u8(ref_u8, second_pred_u8);
-#if defined(__ARM_FEATURE_DOTPROD)
-  const uint8x16_t sad_u8 = vabdq_u8(src_u8, avg);
-  const uint32x4_t prod = vdotq_u32(vdupq_n_u32(0), sad_u8, vdupq_n_u8(1));
-  return horizontal_add_uint32x4(prod);
-#else
-  uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(avg));
-  abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg));
-  return horizontal_add_uint16x8(abs);
-#endif
-}
+static INLINE unsigned int sadwxh_neon(const uint8_t *src_ptr, int src_stride,
+                                       const uint8_t *ref_ptr, int ref_stride,
+                                       int w, int h) {
+  // Only two accumulators are required for optimal instruction throughput of
+  // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
+  uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
 
-uint32_t vpx_sad4x8_neon(const uint8_t *src_ptr, int src_stride,
-                         const uint8_t *ref_ptr, int ref_stride) {
-#if defined(__ARM_FEATURE_DOTPROD)
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  const uint8x16_t src1_u8 = load_unaligned_u8q(src_ptr, src_stride);
-  const uint8x16_t ref1_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
-  const uint8x16_t src2_u8 =
-      load_unaligned_u8q(src_ptr + 4 * src_stride, src_stride);
-  const uint8x16_t ref2_u8 =
-      load_unaligned_u8q(ref_ptr + 4 * ref_stride, ref_stride);
-  const uint8x16_t sad1_u8 = vabdq_u8(src1_u8, ref1_u8);
-  const uint8x16_t sad2_u8 = vabdq_u8(src2_u8, ref2_u8);
-  prod = vdotq_u32(prod, sad1_u8, ones);
-  prod = vdotq_u32(prod, sad2_u8, ones);
-  return horizontal_add_uint32x4(prod);
-#else
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-  for (i = 0; i < 8; i += 4) {
-    const uint8x16_t src_u8 = load_unaligned_u8q(src_ptr, src_stride);
-    const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
-    src_ptr += 4 * src_stride;
-    ref_ptr += 4 * ref_stride;
-    abs = vabal_u8(abs, vget_low_u8(src_u8), vget_low_u8(ref_u8));
-    abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8));
-  }
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      uint8x16_t s0, s1, r0, r1, diff0, diff1;
 
-  return horizontal_add_uint16x8(abs);
-#endif
+      s0 = vld1q_u8(src_ptr + j);
+      r0 = vld1q_u8(ref_ptr + j);
+      diff0 = vabdq_u8(s0, r0);
+      sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
+
+      s1 = vld1q_u8(src_ptr + j + 16);
+      r1 = vld1q_u8(ref_ptr + j + 16);
+      diff1 = vabdq_u8(s1, r1);
+      sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
+
+      j += 32;
+    } while (j < w);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sum[0], sum[1]));
 }
 
-uint32_t vpx_sad4x8_avg_neon(const uint8_t *src_ptr, int src_stride,
-                             const uint8_t *ref_ptr, int ref_stride,
-                             const uint8_t *second_pred) {
-#if defined(__ARM_FEATURE_DOTPROD)
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  const uint8x16_t src1_u8 = load_unaligned_u8q(src_ptr, src_stride);
-  const uint8x16_t ref1_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
-  const uint8x16_t src2_u8 =
-      load_unaligned_u8q(src_ptr + 4 * src_stride, src_stride);
-  const uint8x16_t ref2_u8 =
-      load_unaligned_u8q(ref_ptr + 4 * ref_stride, ref_stride);
-  const uint8x16_t second_pred1_u8 = vld1q_u8(second_pred);
-  const uint8x16_t second_pred2_u8 = vld1q_u8(second_pred + 16);
-  const uint8x16_t avg1 = vrhaddq_u8(ref1_u8, second_pred1_u8);
-  const uint8x16_t avg2 = vrhaddq_u8(ref2_u8, second_pred2_u8);
-  const uint8x16_t sad1_u8 = vabdq_u8(src1_u8, avg1);
-  const uint8x16_t sad2_u8 = vabdq_u8(src2_u8, avg2);
-  prod = vdotq_u32(prod, sad1_u8, ones);
-  prod = vdotq_u32(prod, sad2_u8, ones);
-  return horizontal_add_uint32x4(prod);
-#else
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-  for (i = 0; i < 8; i += 4) {
-    const uint8x16_t src_u8 = load_unaligned_u8q(src_ptr, src_stride);
-    const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
-    const uint8x16_t second_pred_u8 = vld1q_u8(second_pred);
-    const uint8x16_t avg = vrhaddq_u8(ref_u8, second_pred_u8);
-    src_ptr += 4 * src_stride;
-    ref_ptr += 4 * ref_stride;
-    second_pred += 16;
-    abs = vabal_u8(abs, vget_low_u8(src_u8), vget_low_u8(avg));
-    abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg));
-  }
+static INLINE unsigned int sad64xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  return sadwxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h);
+}
 
-  return horizontal_add_uint16x8(abs);
-#endif
+static INLINE unsigned int sad32xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  return sadwxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h);
 }
 
-#if defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint32x2_t sad8x(const uint8_t *src_ptr, int src_stride,
-                               const uint8_t *ref_ptr, int ref_stride,
-                               const int height) {
-  int i;
-  uint32x2_t prod = vdup_n_u32(0);
-  const uint8x8_t ones = vdup_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t a_u8 = vld1_u8(src_ptr);
-    const uint8x8_t b_u8 = vld1_u8(ref_ptr);
-    const uint8x8_t sad_u8 = vabd_u8(a_u8, b_u8);
+static INLINE unsigned int sad16xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h / 2;
+  do {
+    uint8x16_t s0, s1, r0, r1, diff0, diff1;
+
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    diff0 = vabdq_u8(s0, r0);
+    sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    prod = vdot_u32(prod, sad_u8, ones);
-  }
-  return prod;
-}
 
-static INLINE uint32x2_t sad8x_avg(const uint8_t *src_ptr, int src_stride,
-                                   const uint8_t *ref_ptr, int ref_stride,
-                                   const uint8_t *second_pred,
-                                   const int height) {
-  int i;
-  uint32x2_t prod = vdup_n_u32(0);
-  const uint8x8_t ones = vdup_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t a_u8 = vld1_u8(src_ptr);
-    const uint8x8_t b_u8 = vld1_u8(ref_ptr);
-    const uint8x8_t c_u8 = vld1_u8(second_pred);
-    const uint8x8_t avg = vrhadd_u8(b_u8, c_u8);
-    const uint8x8_t sad_u8 = vabd_u8(a_u8, avg);
+    s1 = vld1q_u8(src_ptr);
+    r1 = vld1q_u8(ref_ptr);
+    diff1 = vabdq_u8(s1, r1);
+    sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 8;
-    prod = vdot_u32(prod, sad_u8, ones);
-  }
-  return prod;
-}
+  } while (--i != 0);
 
-#define SAD8XN(n)                                                            \
-  uint32_t vpx_sad8x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                               const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint32x2_t prod =                                                  \
-        sad8x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint32x2(prod);                                    \
-  }                                                                          \
-                                                                             \
-  uint32_t vpx_sad8x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                   const uint8_t *ref_ptr, int ref_stride,   \
-                                   const uint8_t *second_pred) {             \
-    const uint32x2_t prod =                                                  \
-        sad8x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint32x2(prod);                                    \
-  }
+  return horizontal_add_uint32x4(vaddq_u32(sum[0], sum[1]));
+}
 
 #else  // !defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint16x8_t sad8x(const uint8_t *src_ptr, int src_stride,
-                               const uint8_t *ref_ptr, int ref_stride,
-                               const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t a_u8 = vld1_u8(src_ptr);
-    const uint8x8_t b_u8 = vld1_u8(ref_ptr);
+
+static INLINE unsigned int sad64xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
+  uint32x4_t sum_u32;
+
+  int i = h;
+  do {
+    uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3;
+    uint8x16_t diff0, diff1, diff2, diff3;
+
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    diff0 = vabdq_u8(s0, r0);
+    sum[0] = vpadalq_u8(sum[0], diff0);
+
+    s1 = vld1q_u8(src_ptr + 16);
+    r1 = vld1q_u8(ref_ptr + 16);
+    diff1 = vabdq_u8(s1, r1);
+    sum[1] = vpadalq_u8(sum[1], diff1);
+
+    s2 = vld1q_u8(src_ptr + 32);
+    r2 = vld1q_u8(ref_ptr + 32);
+    diff2 = vabdq_u8(s2, r2);
+    sum[2] = vpadalq_u8(sum[2], diff2);
+
+    s3 = vld1q_u8(src_ptr + 48);
+    r3 = vld1q_u8(ref_ptr + 48);
+    diff3 = vabdq_u8(s3, r3);
+    sum[3] = vpadalq_u8(sum[3], diff3);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    abs = vabal_u8(abs, a_u8, b_u8);
-  }
-  return abs;
+  } while (--i != 0);
+
+  sum_u32 = vpaddlq_u16(sum[0]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[1]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[2]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[3]);
+
+  return horizontal_add_uint32x4(sum_u32);
 }
 
-static INLINE uint16x8_t sad8x_avg(const uint8_t *src_ptr, int src_stride,
-                                   const uint8_t *ref_ptr, int ref_stride,
-                                   const uint8_t *second_pred,
-                                   const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x8_t a_u8 = vld1_u8(src_ptr);
-    const uint8x8_t b_u8 = vld1_u8(ref_ptr);
-    const uint8x8_t c_u8 = vld1_u8(second_pred);
-    const uint8x8_t avg = vrhadd_u8(b_u8, c_u8);
+static INLINE unsigned int sad32xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  uint32x4_t sum = vdupq_n_u32(0);
+
+  int i = h;
+  do {
+    uint8x16_t s0 = vld1q_u8(src_ptr);
+    uint8x16_t r0 = vld1q_u8(ref_ptr);
+    uint8x16_t diff0 = vabdq_u8(s0, r0);
+    uint16x8_t sum0 = vpaddlq_u8(diff0);
+
+    uint8x16_t s1 = vld1q_u8(src_ptr + 16);
+    uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
+    uint8x16_t diff1 = vabdq_u8(s1, r1);
+    uint16x8_t sum1 = vpaddlq_u8(diff1);
+
+    sum = vpadalq_u16(sum, sum0);
+    sum = vpadalq_u16(sum, sum1);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 8;
-    abs = vabal_u8(abs, a_u8, avg);
-  }
-  return abs;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(sum);
 }
 
-#define SAD8XN(n)                                                              \
-  uint32_t vpx_sad8x##n##_neon(const uint8_t *src_ptr, int src_stride,         \
-                               const uint8_t *ref_ptr, int ref_stride) {       \
-    const uint16x8_t abs = sad8x(src_ptr, src_stride, ref_ptr, ref_stride, n); \
-    return horizontal_add_uint16x8(abs);                                       \
-  }                                                                            \
-                                                                               \
-  uint32_t vpx_sad8x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,     \
-                                   const uint8_t *ref_ptr, int ref_stride,     \
-                                   const uint8_t *second_pred) {               \
-    const uint16x8_t abs =                                                     \
-        sad8x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n);   \
-    return horizontal_add_uint16x8(abs);                                       \
-  }
-#endif  // defined(__ARM_FEATURE_DOTPROD)
+static INLINE unsigned int sad16xh_neon(const uint8_t *src_ptr, int src_stride,
+                                        const uint8_t *ref_ptr, int ref_stride,
+                                        int h) {
+  uint16x8_t sum = vdupq_n_u16(0);
 
-SAD8XN(4)
-SAD8XN(8)
-SAD8XN(16)
+  int i = h;
+  do {
+    uint8x16_t s = vld1q_u8(src_ptr);
+    uint8x16_t r = vld1q_u8(ref_ptr);
+
+    uint8x16_t diff = vabdq_u8(s, r);
+    sum = vpadalq_u8(sum, diff);
 
-#if defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint32x4_t sad16x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t src_u8 = vld1q_u8(src_ptr);
-    const uint8x16_t ref_u8 = vld1q_u8(ref_ptr);
-    const uint8x16_t sad_u8 = vabdq_u8(src_u8, ref_u8);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    prod = vdotq_u32(prod, sad_u8, ones);
-  }
-  return prod;
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
 
-static INLINE uint32x4_t sad16x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_u8 = vld1q_u8(src_ptr);
-    const uint8x16_t b_u8 = vld1q_u8(ref_ptr);
-    const uint8x16_t c_u8 = vld1q_u8(second_pred);
-    const uint8x16_t avg = vrhaddq_u8(b_u8, c_u8);
-    const uint8x16_t sad_u8 = vabdq_u8(a_u8, avg);
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE unsigned int sad8xh_neon(const uint8_t *src_ptr, int src_stride,
+                                       const uint8_t *ref_ptr, int ref_stride,
+                                       int h) {
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = h;
+  do {
+    uint8x8_t s = vld1_u8(src_ptr);
+    uint8x8_t r = vld1_u8(ref_ptr);
+
+    sum = vabal_u8(sum, s, r);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 16;
-    prod = vdotq_u32(prod, sad_u8, ones);
-  }
-  return prod;
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
 
-#define SAD16XN(n)                                                            \
-  uint32_t vpx_sad16x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                                const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint32x4_t prod =                                                   \
-        sad16x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint32x4(prod);                                     \
-  }                                                                           \
-                                                                              \
-  uint32_t vpx_sad16x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                    const uint8_t *ref_ptr, int ref_stride,   \
-                                    const uint8_t *second_pred) {             \
-    const uint32x4_t prod =                                                   \
-        sad16x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint32x4(prod);                                     \
-  }
-#else  // !defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint16x8_t sad16x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_u8 = vld1q_u8(src_ptr);
-    const uint8x16_t b_u8 = vld1q_u8(ref_ptr);
+static INLINE unsigned int sad4xh_neon(const uint8_t *src_ptr, int src_stride,
+                                       const uint8_t *ref_ptr, int ref_stride,
+                                       int h) {
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = h / 2;
+  do {
+    uint32x2_t s, r;
+    uint32_t s0, s1, r0, r1;
+
+    memcpy(&s0, src_ptr, 4);
+    memcpy(&r0, ref_ptr, 4);
+    s = vdup_n_u32(s0);
+    r = vdup_n_u32(r0);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    abs = vabal_u8(abs, vget_low_u8(a_u8), vget_low_u8(b_u8));
-    abs = vabal_u8(abs, vget_high_u8(a_u8), vget_high_u8(b_u8));
-  }
-  return abs;
-}
 
-static INLINE uint16x8_t sad16x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_u8 = vld1q_u8(src_ptr);
-    const uint8x16_t b_u8 = vld1q_u8(ref_ptr);
-    const uint8x16_t c_u8 = vld1q_u8(second_pred);
-    const uint8x16_t avg = vrhaddq_u8(b_u8, c_u8);
+    memcpy(&s1, src_ptr, 4);
+    memcpy(&r1, ref_ptr, 4);
+    s = vset_lane_u32(s1, s, 1);
+    r = vset_lane_u32(r1, r, 1);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 16;
-    abs = vabal_u8(abs, vget_low_u8(a_u8), vget_low_u8(avg));
-    abs = vabal_u8(abs, vget_high_u8(a_u8), vget_high_u8(avg));
-  }
-  return abs;
+
+    sum = vabal_u8(sum, vreinterpret_u8_u32(s), vreinterpret_u8_u32(r));
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
 
-#define SAD16XN(n)                                                            \
-  uint32_t vpx_sad16x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                                const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint16x8_t abs =                                                    \
-        sad16x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint16x8(abs);                                      \
-  }                                                                           \
-                                                                              \
-  uint32_t vpx_sad16x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                    const uint8_t *ref_ptr, int ref_stride,   \
-                                    const uint8_t *second_pred) {             \
-    const uint16x8_t abs =                                                    \
-        sad16x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint16x8(abs);                                      \
+#define SAD_WXH_NEON(w, h)                                                   \
+  unsigned int vpx_sad##w##x##h##_neon(const uint8_t *src, int src_stride,   \
+                                       const uint8_t *ref, int ref_stride) { \
+    return sad##w##xh_neon(src, src_stride, ref, ref_stride, (h));           \
   }
-#endif  // defined(__ARM_FEATURE_DOTPROD)
 
-SAD16XN(8)
-SAD16XN(16)
-SAD16XN(32)
+SAD_WXH_NEON(4, 4)
+SAD_WXH_NEON(4, 8)
+
+SAD_WXH_NEON(8, 4)
+SAD_WXH_NEON(8, 8)
+SAD_WXH_NEON(8, 16)
+
+SAD_WXH_NEON(16, 8)
+SAD_WXH_NEON(16, 16)
+SAD_WXH_NEON(16, 32)
+
+SAD_WXH_NEON(32, 16)
+SAD_WXH_NEON(32, 32)
+SAD_WXH_NEON(32, 64)
+
+SAD_WXH_NEON(64, 32)
+SAD_WXH_NEON(64, 64)
 
 #if defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint32x4_t sad32x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_lo = vld1q_u8(src_ptr);
-    const uint8x16_t a_hi = vld1q_u8(src_ptr + 16);
-    const uint8x16_t b_lo = vld1q_u8(ref_ptr);
-    const uint8x16_t b_hi = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t sad_lo_u8 = vabdq_u8(a_lo, b_lo);
-    const uint8x16_t sad_hi_u8 = vabdq_u8(a_hi, b_hi);
+
+static INLINE unsigned int sadwxh_avg_neon(const uint8_t *src_ptr,
+                                           int src_stride,
+                                           const uint8_t *ref_ptr,
+                                           int ref_stride, int w, int h,
+                                           const uint8_t *second_pred) {
+  // Only two accumulators are required for optimal instruction throughput of
+  // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
+  uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
+
+      s0 = vld1q_u8(src_ptr + j);
+      r0 = vld1q_u8(ref_ptr + j);
+      p0 = vld1q_u8(second_pred);
+      avg0 = vrhaddq_u8(r0, p0);
+      diff0 = vabdq_u8(s0, avg0);
+      sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
+
+      s1 = vld1q_u8(src_ptr + j + 16);
+      r1 = vld1q_u8(ref_ptr + j + 16);
+      p1 = vld1q_u8(second_pred + 16);
+      avg1 = vrhaddq_u8(r1, p1);
+      diff1 = vabdq_u8(s1, avg1);
+      sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
+
+      j += 32;
+      second_pred += 32;
+    } while (j < w);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    prod = vdotq_u32(prod, sad_lo_u8, ones);
-    prod = vdotq_u32(prod, sad_hi_u8, ones);
-  }
-  return prod;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sum[0], sum[1]));
+}
+
+static INLINE unsigned int sad64xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  return sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 64, h,
+                         second_pred);
 }
 
-static INLINE uint32x4_t sad32x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_lo = vld1q_u8(src_ptr);
-    const uint8x16_t a_hi = vld1q_u8(src_ptr + 16);
-    const uint8x16_t b_lo = vld1q_u8(ref_ptr);
-    const uint8x16_t b_hi = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t c_lo = vld1q_u8(second_pred);
-    const uint8x16_t c_hi = vld1q_u8(second_pred + 16);
-    const uint8x16_t avg_lo = vrhaddq_u8(b_lo, c_lo);
-    const uint8x16_t avg_hi = vrhaddq_u8(b_hi, c_hi);
-    const uint8x16_t sad_lo_u8 = vabdq_u8(a_lo, avg_lo);
-    const uint8x16_t sad_hi_u8 = vabdq_u8(a_hi, avg_hi);
+static INLINE unsigned int sad32xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  return sadwxh_avg_neon(src_ptr, src_stride, ref_ptr, ref_stride, 32, h,
+                         second_pred);
+}
+
+static INLINE unsigned int sad16xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h / 2;
+  do {
+    uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
+
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    p0 = vld1q_u8(second_pred);
+    avg0 = vrhaddq_u8(r0, p0);
+    diff0 = vabdq_u8(s0, avg0);
+    sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 32;
-    prod = vdotq_u32(prod, sad_lo_u8, ones);
-    prod = vdotq_u32(prod, sad_hi_u8, ones);
-  }
-  return prod;
+    second_pred += 16;
+
+    s1 = vld1q_u8(src_ptr);
+    r1 = vld1q_u8(ref_ptr);
+    p1 = vld1q_u8(second_pred);
+    avg1 = vrhaddq_u8(r1, p1);
+    diff1 = vabdq_u8(s1, avg1);
+    sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+    second_pred += 16;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(vaddq_u32(sum[0], sum[1]));
 }
 
-#define SAD32XN(n)                                                            \
-  uint32_t vpx_sad32x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                                const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint32x4_t prod =                                                   \
-        sad32x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint32x4(prod);                                     \
-  }                                                                           \
-                                                                              \
-  uint32_t vpx_sad32x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                    const uint8_t *ref_ptr, int ref_stride,   \
-                                    const uint8_t *second_pred) {             \
-    const uint32x4_t prod =                                                   \
-        sad32x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint32x4(prod);                                     \
-  }
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE unsigned int sad64xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
+                        vdupq_n_u16(0) };
+  uint32x4_t sum_u32;
+
+  int i = h;
+  do {
+    uint8x16_t s0, s1, s2, s3, r0, r1, r2, r3, p0, p1, p2, p3;
+    uint8x16_t avg0, avg1, avg2, avg3, diff0, diff1, diff2, diff3;
+
+    s0 = vld1q_u8(src_ptr);
+    r0 = vld1q_u8(ref_ptr);
+    p0 = vld1q_u8(second_pred);
+    avg0 = vrhaddq_u8(r0, p0);
+    diff0 = vabdq_u8(s0, avg0);
+    sum[0] = vpadalq_u8(sum[0], diff0);
+
+    s1 = vld1q_u8(src_ptr + 16);
+    r1 = vld1q_u8(ref_ptr + 16);
+    p1 = vld1q_u8(second_pred + 16);
+    avg1 = vrhaddq_u8(r1, p1);
+    diff1 = vabdq_u8(s1, avg1);
+    sum[1] = vpadalq_u8(sum[1], diff1);
+
+    s2 = vld1q_u8(src_ptr + 32);
+    r2 = vld1q_u8(ref_ptr + 32);
+    p2 = vld1q_u8(second_pred + 32);
+    avg2 = vrhaddq_u8(r2, p2);
+    diff2 = vabdq_u8(s2, avg2);
+    sum[2] = vpadalq_u8(sum[2], diff2);
+
+    s3 = vld1q_u8(src_ptr + 48);
+    r3 = vld1q_u8(ref_ptr + 48);
+    p3 = vld1q_u8(second_pred + 48);
+    avg3 = vrhaddq_u8(r3, p3);
+    diff3 = vabdq_u8(s3, avg3);
+    sum[3] = vpadalq_u8(sum[3], diff3);
 
-#else  // defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint16x8_t sad32x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_lo = vld1q_u8(src_ptr);
-    const uint8x16_t a_hi = vld1q_u8(src_ptr + 16);
-    const uint8x16_t b_lo = vld1q_u8(ref_ptr);
-    const uint8x16_t b_hi = vld1q_u8(ref_ptr + 16);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    abs = vabal_u8(abs, vget_low_u8(a_lo), vget_low_u8(b_lo));
-    abs = vabal_u8(abs, vget_high_u8(a_lo), vget_high_u8(b_lo));
-    abs = vabal_u8(abs, vget_low_u8(a_hi), vget_low_u8(b_hi));
-    abs = vabal_u8(abs, vget_high_u8(a_hi), vget_high_u8(b_hi));
-  }
-  return abs;
+    second_pred += 64;
+  } while (--i != 0);
+
+  sum_u32 = vpaddlq_u16(sum[0]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[1]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[2]);
+  sum_u32 = vpadalq_u16(sum_u32, sum[3]);
+
+  return horizontal_add_uint32x4(sum_u32);
 }
 
-static INLINE uint16x8_t sad32x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint16x8_t abs = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_lo = vld1q_u8(src_ptr);
-    const uint8x16_t a_hi = vld1q_u8(src_ptr + 16);
-    const uint8x16_t b_lo = vld1q_u8(ref_ptr);
-    const uint8x16_t b_hi = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t c_lo = vld1q_u8(second_pred);
-    const uint8x16_t c_hi = vld1q_u8(second_pred + 16);
-    const uint8x16_t avg_lo = vrhaddq_u8(b_lo, c_lo);
-    const uint8x16_t avg_hi = vrhaddq_u8(b_hi, c_hi);
+static INLINE unsigned int sad32xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  uint32x4_t sum = vdupq_n_u32(0);
+
+  int i = h;
+  do {
+    uint8x16_t s0 = vld1q_u8(src_ptr);
+    uint8x16_t r0 = vld1q_u8(ref_ptr);
+    uint8x16_t p0 = vld1q_u8(second_pred);
+    uint8x16_t avg0 = vrhaddq_u8(r0, p0);
+    uint8x16_t diff0 = vabdq_u8(s0, avg0);
+    uint16x8_t sum0 = vpaddlq_u8(diff0);
+
+    uint8x16_t s1 = vld1q_u8(src_ptr + 16);
+    uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
+    uint8x16_t p1 = vld1q_u8(second_pred + 16);
+    uint8x16_t avg1 = vrhaddq_u8(r1, p1);
+    uint8x16_t diff1 = vabdq_u8(s1, avg1);
+    uint16x8_t sum1 = vpaddlq_u8(diff1);
+
+    sum = vpadalq_u16(sum, sum0);
+    sum = vpadalq_u16(sum, sum1);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
     second_pred += 32;
-    abs = vabal_u8(abs, vget_low_u8(a_lo), vget_low_u8(avg_lo));
-    abs = vabal_u8(abs, vget_high_u8(a_lo), vget_high_u8(avg_lo));
-    abs = vabal_u8(abs, vget_low_u8(a_hi), vget_low_u8(avg_hi));
-    abs = vabal_u8(abs, vget_high_u8(a_hi), vget_high_u8(avg_hi));
-  }
-  return abs;
+  } while (--i != 0);
+
+  return horizontal_add_uint32x4(sum);
 }
 
-#define SAD32XN(n)                                                            \
-  uint32_t vpx_sad32x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                                const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint16x8_t abs =                                                    \
-        sad32x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint16x8(abs);                                      \
-  }                                                                           \
-                                                                              \
-  uint32_t vpx_sad32x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                    const uint8_t *ref_ptr, int ref_stride,   \
-                                    const uint8_t *second_pred) {             \
-    const uint16x8_t abs =                                                    \
-        sad32x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint16x8(abs);                                      \
-  }
-#endif  // defined(__ARM_FEATURE_DOTPROD)
+static INLINE unsigned int sad16xh_avg_neon(const uint8_t *src_ptr,
+                                            int src_stride,
+                                            const uint8_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            const uint8_t *second_pred) {
+  uint16x8_t sum = vdupq_n_u16(0);
 
-SAD32XN(16)
-SAD32XN(32)
-SAD32XN(64)
+  int i = h;
+  do {
+    uint8x16_t s = vld1q_u8(src_ptr);
+    uint8x16_t r = vld1q_u8(ref_ptr);
+    uint8x16_t p = vld1q_u8(second_pred);
+
+    uint8x16_t avg = vrhaddq_u8(r, p);
+    uint8x16_t diff = vabdq_u8(s, avg);
+    sum = vpadalq_u8(sum, diff);
 
-#if defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint32x4_t sad64x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_0 = vld1q_u8(src_ptr);
-    const uint8x16_t a_1 = vld1q_u8(src_ptr + 16);
-    const uint8x16_t a_2 = vld1q_u8(src_ptr + 32);
-    const uint8x16_t a_3 = vld1q_u8(src_ptr + 48);
-    const uint8x16_t b_0 = vld1q_u8(ref_ptr);
-    const uint8x16_t b_1 = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t b_2 = vld1q_u8(ref_ptr + 32);
-    const uint8x16_t b_3 = vld1q_u8(ref_ptr + 48);
-    const uint8x16_t sad_0_u8 = vabdq_u8(a_0, b_0);
-    const uint8x16_t sad_1_u8 = vabdq_u8(a_1, b_1);
-    const uint8x16_t sad_2_u8 = vabdq_u8(a_2, b_2);
-    const uint8x16_t sad_3_u8 = vabdq_u8(a_3, b_3);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    prod = vdotq_u32(prod, sad_0_u8, ones);
-    prod = vdotq_u32(prod, sad_1_u8, ones);
-    prod = vdotq_u32(prod, sad_2_u8, ones);
-    prod = vdotq_u32(prod, sad_3_u8, ones);
-  }
-  return prod;
+    second_pred += 16;
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
 
-static INLINE uint32x4_t sad64x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint32x4_t prod = vdupq_n_u32(0);
-  const uint8x16_t ones = vdupq_n_u8(1);
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_0 = vld1q_u8(src_ptr);
-    const uint8x16_t a_1 = vld1q_u8(src_ptr + 16);
-    const uint8x16_t a_2 = vld1q_u8(src_ptr + 32);
-    const uint8x16_t a_3 = vld1q_u8(src_ptr + 48);
-    const uint8x16_t b_0 = vld1q_u8(ref_ptr);
-    const uint8x16_t b_1 = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t b_2 = vld1q_u8(ref_ptr + 32);
-    const uint8x16_t b_3 = vld1q_u8(ref_ptr + 48);
-    const uint8x16_t c_0 = vld1q_u8(second_pred);
-    const uint8x16_t c_1 = vld1q_u8(second_pred + 16);
-    const uint8x16_t c_2 = vld1q_u8(second_pred + 32);
-    const uint8x16_t c_3 = vld1q_u8(second_pred + 48);
-    const uint8x16_t avg_0 = vrhaddq_u8(b_0, c_0);
-    const uint8x16_t avg_1 = vrhaddq_u8(b_1, c_1);
-    const uint8x16_t avg_2 = vrhaddq_u8(b_2, c_2);
-    const uint8x16_t avg_3 = vrhaddq_u8(b_3, c_3);
-    const uint8x16_t sad_0_u8 = vabdq_u8(a_0, avg_0);
-    const uint8x16_t sad_1_u8 = vabdq_u8(a_1, avg_1);
-    const uint8x16_t sad_2_u8 = vabdq_u8(a_2, avg_2);
-    const uint8x16_t sad_3_u8 = vabdq_u8(a_3, avg_3);
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE unsigned int sad8xh_avg_neon(const uint8_t *src_ptr,
+                                           int src_stride,
+                                           const uint8_t *ref_ptr,
+                                           int ref_stride, int h,
+                                           const uint8_t *second_pred) {
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = h;
+  do {
+    uint8x8_t s = vld1_u8(src_ptr);
+    uint8x8_t r = vld1_u8(ref_ptr);
+    uint8x8_t p = vld1_u8(second_pred);
+
+    uint8x8_t avg = vrhadd_u8(r, p);
+    sum = vabal_u8(sum, s, avg);
+
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 64;
-    prod = vdotq_u32(prod, sad_0_u8, ones);
-    prod = vdotq_u32(prod, sad_1_u8, ones);
-    prod = vdotq_u32(prod, sad_2_u8, ones);
-    prod = vdotq_u32(prod, sad_3_u8, ones);
-  }
-  return prod;
+    second_pred += 8;
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
-#else   // !defined(__ARM_FEATURE_DOTPROD)
-static INLINE uint32x4_t sad64x(const uint8_t *src_ptr, int src_stride,
-                                const uint8_t *ref_ptr, int ref_stride,
-                                const int height) {
-  int i;
-  uint16x8_t abs_0 = vdupq_n_u16(0);
-  uint16x8_t abs_1 = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_0 = vld1q_u8(src_ptr);
-    const uint8x16_t a_1 = vld1q_u8(src_ptr + 16);
-    const uint8x16_t a_2 = vld1q_u8(src_ptr + 32);
-    const uint8x16_t a_3 = vld1q_u8(src_ptr + 48);
-    const uint8x16_t b_0 = vld1q_u8(ref_ptr);
-    const uint8x16_t b_1 = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t b_2 = vld1q_u8(ref_ptr + 32);
-    const uint8x16_t b_3 = vld1q_u8(ref_ptr + 48);
+
+static INLINE unsigned int sad4xh_avg_neon(const uint8_t *src_ptr,
+                                           int src_stride,
+                                           const uint8_t *ref_ptr,
+                                           int ref_stride, int h,
+                                           const uint8_t *second_pred) {
+  uint16x8_t sum = vdupq_n_u16(0);
+
+  int i = h / 2;
+  do {
+    uint32x2_t s, r;
+    uint32_t s0, s1, r0, r1;
+    uint8x8_t p, avg;
+
+    memcpy(&s0, src_ptr, 4);
+    memcpy(&r0, ref_ptr, 4);
+    s = vdup_n_u32(s0);
+    r = vdup_n_u32(r0);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    abs_0 = vabal_u8(abs_0, vget_low_u8(a_0), vget_low_u8(b_0));
-    abs_0 = vabal_u8(abs_0, vget_high_u8(a_0), vget_high_u8(b_0));
-    abs_0 = vabal_u8(abs_0, vget_low_u8(a_1), vget_low_u8(b_1));
-    abs_0 = vabal_u8(abs_0, vget_high_u8(a_1), vget_high_u8(b_1));
-    abs_1 = vabal_u8(abs_1, vget_low_u8(a_2), vget_low_u8(b_2));
-    abs_1 = vabal_u8(abs_1, vget_high_u8(a_2), vget_high_u8(b_2));
-    abs_1 = vabal_u8(abs_1, vget_low_u8(a_3), vget_low_u8(b_3));
-    abs_1 = vabal_u8(abs_1, vget_high_u8(a_3), vget_high_u8(b_3));
-  }
-
-  {
-    const uint32x4_t sum = vpaddlq_u16(abs_0);
-    return vpadalq_u16(sum, abs_1);
-  }
-}
 
-static INLINE uint32x4_t sad64x_avg(const uint8_t *src_ptr, int src_stride,
-                                    const uint8_t *ref_ptr, int ref_stride,
-                                    const uint8_t *second_pred,
-                                    const int height) {
-  int i;
-  uint16x8_t abs_0 = vdupq_n_u16(0);
-  uint16x8_t abs_1 = vdupq_n_u16(0);
-
-  for (i = 0; i < height; ++i) {
-    const uint8x16_t a_0 = vld1q_u8(src_ptr);
-    const uint8x16_t a_1 = vld1q_u8(src_ptr + 16);
-    const uint8x16_t a_2 = vld1q_u8(src_ptr + 32);
-    const uint8x16_t a_3 = vld1q_u8(src_ptr + 48);
-    const uint8x16_t b_0 = vld1q_u8(ref_ptr);
-    const uint8x16_t b_1 = vld1q_u8(ref_ptr + 16);
-    const uint8x16_t b_2 = vld1q_u8(ref_ptr + 32);
-    const uint8x16_t b_3 = vld1q_u8(ref_ptr + 48);
-    const uint8x16_t c_0 = vld1q_u8(second_pred);
-    const uint8x16_t c_1 = vld1q_u8(second_pred + 16);
-    const uint8x16_t c_2 = vld1q_u8(second_pred + 32);
-    const uint8x16_t c_3 = vld1q_u8(second_pred + 48);
-    const uint8x16_t avg_0 = vrhaddq_u8(b_0, c_0);
-    const uint8x16_t avg_1 = vrhaddq_u8(b_1, c_1);
-    const uint8x16_t avg_2 = vrhaddq_u8(b_2, c_2);
-    const uint8x16_t avg_3 = vrhaddq_u8(b_3, c_3);
+    memcpy(&s1, src_ptr, 4);
+    memcpy(&r1, ref_ptr, 4);
+    s = vset_lane_u32(s1, s, 1);
+    r = vset_lane_u32(r1, r, 1);
     src_ptr += src_stride;
     ref_ptr += ref_stride;
-    second_pred += 64;
-    abs_0 = vabal_u8(abs_0, vget_low_u8(a_0), vget_low_u8(avg_0));
-    abs_0 = vabal_u8(abs_0, vget_high_u8(a_0), vget_high_u8(avg_0));
-    abs_0 = vabal_u8(abs_0, vget_low_u8(a_1), vget_low_u8(avg_1));
-    abs_0 = vabal_u8(abs_0, vget_high_u8(a_1), vget_high_u8(avg_1));
-    abs_1 = vabal_u8(abs_1, vget_low_u8(a_2), vget_low_u8(avg_2));
-    abs_1 = vabal_u8(abs_1, vget_high_u8(a_2), vget_high_u8(avg_2));
-    abs_1 = vabal_u8(abs_1, vget_low_u8(a_3), vget_low_u8(avg_3));
-    abs_1 = vabal_u8(abs_1, vget_high_u8(a_3), vget_high_u8(avg_3));
-  }
 
-  {
-    const uint32x4_t sum = vpaddlq_u16(abs_0);
-    return vpadalq_u16(sum, abs_1);
-  }
+    p = vld1_u8(second_pred);
+    avg = vrhadd_u8(vreinterpret_u8_u32(r), p);
+
+    sum = vabal_u8(sum, vreinterpret_u8_u32(s), avg);
+    second_pred += 8;
+  } while (--i != 0);
+
+  return horizontal_add_uint16x8(sum);
 }
-#endif  // defined(__ARM_FEATURE_DOTPROD)
 
-#define SAD64XN(n)                                                            \
-  uint32_t vpx_sad64x##n##_neon(const uint8_t *src_ptr, int src_stride,       \
-                                const uint8_t *ref_ptr, int ref_stride) {     \
-    const uint32x4_t abs =                                                    \
-        sad64x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return horizontal_add_uint32x4(abs);                                      \
-  }                                                                           \
-                                                                              \
-  uint32_t vpx_sad64x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
-                                    const uint8_t *ref_ptr, int ref_stride,   \
-                                    const uint8_t *second_pred) {             \
-    const uint32x4_t abs =                                                    \
-        sad64x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return horizontal_add_uint32x4(abs);                                      \
+#define SAD_WXH_AVG_NEON(w, h)                                             \
+  uint32_t vpx_sad##w##x##h##_avg_neon(const uint8_t *src, int src_stride, \
+                                       const uint8_t *ref, int ref_stride, \
+                                       const uint8_t *second_pred) {       \
+    return sad##w##xh_avg_neon(src, src_stride, ref, ref_stride, (h),      \
+                               second_pred);                               \
   }
 
-SAD64XN(32)
-SAD64XN(64)
+SAD_WXH_AVG_NEON(4, 4)
+SAD_WXH_AVG_NEON(4, 8)
+
+SAD_WXH_AVG_NEON(8, 4)
+SAD_WXH_AVG_NEON(8, 8)
+SAD_WXH_AVG_NEON(8, 16)
+
+SAD_WXH_AVG_NEON(16, 8)
+SAD_WXH_AVG_NEON(16, 16)
+SAD_WXH_AVG_NEON(16, 32)
+
+SAD_WXH_AVG_NEON(32, 16)
+SAD_WXH_AVG_NEON(32, 32)
+SAD_WXH_AVG_NEON(32, 64)
+
+SAD_WXH_AVG_NEON(64, 32)
+SAD_WXH_AVG_NEON(64, 64)