]> granicus.if.org Git - libvpx/commitdiff
Add Neon implementations of standard bitdepth MSE functions
authorSalome Thirot <salome.thirot@arm.com>
Fri, 24 Feb 2023 18:05:43 +0000 (18:05 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Mon, 27 Feb 2023 18:03:22 +0000 (18:03 +0000)
Currently only vpx_mse16x16 has a Neon implementation. This patch adds
optimized Armv8.0 and Armv8.4 dot-product paths for all block sizes:
8x8, 8x16, 16x8 and 16x16.

Add the corresponding tests as well.

Change-Id: Ib0357fdcdeb05860385fec89633386e34395e260

test/variance_test.cc
vpx_dsp/arm/variance_neon.c
vpx_dsp/vpx_dsp_rtcd_defs.pl

index 33f09209f4c4545b5d030525d040e94f745484cc..a68cfad516a8bf012442d9b65a42144ae46deaea 100644 (file)
@@ -773,6 +773,7 @@ TEST_P(VpxSseTest, RefSse) { RefTestSse(); }
 TEST_P(VpxSseTest, MaxSse) { MaxTestSse(); }
 TEST_P(VpxMseTest, RefMse) { RefTestMse(); }
 TEST_P(VpxMseTest, MaxMse) { MaxTestMse(); }
+TEST_P(VpxMseTest, DISABLED_Speed) { SpeedTest(); }
 TEST_P(VpxVarianceTest, Zero) { ZeroTest(); }
 TEST_P(VpxVarianceTest, Ref) { RefTest(); }
 TEST_P(VpxVarianceTest, RefStride) { RefStrideTest(); }
@@ -1450,8 +1451,10 @@ INSTANTIATE_TEST_SUITE_P(NEON, VpxSseTest,
                                                      &vpx_get4x4sse_cs_neon)));
 
 INSTANTIATE_TEST_SUITE_P(NEON, VpxMseTest,
-                         ::testing::Values(MseParams(4, 4,
-                                                     &vpx_mse16x16_neon)));
+                         ::testing::Values(MseParams(4, 4, &vpx_mse16x16_neon),
+                                           MseParams(4, 3, &vpx_mse16x8_neon),
+                                           MseParams(3, 4, &vpx_mse8x16_neon),
+                                           MseParams(3, 3, &vpx_mse8x8_neon)));
 
 INSTANTIATE_TEST_SUITE_P(
     NEON, VpxVarianceTest,
index 3ccc4e807b0e989d937bf7c3794a494d5c5c9b24..feff980c93ee744a0a37f8826109de791d73f653 100644 (file)
@@ -371,32 +371,66 @@ VARIANCE_WXH_NEON(64, 64, 12)
 
 #if defined(__ARM_FEATURE_DOTPROD)
 
-unsigned int vpx_mse16x16_neon(const unsigned char *src_ptr, int src_stride,
-                               const unsigned char *ref_ptr, int ref_stride,
-                               unsigned int *sse) {
-  int i;
-  uint8x16_t a[2], b[2], abs_diff[2];
-  uint32x4_t sse_vec[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
-
-  for (i = 0; i < 8; i++) {
-    a[0] = vld1q_u8(src_ptr);
+static INLINE unsigned int vpx_mse8xh_neon(const unsigned char *src_ptr,
+                                           int src_stride,
+                                           const unsigned char *ref_ptr,
+                                           int ref_stride, int h,
+                                           unsigned int *sse) {
+  uint32x2_t sse_u32[2] = { vdup_n_u32(0), vdup_n_u32(0) };
+
+  int i = h / 2;
+  do {
+    uint8x8_t s0, s1, r0, r1, diff0, diff1;
+
+    s0 = vld1_u8(src_ptr);
     src_ptr += src_stride;
-    a[1] = vld1q_u8(src_ptr);
+    s1 = vld1_u8(src_ptr);
     src_ptr += src_stride;
-    b[0] = vld1q_u8(ref_ptr);
+    r0 = vld1_u8(ref_ptr);
     ref_ptr += ref_stride;
-    b[1] = vld1q_u8(ref_ptr);
+    r1 = vld1_u8(ref_ptr);
     ref_ptr += ref_stride;
 
-    abs_diff[0] = vabdq_u8(a[0], b[0]);
-    abs_diff[1] = vabdq_u8(a[1], b[1]);
+    diff0 = vabd_u8(s0, r0);
+    diff1 = vabd_u8(s1, r1);
 
-    sse_vec[0] = vdotq_u32(sse_vec[0], abs_diff[0], abs_diff[0]);
-    sse_vec[1] = vdotq_u32(sse_vec[1], abs_diff[1], abs_diff[1]);
-  }
+    sse_u32[0] = vdot_u32(sse_u32[0], diff0, diff0);
+    sse_u32[1] = vdot_u32(sse_u32[1], diff1, diff1);
+  } while (--i != 0);
 
-  *sse = horizontal_add_uint32x4(vaddq_u32(sse_vec[0], sse_vec[1]));
-  return horizontal_add_uint32x4(vaddq_u32(sse_vec[0], sse_vec[1]));
+  *sse = horizontal_add_uint32x2(vadd_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
+}
+
+static INLINE unsigned int vpx_mse16xh_neon(const unsigned char *src_ptr,
+                                            int src_stride,
+                                            const unsigned char *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h / 2;
+  do {
+    uint8x16_t s0, s1, r0, r1, diff0, diff1;
+
+    s0 = vld1q_u8(src_ptr);
+    src_ptr += src_stride;
+    s1 = vld1q_u8(src_ptr);
+    src_ptr += src_stride;
+    r0 = vld1q_u8(ref_ptr);
+    ref_ptr += ref_stride;
+    r1 = vld1q_u8(ref_ptr);
+    ref_ptr += ref_stride;
+
+    diff0 = vabdq_u8(s0, r0);
+    diff1 = vabdq_u8(s1, r1);
+
+    sse_u32[0] = vdotq_u32(sse_u32[0], diff0, diff0);
+    sse_u32[1] = vdotq_u32(sse_u32[1], diff1, diff1);
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
 }
 
 unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, int src_stride,
@@ -435,58 +469,67 @@ unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, int src_stride,
 
 #else  // !defined(__ARM_FEATURE_DOTPROD)
 
-unsigned int vpx_mse16x16_neon(const unsigned char *src_ptr, int src_stride,
-                               const unsigned char *ref_ptr, int ref_stride,
-                               unsigned int *sse) {
-  int i;
-  uint8x16_t a[2], b[2];
-  int16x4_t diff_lo[4], diff_hi[4];
-  uint16x8_t diff[4];
-  int32x4_t sse_vec[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
-                           vdupq_n_s32(0) };
+static INLINE unsigned int vpx_mse8xh_neon(const unsigned char *src_ptr,
+                                           int src_stride,
+                                           const unsigned char *ref_ptr,
+                                           int ref_stride, int h,
+                                           unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
 
-  for (i = 0; i < 8; i++) {
-    a[0] = vld1q_u8(src_ptr);
+  int i = h / 2;
+  do {
+    uint8x8_t s0, s1, r0, r1, diff0, diff1;
+    uint16x8_t sse0, sse1;
+
+    s0 = vld1_u8(src_ptr);
     src_ptr += src_stride;
-    a[1] = vld1q_u8(src_ptr);
+    s1 = vld1_u8(src_ptr);
     src_ptr += src_stride;
-    b[0] = vld1q_u8(ref_ptr);
+    r0 = vld1_u8(ref_ptr);
     ref_ptr += ref_stride;
-    b[1] = vld1q_u8(ref_ptr);
+    r1 = vld1_u8(ref_ptr);
     ref_ptr += ref_stride;
 
-    diff[0] = vsubl_u8(vget_low_u8(a[0]), vget_low_u8(b[0]));
-    diff[1] = vsubl_u8(vget_high_u8(a[0]), vget_high_u8(b[0]));
-    diff[2] = vsubl_u8(vget_low_u8(a[1]), vget_low_u8(b[1]));
-    diff[3] = vsubl_u8(vget_high_u8(a[1]), vget_high_u8(b[1]));
-
-    diff_lo[0] = vreinterpret_s16_u16(vget_low_u16(diff[0]));
-    diff_lo[1] = vreinterpret_s16_u16(vget_low_u16(diff[1]));
-    sse_vec[0] = vmlal_s16(sse_vec[0], diff_lo[0], diff_lo[0]);
-    sse_vec[1] = vmlal_s16(sse_vec[1], diff_lo[1], diff_lo[1]);
-
-    diff_lo[2] = vreinterpret_s16_u16(vget_low_u16(diff[2]));
-    diff_lo[3] = vreinterpret_s16_u16(vget_low_u16(diff[3]));
-    sse_vec[2] = vmlal_s16(sse_vec[2], diff_lo[2], diff_lo[2]);
-    sse_vec[3] = vmlal_s16(sse_vec[3], diff_lo[3], diff_lo[3]);
-
-    diff_hi[0] = vreinterpret_s16_u16(vget_high_u16(diff[0]));
-    diff_hi[1] = vreinterpret_s16_u16(vget_high_u16(diff[1]));
-    sse_vec[0] = vmlal_s16(sse_vec[0], diff_hi[0], diff_hi[0]);
-    sse_vec[1] = vmlal_s16(sse_vec[1], diff_hi[1], diff_hi[1]);
-
-    diff_hi[2] = vreinterpret_s16_u16(vget_high_u16(diff[2]));
-    diff_hi[3] = vreinterpret_s16_u16(vget_high_u16(diff[3]));
-    sse_vec[2] = vmlal_s16(sse_vec[2], diff_hi[2], diff_hi[2]);
-    sse_vec[3] = vmlal_s16(sse_vec[3], diff_hi[3], diff_hi[3]);
-  }
+    diff0 = vabd_u8(s0, r0);
+    diff1 = vabd_u8(s1, r1);
+
+    sse0 = vmull_u8(diff0, diff0);
+    sse_u32[0] = vpadalq_u16(sse_u32[0], sse0);
+    sse1 = vmull_u8(diff1, diff1);
+    sse_u32[1] = vpadalq_u16(sse_u32[1], sse1);
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
+}
+
+static INLINE unsigned int vpx_mse16xh_neon(const unsigned char *src_ptr,
+                                            int src_stride,
+                                            const unsigned char *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h;
+  do {
+    uint8x16_t s, r, diff;
+    uint16x8_t sse0, sse1;
 
-  sse_vec[0] = vaddq_s32(sse_vec[0], sse_vec[1]);
-  sse_vec[2] = vaddq_s32(sse_vec[2], sse_vec[3]);
-  sse_vec[0] = vaddq_s32(sse_vec[0], sse_vec[2]);
+    s = vld1q_u8(src_ptr);
+    src_ptr += src_stride;
+    r = vld1q_u8(ref_ptr);
+    ref_ptr += ref_stride;
 
-  *sse = horizontal_add_uint32x4(vreinterpretq_u32_s32(sse_vec[0]));
-  return horizontal_add_uint32x4(vreinterpretq_u32_s32(sse_vec[0]));
+    diff = vabdq_u8(s, r);
+
+    sse0 = vmull_u8(vget_low_u8(diff), vget_low_u8(diff));
+    sse_u32[0] = vpadalq_u16(sse_u32[0], sse0);
+    sse1 = vmull_u8(vget_high_u8(diff), vget_high_u8(diff));
+    sse_u32[1] = vpadalq_u16(sse_u32[1], sse1);
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
 }
 
 unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, int src_stride,
@@ -531,3 +574,16 @@ unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr, int src_stride,
 }
 
 #endif  // defined(__ARM_FEATURE_DOTPROD)
+
+#define VPX_MSE_WXH_NEON(w, h)                                              \
+  unsigned int vpx_mse##w##x##h##_neon(                                     \
+      const unsigned char *src_ptr, int src_stride,                         \
+      const unsigned char *ref_ptr, int ref_stride, unsigned int *sse) {    \
+    return vpx_mse##w##xh_neon(src_ptr, src_stride, ref_ptr, ref_stride, h, \
+                               sse);                                        \
+  }
+
+VPX_MSE_WXH_NEON(8, 8)
+VPX_MSE_WXH_NEON(8, 16)
+VPX_MSE_WXH_NEON(16, 8)
+VPX_MSE_WXH_NEON(16, 16)
index eef72249e0be5a628f8442ffbde1ceab5dba33ba..0ad3cbe6b2ec8e13a327a3def6332452f1cb01c3 100644 (file)
@@ -1141,13 +1141,13 @@ add_proto qw/unsigned int vpx_mse16x16/, "const uint8_t *src_ptr, int src_stride
   specialize qw/vpx_mse16x16 sse2 avx2 neon msa mmi vsx lsx/;
 
 add_proto qw/unsigned int vpx_mse16x8/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
-  specialize qw/vpx_mse16x8 sse2 avx2 msa mmi vsx/;
+  specialize qw/vpx_mse16x8 sse2 avx2 neon msa mmi vsx/;
 
 add_proto qw/unsigned int vpx_mse8x16/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
-  specialize qw/vpx_mse8x16 sse2 msa mmi vsx/;
+  specialize qw/vpx_mse8x16 sse2 neon msa mmi vsx/;
 
 add_proto qw/unsigned int vpx_mse8x8/, "const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse";
-  specialize qw/vpx_mse8x8 sse2 msa mmi vsx/;
+  specialize qw/vpx_mse8x8 sse2 neon msa mmi vsx/;
 
 add_proto qw/unsigned int vpx_get_mb_ss/, "const int16_t *";
   specialize qw/vpx_get_mb_ss sse2 msa vsx/;