]> granicus.if.org Git - libvpx/commitdiff
[NEON] Move helper functions for reuse
authorKonstantinos Margaritis <konma@vectorcamp.gr>
Thu, 6 Oct 2022 13:05:01 +0000 (13:05 +0000)
committerKonstantinos Margaritis <konma@vectorcamp.gr>
Wed, 12 Oct 2022 08:11:53 +0000 (08:11 +0000)
Move all butterfly functions to fdct_neon.h
Slightly optimize load/scale/cross functions
in fdct 16x16.
These will be reused in highbd variants.

Change-Id: I28b6e0cc240304bab6b94d9c3f33cca77b8cb073

vpx_dsp/arm/fdct16x16_neon.c
vpx_dsp/arm/fdct16x16_neon.h
vpx_dsp/arm/fdct32x32_neon.c
vpx_dsp/arm/fdct4x4_neon.c [moved from vpx_dsp/arm/fdct_neon.c with 100% similarity]
vpx_dsp/arm/fdct8x8_neon.c [moved from vpx_dsp/arm/fwd_txfm_neon.c with 100% similarity]
vpx_dsp/arm/fdct_neon.h
vpx_dsp/vpx_dsp.mk

index 5cccb6a64a187d11acd6e7286bbbcd7e143c9349..0b0ce223db10cf6c8e37d46eb301518ad92c8aeb 100644 (file)
@@ -35,13 +35,13 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
   int16x8_t temp3[16];
 
   // Left half.
-  load(input, stride, temp0);
-  cross_input(temp0, temp1, 0);
+  load_cross(input, stride, temp0);
+  scale_input(temp0, temp1);
   vpx_fdct16x16_body(temp1, temp0);
 
   // Right half.
-  load(input + 8, stride, temp1);
-  cross_input(temp1, temp2, 0);
+  load_cross(input + 8, stride, temp1);
+  scale_input(temp1, temp2);
   vpx_fdct16x16_body(temp2, temp1);
 
   // Transpose top left and top right quarters into one contiguous location to
@@ -49,7 +49,7 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
   transpose_s16_8x8_new(&temp0[0], &temp2[0]);
   transpose_s16_8x8_new(&temp1[0], &temp2[8]);
   partial_round_shift(temp2);
-  cross_input(temp2, temp3, 1);
+  cross_input(temp2, temp3);
   vpx_fdct16x16_body(temp3, temp2);
   transpose_s16_8x8(&temp2[0], &temp2[1], &temp2[2], &temp2[3], &temp2[4],
                     &temp2[5], &temp2[6], &temp2[7]);
@@ -65,7 +65,7 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
   transpose_s16_8x8(&temp1[8], &temp1[9], &temp1[10], &temp1[11], &temp1[12],
                     &temp1[13], &temp1[14], &temp1[15]);
   partial_round_shift(temp1);
-  cross_input(temp1, temp0, 1);
+  cross_input(temp1, temp0);
   vpx_fdct16x16_body(temp0, temp1);
   transpose_s16_8x8(&temp1[0], &temp1[1], &temp1[2], &temp1[3], &temp1[4],
                     &temp1[5], &temp1[6], &temp1[7]);
index 5ce74cdf415783472a95ccc551bf05491d8c898b..7fc2c6e7e87430a16bb615ae519c692b626df32e 100644 (file)
@@ -13,6 +13,8 @@
 
 #include <arm_neon.h>
 
+#include "fdct_neon.h"
+
 static INLINE void load(const int16_t *a, int stride, int16x8_t *b /*[16]*/) {
   b[0] = vld1q_s16(a);
   a += stride;
@@ -72,45 +74,67 @@ static INLINE void store(tran_low_t *a, const int16x8_t *b /*[8]*/) {
 // To maybe reduce register usage this could be combined with the load() step to
 // get the first 4 and last 4 values, cross those, then load the middle 8 values
 // and cross them.
+static INLINE void scale_input(const int16x8_t *a /*[16]*/,
+                               int16x8_t *b /*[16]*/) {
+  b[0] = vshlq_n_s16(a[0], 2);
+  b[1] = vshlq_n_s16(a[1], 2);
+  b[2] = vshlq_n_s16(a[2], 2);
+  b[3] = vshlq_n_s16(a[3], 2);
+  b[4] = vshlq_n_s16(a[4], 2);
+  b[5] = vshlq_n_s16(a[5], 2);
+  b[6] = vshlq_n_s16(a[6], 2);
+  b[7] = vshlq_n_s16(a[7], 2);
+
+  b[8] = vshlq_n_s16(a[8], 2);
+  b[9] = vshlq_n_s16(a[9], 2);
+  b[10] = vshlq_n_s16(a[10], 2);
+  b[11] = vshlq_n_s16(a[11], 2);
+  b[12] = vshlq_n_s16(a[12], 2);
+  b[13] = vshlq_n_s16(a[13], 2);
+  b[14] = vshlq_n_s16(a[14], 2);
+  b[15] = vshlq_n_s16(a[15], 2);
+}
+
 static INLINE void cross_input(const int16x8_t *a /*[16]*/,
-                               int16x8_t *b /*[16]*/, const int pass) {
-  if (pass == 0) {
-    b[0] = vshlq_n_s16(vaddq_s16(a[0], a[15]), 2);
-    b[1] = vshlq_n_s16(vaddq_s16(a[1], a[14]), 2);
-    b[2] = vshlq_n_s16(vaddq_s16(a[2], a[13]), 2);
-    b[3] = vshlq_n_s16(vaddq_s16(a[3], a[12]), 2);
-    b[4] = vshlq_n_s16(vaddq_s16(a[4], a[11]), 2);
-    b[5] = vshlq_n_s16(vaddq_s16(a[5], a[10]), 2);
-    b[6] = vshlq_n_s16(vaddq_s16(a[6], a[9]), 2);
-    b[7] = vshlq_n_s16(vaddq_s16(a[7], a[8]), 2);
+                               int16x8_t *b /*[16]*/) {
+  b[0] = vaddq_s16(a[0], a[15]);
+  b[1] = vaddq_s16(a[1], a[14]);
+  b[2] = vaddq_s16(a[2], a[13]);
+  b[3] = vaddq_s16(a[3], a[12]);
+  b[4] = vaddq_s16(a[4], a[11]);
+  b[5] = vaddq_s16(a[5], a[10]);
+  b[6] = vaddq_s16(a[6], a[9]);
+  b[7] = vaddq_s16(a[7], a[8]);
+
+  b[8] = vsubq_s16(a[7], a[8]);
+  b[9] = vsubq_s16(a[6], a[9]);
+  b[10] = vsubq_s16(a[5], a[10]);
+  b[11] = vsubq_s16(a[4], a[11]);
+  b[12] = vsubq_s16(a[3], a[12]);
+  b[13] = vsubq_s16(a[2], a[13]);
+  b[14] = vsubq_s16(a[1], a[14]);
+  b[15] = vsubq_s16(a[0], a[15]);
+}
 
-    b[8] = vshlq_n_s16(vsubq_s16(a[7], a[8]), 2);
-    b[9] = vshlq_n_s16(vsubq_s16(a[6], a[9]), 2);
-    b[10] = vshlq_n_s16(vsubq_s16(a[5], a[10]), 2);
-    b[11] = vshlq_n_s16(vsubq_s16(a[4], a[11]), 2);
-    b[12] = vshlq_n_s16(vsubq_s16(a[3], a[12]), 2);
-    b[13] = vshlq_n_s16(vsubq_s16(a[2], a[13]), 2);
-    b[14] = vshlq_n_s16(vsubq_s16(a[1], a[14]), 2);
-    b[15] = vshlq_n_s16(vsubq_s16(a[0], a[15]), 2);
-  } else {
-    b[0] = vaddq_s16(a[0], a[15]);
-    b[1] = vaddq_s16(a[1], a[14]);
-    b[2] = vaddq_s16(a[2], a[13]);
-    b[3] = vaddq_s16(a[3], a[12]);
-    b[4] = vaddq_s16(a[4], a[11]);
-    b[5] = vaddq_s16(a[5], a[10]);
-    b[6] = vaddq_s16(a[6], a[9]);
-    b[7] = vaddq_s16(a[7], a[8]);
+static INLINE void load_cross(const int16_t *a, int stride,
+                              int16x8_t *b /*[16]*/) {
+  b[0] = vaddq_s16(vld1q_s16(a + 0 * stride), vld1q_s16(a + 15 * stride));
+  b[1] = vaddq_s16(vld1q_s16(a + 1 * stride), vld1q_s16(a + 14 * stride));
+  b[2] = vaddq_s16(vld1q_s16(a + 2 * stride), vld1q_s16(a + 13 * stride));
+  b[3] = vaddq_s16(vld1q_s16(a + 3 * stride), vld1q_s16(a + 12 * stride));
+  b[4] = vaddq_s16(vld1q_s16(a + 4 * stride), vld1q_s16(a + 11 * stride));
+  b[5] = vaddq_s16(vld1q_s16(a + 5 * stride), vld1q_s16(a + 10 * stride));
+  b[6] = vaddq_s16(vld1q_s16(a + 6 * stride), vld1q_s16(a + 9 * stride));
+  b[7] = vaddq_s16(vld1q_s16(a + 7 * stride), vld1q_s16(a + 8 * stride));
 
-    b[8] = vsubq_s16(a[7], a[8]);
-    b[9] = vsubq_s16(a[6], a[9]);
-    b[10] = vsubq_s16(a[5], a[10]);
-    b[11] = vsubq_s16(a[4], a[11]);
-    b[12] = vsubq_s16(a[3], a[12]);
-    b[13] = vsubq_s16(a[2], a[13]);
-    b[14] = vsubq_s16(a[1], a[14]);
-    b[15] = vsubq_s16(a[0], a[15]);
-  }
+  b[8] = vsubq_s16(vld1q_s16(a + 7 * stride), vld1q_s16(a + 8 * stride));
+  b[9] = vsubq_s16(vld1q_s16(a + 6 * stride), vld1q_s16(a + 9 * stride));
+  b[10] = vsubq_s16(vld1q_s16(a + 5 * stride), vld1q_s16(a + 10 * stride));
+  b[11] = vsubq_s16(vld1q_s16(a + 4 * stride), vld1q_s16(a + 11 * stride));
+  b[12] = vsubq_s16(vld1q_s16(a + 3 * stride), vld1q_s16(a + 12 * stride));
+  b[13] = vsubq_s16(vld1q_s16(a + 2 * stride), vld1q_s16(a + 13 * stride));
+  b[14] = vsubq_s16(vld1q_s16(a + 1 * stride), vld1q_s16(a + 14 * stride));
+  b[15] = vsubq_s16(vld1q_s16(a + 0 * stride), vld1q_s16(a + 15 * stride));
 }
 
 // Quarter round at the beginning of the second pass. Can't use vrshr (rounding)
@@ -135,45 +159,6 @@ static INLINE void partial_round_shift(int16x8_t *a /*[16]*/) {
   a[15] = vshrq_n_s16(vaddq_s16(a[15], one), 2);
 }
 
-// fdct_round_shift((a +/- b) * c)
-static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
-                                       const tran_high_t c, int16x8_t *add,
-                                       int16x8_t *sub) {
-  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), c);
-  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), c);
-  const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), c);
-  const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), c);
-  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), c);
-  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), c);
-  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, 14);
-  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, 14);
-  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, 14);
-  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, 14);
-  *add = vcombine_s16(rounded0, rounded1);
-  *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// fdct_round_shift(a * c0 +/- b * c1)
-static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
-                                       const tran_coef_t c0,
-                                       const tran_coef_t c1, int16x8_t *add,
-                                       int16x8_t *sub) {
-  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), c0);
-  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), c0);
-  const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), c1);
-  const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), c1);
-  const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), c0);
-  const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), c0);
-  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), c1);
-  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), c1);
-  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, 14);
-  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, 14);
-  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, 14);
-  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, 14);
-  *add = vcombine_s16(rounded0, rounded1);
-  *sub = vcombine_s16(rounded2, rounded3);
-}
-
 // Main body of fdct16x16.
 static void vpx_fdct16x16_body(const int16x8_t *in /*[16]*/,
                                int16x8_t *out /*[16]*/) {
index de74e6630b787fdab418d7f0e25b2955723b0b1f..51d81bd085ec52a00b1ef7e5a7858f72e0b63194 100644 (file)
@@ -15,6 +15,7 @@
 #include "vpx_dsp/txfm_common.h"
 #include "vpx_dsp/arm/mem_neon.h"
 #include "vpx_dsp/arm/transpose_neon.h"
+#include "vpx_dsp/arm/fdct_neon.h"
 
 // Most gcc 4.9 distributions outside of Android do not generate correct code
 // for this function.
@@ -194,54 +195,6 @@ static INLINE void store(tran_low_t *a, const int16x8_t *b) {
 
 #undef STORE_S16
 
-// fdct_round_shift((a +/- b) * c)
-static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
-                                       const tran_high_t constant,
-                                       int16x8_t *add, int16x8_t *sub) {
-  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
-  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
-  const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
-  const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
-  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
-  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
-  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
-  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
-  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
-  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
-  *add = vcombine_s16(rounded0, rounded1);
-  *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// fdct_round_shift(a * c0 +/- b * c1)
-static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
-                                       const tran_coef_t constant0,
-                                       const tran_coef_t constant1,
-                                       int16x8_t *add, int16x8_t *sub) {
-  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant0);
-  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant0);
-  const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), constant1);
-  const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), constant1);
-  const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), constant0);
-  const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), constant0);
-  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant1);
-  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant1);
-  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
-  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
-  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
-  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
-  *add = vcombine_s16(rounded0, rounded1);
-  *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// Add 2 if positive, 1 if negative, and shift by 2.
-// In practice, subtract the sign bit, then shift with rounding.
-static INLINE int16x8_t sub_round_shift(const int16x8_t a) {
-  const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
-  const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
-  const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
-  return vrshrq_n_s16(vsubq_s16(a, a_sign_s16), 2);
-}
-
 static void dct_body_first_pass(const int16x8_t *in, int16x8_t *out) {
   int16x8_t a[32];
   int16x8_t b[32];
@@ -562,23 +515,6 @@ static void dct_body_first_pass(const int16x8_t *in, int16x8_t *out) {
     b##_hi[b_index] = vsubq_s32(a##_hi[left_index], a##_hi[right_index]); \
   } while (0)
 
-// Like butterfly_one_coeff, but don't narrow results.
-static INLINE void butterfly_one_coeff_s16_s32(
-    const int16x8_t a, const int16x8_t b, const tran_high_t constant,
-    int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
-    int32x4_t *sub_hi) {
-  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
-  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
-  const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
-  const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
-  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
-  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
-  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
-  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
-  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
-  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
 #define BUTTERFLY_ONE_S16_S32(a, left_index, right_index, constant, b,   \
                               add_index, sub_index)                      \
   do {                                                                   \
@@ -587,23 +523,6 @@ static INLINE void butterfly_one_coeff_s16_s32(
                                 &b##_lo[sub_index], &b##_hi[sub_index]); \
   } while (0)
 
-// Like butterfly_one_coeff, but with s32.
-static INLINE void butterfly_one_coeff_s32(
-    const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
-    const int32x4_t b_hi, const int32_t constant, int32x4_t *add_lo,
-    int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
-  const int32x4_t a_lo_0 = vmulq_n_s32(a_lo, constant);
-  const int32x4_t a_hi_0 = vmulq_n_s32(a_hi, constant);
-  const int32x4_t sum0 = vmlaq_n_s32(a_lo_0, b_lo, constant);
-  const int32x4_t sum1 = vmlaq_n_s32(a_hi_0, b_hi, constant);
-  const int32x4_t diff0 = vmlsq_n_s32(a_lo_0, b_lo, constant);
-  const int32x4_t diff1 = vmlsq_n_s32(a_hi_0, b_hi, constant);
-  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
-  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
-  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
-  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
 #define BUTTERFLY_ONE_S32(a, left_index, right_index, constant, b, add_index, \
                           sub_index)                                          \
   do {                                                                        \
@@ -613,26 +532,6 @@ static INLINE void butterfly_one_coeff_s32(
                             &b##_lo[sub_index], &b##_hi[sub_index]);          \
   } while (0)
 
-// Like butterfly_two_coeff, but with s32.
-static INLINE void butterfly_two_coeff_s32(
-    const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
-    const int32x4_t b_hi, const int32_t constant0, const int32_t constant1,
-    int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
-    int32x4_t *sub_hi) {
-  const int32x4_t a0 = vmulq_n_s32(a_lo, constant0);
-  const int32x4_t a1 = vmulq_n_s32(a_hi, constant0);
-  const int32x4_t a2 = vmulq_n_s32(a_lo, constant1);
-  const int32x4_t a3 = vmulq_n_s32(a_hi, constant1);
-  const int32x4_t sum0 = vmlaq_n_s32(a2, b_lo, constant0);
-  const int32x4_t sum1 = vmlaq_n_s32(a3, b_hi, constant0);
-  const int32x4_t diff0 = vmlsq_n_s32(a0, b_lo, constant1);
-  const int32x4_t diff1 = vmlsq_n_s32(a1, b_hi, constant1);
-  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
-  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
-  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
-  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
 #define BUTTERFLY_TWO_S32(a, left_index, right_index, left_constant,           \
                           right_constant, b, add_index, sub_index)             \
   do {                                                                         \
@@ -643,24 +542,6 @@ static INLINE void butterfly_two_coeff_s32(
                             &b##_hi[sub_index]);                               \
   } while (0)
 
-// Add 1 if positive, 2 if negative, and shift by 2.
-// In practice, add 1, then add the sign bit, then shift without rounding.
-static INLINE int16x8_t add_round_shift_s32(const int32x4_t a_lo,
-                                            const int32x4_t a_hi) {
-  const int32x4_t one = vdupq_n_s32(1);
-  const uint32x4_t a_lo_u32 = vreinterpretq_u32_s32(a_lo);
-  const uint32x4_t a_lo_sign_u32 = vshrq_n_u32(a_lo_u32, 31);
-  const int32x4_t a_lo_sign_s32 = vreinterpretq_s32_u32(a_lo_sign_u32);
-  const int16x4_t b_lo =
-      vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_lo, a_lo_sign_s32), one), 2);
-  const uint32x4_t a_hi_u32 = vreinterpretq_u32_s32(a_hi);
-  const uint32x4_t a_hi_sign_u32 = vshrq_n_u32(a_hi_u32, 31);
-  const int32x4_t a_hi_sign_s32 = vreinterpretq_s32_u32(a_hi_sign_u32);
-  const int16x4_t b_hi =
-      vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_hi, a_hi_sign_s32), one), 2);
-  return vcombine_s16(b_lo, b_hi);
-}
-
 static void dct_body_second_pass(const int16x8_t *in, int16x8_t *out) {
   int16x8_t a[32];
   int16x8_t b[32];
@@ -967,16 +848,6 @@ static void dct_body_second_pass(const int16x8_t *in, int16x8_t *out) {
   out[3] = add_round_shift_s32(d_lo[3], d_hi[3]);
 }
 
-// Add 1 if positive, 2 if negative, and shift by 2.
-// In practice, add 1, then add the sign bit, then shift without rounding.
-static INLINE int16x8_t add_round_shift_s16(const int16x8_t a) {
-  const int16x8_t one = vdupq_n_s16(1);
-  const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
-  const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
-  const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
-  return vshrq_n_s16(vaddq_s16(vaddq_s16(a, a_sign_s16), one), 2);
-}
-
 static void dct_body_second_pass_rd(const int16x8_t *in, int16x8_t *out) {
   int16x8_t a[32];
   int16x8_t b[32];
@@ -1279,42 +1150,6 @@ static void dct_body_second_pass_rd(const int16x8_t *in, int16x8_t *out) {
 #undef BUTTERFLY_ONE_S32
 #undef BUTTERFLY_TWO_S32
 
-// Transpose 8x8 to a new location. Don't use transpose_neon.h because those
-// are all in-place.
-// TODO(johannkoenig): share with other fdcts.
-static INLINE void transpose_8x8(const int16x8_t *a, int16x8_t *b) {
-  // Swap 16 bit elements.
-  const int16x8x2_t c0 = vtrnq_s16(a[0], a[1]);
-  const int16x8x2_t c1 = vtrnq_s16(a[2], a[3]);
-  const int16x8x2_t c2 = vtrnq_s16(a[4], a[5]);
-  const int16x8x2_t c3 = vtrnq_s16(a[6], a[7]);
-
-  // Swap 32 bit elements.
-  const int32x4x2_t d0 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[0]),
-                                   vreinterpretq_s32_s16(c1.val[0]));
-  const int32x4x2_t d1 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[1]),
-                                   vreinterpretq_s32_s16(c1.val[1]));
-  const int32x4x2_t d2 = vtrnq_s32(vreinterpretq_s32_s16(c2.val[0]),
-                                   vreinterpretq_s32_s16(c3.val[0]));
-  const int32x4x2_t d3 = vtrnq_s32(vreinterpretq_s32_s16(c2.val[1]),
-                                   vreinterpretq_s32_s16(c3.val[1]));
-
-  // Swap 64 bit elements
-  const int16x8x2_t e0 = vpx_vtrnq_s64_to_s16(d0.val[0], d2.val[0]);
-  const int16x8x2_t e1 = vpx_vtrnq_s64_to_s16(d1.val[0], d3.val[0]);
-  const int16x8x2_t e2 = vpx_vtrnq_s64_to_s16(d0.val[1], d2.val[1]);
-  const int16x8x2_t e3 = vpx_vtrnq_s64_to_s16(d1.val[1], d3.val[1]);
-
-  b[0] = e0.val[0];
-  b[1] = e1.val[0];
-  b[2] = e2.val[0];
-  b[3] = e3.val[0];
-  b[4] = e0.val[1];
-  b[5] = e1.val[1];
-  b[6] = e2.val[1];
-  b[7] = e3.val[1];
-}
-
 void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
   int16x8_t temp0[32];
   int16x8_t temp1[32];
@@ -1337,10 +1172,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
   dct_body_first_pass(temp0, temp4);
 
   // Generate the top row by munging the first set of 8 from each one together.
-  transpose_8x8(&temp1[0], &temp0[0]);
-  transpose_8x8(&temp2[0], &temp0[8]);
-  transpose_8x8(&temp3[0], &temp0[16]);
-  transpose_8x8(&temp4[0], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[0], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[0], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[0], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[0], &temp0[24]);
 
   dct_body_second_pass(temp0, temp5);
 
@@ -1355,10 +1190,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
   store(output, temp5);
 
   // Second row of 8x32.
-  transpose_8x8(&temp1[8], &temp0[0]);
-  transpose_8x8(&temp2[8], &temp0[8]);
-  transpose_8x8(&temp3[8], &temp0[16]);
-  transpose_8x8(&temp4[8], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[8], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[8], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[8], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[8], &temp0[24]);
 
   dct_body_second_pass(temp0, temp5);
 
@@ -1373,10 +1208,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
   store(output + 8 * 32, temp5);
 
   // Third row of 8x32
-  transpose_8x8(&temp1[16], &temp0[0]);
-  transpose_8x8(&temp2[16], &temp0[8]);
-  transpose_8x8(&temp3[16], &temp0[16]);
-  transpose_8x8(&temp4[16], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[16], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[16], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[16], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[16], &temp0[24]);
 
   dct_body_second_pass(temp0, temp5);
 
@@ -1391,10 +1226,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
   store(output + 16 * 32, temp5);
 
   // Final row of 8x32.
-  transpose_8x8(&temp1[24], &temp0[0]);
-  transpose_8x8(&temp2[24], &temp0[8]);
-  transpose_8x8(&temp3[24], &temp0[16]);
-  transpose_8x8(&temp4[24], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[24], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[24], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[24], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[24], &temp0[24]);
 
   dct_body_second_pass(temp0, temp5);
 
@@ -1432,10 +1267,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
   dct_body_first_pass(temp0, temp4);
 
   // Generate the top row by munging the first set of 8 from each one together.
-  transpose_8x8(&temp1[0], &temp0[0]);
-  transpose_8x8(&temp2[0], &temp0[8]);
-  transpose_8x8(&temp3[0], &temp0[16]);
-  transpose_8x8(&temp4[0], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[0], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[0], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[0], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[0], &temp0[24]);
 
   dct_body_second_pass_rd(temp0, temp5);
 
@@ -1450,10 +1285,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
   store(output, temp5);
 
   // Second row of 8x32.
-  transpose_8x8(&temp1[8], &temp0[0]);
-  transpose_8x8(&temp2[8], &temp0[8]);
-  transpose_8x8(&temp3[8], &temp0[16]);
-  transpose_8x8(&temp4[8], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[8], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[8], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[8], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[8], &temp0[24]);
 
   dct_body_second_pass_rd(temp0, temp5);
 
@@ -1468,10 +1303,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
   store(output + 8 * 32, temp5);
 
   // Third row of 8x32
-  transpose_8x8(&temp1[16], &temp0[0]);
-  transpose_8x8(&temp2[16], &temp0[8]);
-  transpose_8x8(&temp3[16], &temp0[16]);
-  transpose_8x8(&temp4[16], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[16], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[16], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[16], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[16], &temp0[24]);
 
   dct_body_second_pass_rd(temp0, temp5);
 
@@ -1486,10 +1321,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
   store(output + 16 * 32, temp5);
 
   // Final row of 8x32.
-  transpose_8x8(&temp1[24], &temp0[0]);
-  transpose_8x8(&temp2[24], &temp0[8]);
-  transpose_8x8(&temp3[24], &temp0[16]);
-  transpose_8x8(&temp4[24], &temp0[24]);
+  transpose_s16_8x8_new(&temp1[24], &temp0[0]);
+  transpose_s16_8x8_new(&temp2[24], &temp0[8]);
+  transpose_s16_8x8_new(&temp3[24], &temp0[16]);
+  transpose_s16_8x8_new(&temp4[24], &temp0[24]);
 
   dct_body_second_pass_rd(temp0, temp5);
 
index 28d7d86bf8d0d2cac23514c85577b63d8a963e3c..056cae408395f7b4e17010f0e29fa6a6b46bb1ce 100644 (file)
 
 #include <arm_neon.h>
 
+// fdct_round_shift((a +/- b) * c)
+static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
+                                       const tran_high_t constant,
+                                       int16x8_t *add, int16x8_t *sub) {
+  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
+  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
+  const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
+  const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
+  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
+  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
+  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
+  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
+  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
+  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
+  *add = vcombine_s16(rounded0, rounded1);
+  *sub = vcombine_s16(rounded2, rounded3);
+}
+
+// fdct_round_shift(a * c0 +/- b * c1)
+static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
+                                       const tran_coef_t constant0,
+                                       const tran_coef_t constant1,
+                                       int16x8_t *add, int16x8_t *sub) {
+  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant0);
+  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant0);
+  const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), constant1);
+  const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), constant1);
+  const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), constant0);
+  const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), constant0);
+  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant1);
+  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant1);
+  const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
+  const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
+  const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
+  const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
+  *add = vcombine_s16(rounded0, rounded1);
+  *sub = vcombine_s16(rounded2, rounded3);
+}
+
+// Add 2 if positive, 1 if negative, and shift by 2.
+// In practice, subtract the sign bit, then shift with rounding.
+static INLINE int16x8_t sub_round_shift(const int16x8_t a) {
+  const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
+  const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
+  const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
+  return vrshrq_n_s16(vsubq_s16(a, a_sign_s16), 2);
+}
+
+// Like butterfly_one_coeff, but don't narrow results.
+static INLINE void butterfly_one_coeff_s16_s32(
+    const int16x8_t a, const int16x8_t b, const tran_high_t constant,
+    int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
+    int32x4_t *sub_hi) {
+  const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
+  const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
+  const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
+  const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
+  const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
+  const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
+  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Like butterfly_one_coeff, but with s32.
+static INLINE void butterfly_one_coeff_s32(
+    const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
+    const int32x4_t b_hi, const int32_t constant, int32x4_t *add_lo,
+    int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
+  const int32x4_t a_lo_0 = vmulq_n_s32(a_lo, constant);
+  const int32x4_t a_hi_0 = vmulq_n_s32(a_hi, constant);
+  const int32x4_t sum0 = vmlaq_n_s32(a_lo_0, b_lo, constant);
+  const int32x4_t sum1 = vmlaq_n_s32(a_hi_0, b_hi, constant);
+  const int32x4_t diff0 = vmlsq_n_s32(a_lo_0, b_lo, constant);
+  const int32x4_t diff1 = vmlsq_n_s32(a_hi_0, b_hi, constant);
+  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Like butterfly_two_coeff, but with s32.
+static INLINE void butterfly_two_coeff_s32(
+    const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
+    const int32x4_t b_hi, const int32_t constant0, const int32_t constant1,
+    int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
+    int32x4_t *sub_hi) {
+  const int32x4_t a0 = vmulq_n_s32(a_lo, constant0);
+  const int32x4_t a1 = vmulq_n_s32(a_hi, constant0);
+  const int32x4_t a2 = vmulq_n_s32(a_lo, constant1);
+  const int32x4_t a3 = vmulq_n_s32(a_hi, constant1);
+  const int32x4_t sum0 = vmlaq_n_s32(a2, b_lo, constant0);
+  const int32x4_t sum1 = vmlaq_n_s32(a3, b_hi, constant0);
+  const int32x4_t diff0 = vmlsq_n_s32(a0, b_lo, constant1);
+  const int32x4_t diff1 = vmlsq_n_s32(a1, b_hi, constant1);
+  *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+  *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+  *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+  *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Add 1 if positive, 2 if negative, and shift by 2.
+// In practice, add 1, then add the sign bit, then shift without rounding.
+static INLINE int16x8_t add_round_shift_s16(const int16x8_t a) {
+  const int16x8_t one = vdupq_n_s16(1);
+  const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
+  const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
+  const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
+  return vshrq_n_s16(vaddq_s16(vaddq_s16(a, a_sign_s16), one), 2);
+}
+
+// Add 1 if positive, 2 if negative, and shift by 2.
+// In practice, add 1, then add the sign bit, then shift without rounding.
+static INLINE int16x8_t add_round_shift_s32(const int32x4_t a_lo,
+                                            const int32x4_t a_hi) {
+  const int32x4_t one = vdupq_n_s32(1);
+  const uint32x4_t a_lo_u32 = vreinterpretq_u32_s32(a_lo);
+  const uint32x4_t a_lo_sign_u32 = vshrq_n_u32(a_lo_u32, 31);
+  const int32x4_t a_lo_sign_s32 = vreinterpretq_s32_u32(a_lo_sign_u32);
+  const int16x4_t b_lo =
+      vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_lo, a_lo_sign_s32), one), 2);
+  const uint32x4_t a_hi_u32 = vreinterpretq_u32_s32(a_hi);
+  const uint32x4_t a_hi_sign_u32 = vshrq_n_u32(a_hi_u32, 31);
+  const int32x4_t a_hi_sign_s32 = vreinterpretq_s32_u32(a_hi_sign_u32);
+  const int16x4_t b_hi =
+      vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_hi, a_hi_sign_s32), one), 2);
+  return vcombine_s16(b_lo, b_hi);
+}
+
 static INLINE void vpx_fdct4x4_pass1_neon(int16x4_t *in) {
   const int16x8_t input_01 = vcombine_s16(in[0], in[1]);
   const int16x8_t input_32 = vcombine_s16(in[3], in[2]);
index 32d21e03f7ba8109a67ef14a6e1d8477cf7638d0..1fd9495cf99c64e06423b88c0aabda93c537276f 100644 (file)
@@ -227,11 +227,11 @@ ifeq ($(VPX_ARCH_X86_64),yes)
 DSP_SRCS-$(HAVE_SSSE3)  += x86/fwd_txfm_ssse3_x86_64.asm
 endif
 DSP_SRCS-$(HAVE_AVX2)   += x86/fwd_dct32x32_impl_avx2.h
-DSP_SRCS-$(HAVE_NEON)   += arm/fdct_neon.c
+DSP_SRCS-$(HAVE_NEON)   += arm/fdct4x4_neon.c
+DSP_SRCS-$(HAVE_NEON)   += arm/fdct8x8_neon.c
 DSP_SRCS-$(HAVE_NEON)   += arm/fdct16x16_neon.c
 DSP_SRCS-$(HAVE_NEON)   += arm/fdct32x32_neon.c
 DSP_SRCS-$(HAVE_NEON)   += arm/fdct_partial_neon.c
-DSP_SRCS-$(HAVE_NEON)   += arm/fwd_txfm_neon.c
 DSP_SRCS-$(HAVE_MSA)    += mips/fwd_txfm_msa.h
 DSP_SRCS-$(HAVE_MSA)    += mips/fwd_txfm_msa.c
 DSP_SRCS-$(HAVE_LSX)    += loongarch/fwd_txfm_lsx.h