Optimize vpx_highbd_comp_avg_pred_neon
authorSalome Thirot <salome.thirot@arm.com>
Fri, 10 Feb 2023 10:50:47 +0000 (10:50 +0000)
committerSalome Thirot <salome.thirot@arm.com>
Mon, 13 Feb 2023 20:23:14 +0000 (20:23 +0000)
Optimize the implementation of vpx_highbd_comp_avg_pred_neon by making
use of the URHADD instruction to compute the average.

Change-Id: Id74a6d9c33e89bc548c3c7ecace59af69051b4a7

vpx_dsp/arm/highbd_avg_pred_neon.c

index 04dbaad7a3f9abdb866414d05b8342b8b2460b3e..3063acbb3e3fa73d4092cdcdbd4d51392b56328e 100644 (file)
@@ -9,6 +9,7 @@
  */
 
 #include <arm_neon.h>
+#include <assert.h>
 
 #include "./vpx_dsp_rtcd.h"
 #include "./vpx_config.h"
 void vpx_highbd_comp_avg_pred_neon(uint16_t *comp_pred, const uint16_t *pred,
                                    int width, int height, const uint16_t *ref,
                                    int ref_stride) {
-  int i, j;
-  uint32x4_t one_u32 = vdupq_n_u32(1);
-  if (width >= 8) {
-    for (i = 0; i < height; ++i) {
-      for (j = 0; j < width; j += 8) {
-        const uint16x8_t pred_u16 = vld1q_u16(&pred[j]);
-        const uint16x8_t ref_u16 = vld1q_u16(&ref[j]);
-        const uint32x4_t sum1_u32 =
-            vaddl_u16(vget_low_u16(pred_u16), vget_low_u16(ref_u16));
-        const uint32x4_t sum2_u32 =
-            vaddl_u16(vget_high_u16(pred_u16), vget_high_u16(ref_u16));
-        const uint16x4_t sum1_u16 =
-            vshrn_n_u32(vaddq_u32(sum1_u32, one_u32), 1);
-        const uint16x4_t sum2_u16 =
-            vshrn_n_u32(vaddq_u32(sum2_u32, one_u32), 1);
-        const uint16x8_t vcomp_pred = vcombine_u16(sum1_u16, sum2_u16);
-        vst1q_u16(&comp_pred[j], vcomp_pred);
-      }
+  int i = height;
+  if (width > 8) {
+    do {
+      int j = 0;
+      do {
+        const uint16x8_t p = vld1q_u16(pred + j);
+        const uint16x8_t r = vld1q_u16(ref + j);
+
+        uint16x8_t avg = vrhaddq_u16(p, r);
+        vst1q_u16(comp_pred + j, avg);
+
+        j += 8;
+      } while (j < width);
+
+      comp_pred += width;
+      pred += width;
+      ref += ref_stride;
+    } while (--i != 0);
+  } else if (width == 8) {
+    do {
+      const uint16x8_t p = vld1q_u16(pred);
+      const uint16x8_t r = vld1q_u16(ref);
+
+      uint16x8_t avg = vrhaddq_u16(p, r);
+      vst1q_u16(comp_pred, avg);
+
       comp_pred += width;
       pred += width;
       ref += ref_stride;
-    }
+    } while (--i != 0);
   } else {
-    assert(width >= 4);
-    for (i = 0; i < height; ++i) {
-      for (j = 0; j < width; j += 4) {
-        const uint16x4_t pred_u16 = vld1_u16(&pred[j]);
-        const uint16x4_t ref_u16 = vld1_u16(&ref[j]);
-        const uint32x4_t sum_u32 = vaddl_u16(pred_u16, ref_u16);
-        const uint16x4_t vcomp_pred =
-            vshrn_n_u32(vaddq_u32(sum_u32, one_u32), 1);
-        vst1_u16(&comp_pred[j], vcomp_pred);
-      }
+    assert(width == 4);
+    do {
+      const uint16x4_t p = vld1_u16(pred);
+      const uint16x4_t r = vld1_u16(ref);
+
+      uint16x4_t avg = vrhadd_u16(p, r);
+      vst1_u16(comp_pred, avg);
+
       comp_pred += width;
       pred += width;
       ref += ref_stride;
-    }
+    } while (--i != 0);
   }
 }