]> granicus.if.org Git - libvpx/commitdiff
Optimize Neon implementation of high bitdepth MSE functions
authorSalome Thirot <salome.thirot@arm.com>
Mon, 27 Feb 2023 17:58:18 +0000 (17:58 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Wed, 1 Mar 2023 13:35:03 +0000 (13:35 +0000)
Currently MSE functions just call the variance helpers but don't
actually use the computed sum. This patch adds dedicated helpers to
perform the computation of sse.

Add the corresponding tests as well.

Change-Id: I96a8590e3410e84d77f7187344688e02efe03902

test/variance_test.cc
vpx_dsp/arm/highbd_variance_neon.c

index a68cfad516a8bf012442d9b65a42144ae46deaea..1359bc4baf5a406f6c5421bedeee85800a8222c1 100644 (file)
@@ -1507,6 +1507,22 @@ INSTANTIATE_TEST_SUITE_P(
         SubpelAvgVarianceParams(2, 2, &vpx_sub_pixel_avg_variance4x4_neon, 0)));
 
 #if CONFIG_VP9_HIGHBITDEPTH
+INSTANTIATE_TEST_SUITE_P(
+    NEON, VpxHBDMseTest,
+    ::testing::Values(
+        MseParams(4, 4, &vpx_highbd_12_mse16x16_neon, VPX_BITS_12),
+        MseParams(4, 3, &vpx_highbd_12_mse16x8_neon, VPX_BITS_12),
+        MseParams(3, 4, &vpx_highbd_12_mse8x16_neon, VPX_BITS_12),
+        MseParams(3, 3, &vpx_highbd_12_mse8x8_neon, VPX_BITS_12),
+        MseParams(4, 4, &vpx_highbd_10_mse16x16_neon, VPX_BITS_10),
+        MseParams(4, 3, &vpx_highbd_10_mse16x8_neon, VPX_BITS_10),
+        MseParams(3, 4, &vpx_highbd_10_mse8x16_neon, VPX_BITS_10),
+        MseParams(3, 3, &vpx_highbd_10_mse8x8_neon, VPX_BITS_10),
+        MseParams(4, 4, &vpx_highbd_8_mse16x16_neon, VPX_BITS_8),
+        MseParams(4, 3, &vpx_highbd_8_mse16x8_neon, VPX_BITS_8),
+        MseParams(3, 4, &vpx_highbd_8_mse8x16_neon, VPX_BITS_8),
+        MseParams(3, 3, &vpx_highbd_8_mse8x8_neon, VPX_BITS_8)));
+
 INSTANTIATE_TEST_SUITE_P(
     NEON, VpxHBDVarianceTest,
     ::testing::Values(
index 89bd5c579d1326bd4e4af9465d2f457cb0fd5f50..d0b366c95b0d80e12be9d1098ea0979461d0d487 100644 (file)
@@ -351,50 +351,159 @@ HBD_VARIANCE_WXH_12_XLARGE_NEON(64, 64)
     *sum = (int)ROUND_POWER_OF_TWO(sum_long, 4);                      \
   }
 
-#define HIGHBD_MSE(w, h)                                              \
-  uint32_t vpx_highbd_8_mse##w##x##h##_neon(                          \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)sse_long;                                        \
-    return *sse;                                                      \
-  }                                                                   \
-                                                                      \
-  uint32_t vpx_highbd_10_mse##w##x##h##_neon(                         \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 4);                 \
-    return *sse;                                                      \
-  }                                                                   \
-                                                                      \
-  uint32_t vpx_highbd_12_mse##w##x##h##_neon(                         \
-      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr, \
-      int ref_stride, uint32_t *sse) {                                \
-    uint64_t sse_long = 0;                                            \
-    int64_t sum_long = 0;                                             \
-    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                     \
-    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                     \
-    highbd_variance_##w##xh_neon(src, src_stride, ref, ref_stride, h, \
-                                 &sse_long, &sum_long);               \
-    *sse = (uint32_t)ROUND_POWER_OF_TWO(sse_long, 8);                 \
-    return *sse;                                                      \
-  }
-
 HIGHBD_GET_VAR(8)
 HIGHBD_GET_VAR(16)
 
-HIGHBD_MSE(16, 16)
-HIGHBD_MSE(16, 8)
-HIGHBD_MSE(8, 16)
-HIGHBD_MSE(8, 8)
+static INLINE uint32_t highbd_mse_wxh_neon(const uint16_t *src_ptr,
+                                           int src_stride,
+                                           const uint16_t *ref_ptr,
+                                           int ref_stride, int w, int h,
+                                           unsigned int *sse) {
+  uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
+
+  int i = h;
+  do {
+    int j = 0;
+    do {
+      uint16x8_t s = vld1q_u16(src_ptr + j);
+      uint16x8_t r = vld1q_u16(ref_ptr + j);
+
+      uint16x8_t diff = vabdq_u16(s, r);
+
+      sse_u32[0] =
+          vmlal_u16(sse_u32[0], vget_low_u16(diff), vget_low_u16(diff));
+      sse_u32[1] =
+          vmlal_u16(sse_u32[1], vget_high_u16(diff), vget_high_u16(diff));
+
+      j += 8;
+    } while (j < w);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
+  return *sse;
+}
+
+#if defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h / 2;
+  do {
+    uint16x8_t s0, s1, r0, r1;
+    uint8x16_t s, r, diff;
+
+    s0 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    s1 = vld1q_u16(src_ptr);
+    src_ptr += src_stride;
+    r0 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+    r1 = vld1q_u16(ref_ptr);
+    ref_ptr += ref_stride;
+
+    s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(sse_u32);
+  return *sse;
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+  int i = h;
+  do {
+    uint16x8_t s0, s1, r0, r1;
+    uint8x16_t s, r, diff;
+
+    s0 = vld1q_u16(src_ptr);
+    s1 = vld1q_u16(src_ptr + 8);
+    r0 = vld1q_u16(ref_ptr);
+    r1 = vld1q_u16(ref_ptr + 8);
+
+    s = vcombine_u8(vmovn_u16(s0), vmovn_u16(s1));
+    r = vcombine_u8(vmovn_u16(r0), vmovn_u16(r1));
+
+    diff = vabdq_u8(s, r);
+    sse_u32 = vdotq_u32(sse_u32, diff, diff);
+
+    src_ptr += src_stride;
+    ref_ptr += ref_stride;
+  } while (--i != 0);
+
+  *sse = horizontal_add_uint32x4(sse_u32);
+  return *sse;
+}
+
+#else  // !defined(__ARM_FEATURE_DOTPROD)
+
+static INLINE uint32_t highbd_mse8_8xh_neon(const uint16_t *src_ptr,
+                                            int src_stride,
+                                            const uint16_t *ref_ptr,
+                                            int ref_stride, int h,
+                                            unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 8, h,
+                             sse);
+}
+
+static INLINE uint32_t highbd_mse8_16xh_neon(const uint16_t *src_ptr,
+                                             int src_stride,
+                                             const uint16_t *ref_ptr,
+                                             int ref_stride, int h,
+                                             unsigned int *sse) {
+  return highbd_mse_wxh_neon(src_ptr, src_stride, ref_ptr, ref_stride, 16, h,
+                             sse);
+}
+
+#endif  // defined(__ARM_FEATURE_DOTPROD)
+
+#define HIGHBD_MSE_WXH_NEON(w, h)                                       \
+  uint32_t vpx_highbd_8_mse##w##x##h##_neon(                            \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse8_##w##xh_neon(src, src_stride, ref, ref_stride, h, sse); \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t vpx_highbd_10_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 4);                                 \
+    return *sse;                                                        \
+  }                                                                     \
+                                                                        \
+  uint32_t vpx_highbd_12_mse##w##x##h##_neon(                           \
+      const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,   \
+      int ref_stride, uint32_t *sse) {                                  \
+    uint16_t *src = CONVERT_TO_SHORTPTR(src_ptr);                       \
+    uint16_t *ref = CONVERT_TO_SHORTPTR(ref_ptr);                       \
+    highbd_mse_wxh_neon(src, src_stride, ref, ref_stride, w, h, sse);   \
+    *sse = ROUND_POWER_OF_TWO(*sse, 8);                                 \
+    return *sse;                                                        \
+  }
+
+HIGHBD_MSE_WXH_NEON(16, 16)
+HIGHBD_MSE_WXH_NEON(16, 8)
+HIGHBD_MSE_WXH_NEON(8, 16)
+HIGHBD_MSE_WXH_NEON(8, 8)