From: Yi Luo Date: Mon, 17 Oct 2016 18:18:50 +0000 (-0700) Subject: Fix the overflow of av1_fht32x32() in 2D DCT_DCT X-Git-Url: https://granicus.if.org/sourcecode?a=commitdiff_plain;h=157e45a44b3fdf46f7d086f966374984ffe08312;p=libvpx Fix the overflow of av1_fht32x32() in 2D DCT_DCT - Use range check function to avoid DCT_DCT overflow. We need to re-develop the column txfm side scaling/rounding. Now, we prefer to maintain the current BDRate level. - Encoder user level time reduction <1% owing to av1_fht32x32_avx2. - Add MemCheck unit test and fdct32() unit test. Change-Id: I1e67030f67bc637859798ebe2f6698afffb8531c --- diff --git a/av1/common/av1_rtcd_defs.pl b/av1/common/av1_rtcd_defs.pl index 3e1b02f24..d2a91dbed 100644 --- a/av1/common/av1_rtcd_defs.pl +++ b/av1/common/av1_rtcd_defs.pl @@ -391,6 +391,9 @@ specialize qw/av1_fht8x8 sse2/; add_proto qw/void av1_fht16x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; specialize qw/av1_fht16x16 sse2 avx2/; +add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; +specialize qw/av1_fht32x32 avx2/; + if (aom_config("CONFIG_EXT_TX") eq "yes") { add_proto qw/void av1_fht4x8/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; specialize qw/av1_fht4x8 sse2/; @@ -409,9 +412,6 @@ if (aom_config("CONFIG_EXT_TX") eq "yes") { add_proto qw/void av1_fht32x16/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; specialize qw/av1_fht32x16 sse2/; - - add_proto qw/void av1_fht32x32/, "const int16_t *input, tran_low_t *output, int stride, int tx_type"; - specialize qw/av1_fht32x32 avx2/; } if (aom_config("CONFIG_EMULATE_HARDWARE") eq "yes") { diff --git a/av1/encoder/dct.c b/av1/encoder/dct.c index 63b71a5bb..221e3cda9 100644 --- a/av1/encoder/dct.c +++ b/av1/encoder/dct.c @@ -325,7 +325,6 @@ static void fdct16(const tran_low_t *input, tran_low_t *output) { range_check(output, 16, 16); } -#if CONFIG_EXT_TX static void fdct32(const tran_low_t *input, tran_low_t *output) { tran_high_t temp; tran_low_t step[32]; @@ -723,7 +722,6 @@ static void fdct32(const tran_low_t *input, tran_low_t *output) { range_check(output, 32, 18); } -#endif // CONFIG_EXT_TX static void fadst4(const tran_low_t *input, tran_low_t *output) { tran_high_t x0, x1, x2, x3; @@ -1809,57 +1807,74 @@ void av1_highbd_fht16x16_c(const int16_t *input, tran_low_t *output, int stride, } #endif // CONFIG_AOM_HIGHBITDEPTH -#if CONFIG_EXT_TX +// TODO(luoyi): Adding this function to avoid DCT_DCT overflow. +// Remove this function after we scale the column txfm output correctly. +static INLINE int range_check_dct32x32(const int16_t *input, int16_t bound, + int size) { + int i; + for (i = 0; i < size; ++i) { + if (abs(input[i]) > bound) return 1; + } + return 0; +} + void av1_fht32x32_c(const int16_t *input, tran_low_t *output, int stride, int tx_type) { - if (tx_type == DCT_DCT) { - aom_fdct32x32_c(input, output, stride); - } else { - static const transform_2d FHT[] = { - { fdct32, fdct32 }, // DCT_DCT - { fhalfright32, fdct32 }, // ADST_DCT - { fdct32, fhalfright32 }, // DCT_ADST - { fhalfright32, fhalfright32 }, // ADST_ADST - { fhalfright32, fdct32 }, // FLIPADST_DCT - { fdct32, fhalfright32 }, // DCT_FLIPADST - { fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST - { fhalfright32, fhalfright32 }, // ADST_FLIPADST - { fhalfright32, fhalfright32 }, // FLIPADST_ADST - { fidtx32, fidtx32 }, // IDTX - { fdct32, fidtx32 }, // V_DCT - { fidtx32, fdct32 }, // H_DCT - { fhalfright32, fidtx32 }, // V_ADST - { fidtx32, fhalfright32 }, // H_ADST - { fhalfright32, fidtx32 }, // V_FLIPADST - { fidtx32, fhalfright32 }, // H_FLIPADST - }; - const transform_2d ht = FHT[tx_type]; - tran_low_t out[1024]; - int i, j; - tran_low_t temp_in[32], temp_out[32]; + static const transform_2d FHT[] = { + { fdct32, fdct32 }, // DCT_DCT +#if CONFIG_EXT_TX + { fhalfright32, fdct32 }, // ADST_DCT + { fdct32, fhalfright32 }, // DCT_ADST + { fhalfright32, fhalfright32 }, // ADST_ADST + { fhalfright32, fdct32 }, // FLIPADST_DCT + { fdct32, fhalfright32 }, // DCT_FLIPADST + { fhalfright32, fhalfright32 }, // FLIPADST_FLIPADST + { fhalfright32, fhalfright32 }, // ADST_FLIPADST + { fhalfright32, fhalfright32 }, // FLIPADST_ADST + { fidtx32, fidtx32 }, // IDTX + { fdct32, fidtx32 }, // V_DCT + { fidtx32, fdct32 }, // H_DCT + { fhalfright32, fidtx32 }, // V_ADST + { fidtx32, fhalfright32 }, // H_ADST + { fhalfright32, fidtx32 }, // V_FLIPADST + { fidtx32, fhalfright32 }, // H_FLIPADST +#endif + }; + const transform_2d ht = FHT[tx_type]; + tran_low_t out[1024]; + int i, j; + tran_low_t temp_in[32], temp_out[32]; - int16_t flipped_input[32 * 32]; - maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type); +#if CONFIG_EXT_TX + int16_t flipped_input[32 * 32]; + maybe_flip_input(&input, &stride, 32, 32, flipped_input, tx_type); +#endif - // Columns - for (i = 0; i < 32; ++i) { - for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4; - ht.cols(temp_in, temp_out); - for (j = 0; j < 32; ++j) - out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2; + if (DCT_DCT == tx_type) { + if (range_check_dct32x32(input, (1 << 6) - 1, 1 << 10)) { + aom_fdct32x32_c(input, output, stride); + return; } + } + // Columns + for (i = 0; i < 32; ++i) { + for (j = 0; j < 32; ++j) temp_in[j] = input[j * stride + i] * 4; + ht.cols(temp_in, temp_out); + for (j = 0; j < 32; ++j) + out[j * 32 + i] = (temp_out[j] + 1 + (temp_out[j] > 0)) >> 2; + } - // Rows - for (i = 0; i < 32; ++i) { - for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32]; - ht.rows(temp_in, temp_out); - for (j = 0; j < 32; ++j) - output[j + i * 32] = - (tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2); - } + // Rows + for (i = 0; i < 32; ++i) { + for (j = 0; j < 32; ++j) temp_in[j] = out[j + i * 32]; + ht.rows(temp_in, temp_out); + for (j = 0; j < 32; ++j) + output[j + i * 32] = + (tran_low_t)((temp_out[j] + 1 + (temp_out[j] < 0)) >> 2); } } +#if CONFIG_EXT_TX // Forward identity transform. void av1_fwd_idtx_c(const int16_t *src_diff, tran_low_t *coeff, int stride, int bs, int tx_type) { diff --git a/av1/encoder/hybrid_fwd_txfm.c b/av1/encoder/hybrid_fwd_txfm.c index 1103c4b46..6d5eccd06 100644 --- a/av1/encoder/hybrid_fwd_txfm.c +++ b/av1/encoder/hybrid_fwd_txfm.c @@ -21,7 +21,7 @@ static INLINE void fdct32x32(int rd_transform, const int16_t *src, if (rd_transform) aom_fdct32x32_rd(src, dst, src_stride); else - aom_fdct32x32(src, dst, src_stride); + av1_fht32x32(src, dst, src_stride, DCT_DCT); } static void fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff, diff --git a/av1/encoder/x86/hybrid_fwd_txfm_avx2.c b/av1/encoder/x86/hybrid_fwd_txfm_avx2.c index 69bf89af6..928af1355 100644 --- a/av1/encoder/x86/hybrid_fwd_txfm_avx2.c +++ b/av1/encoder/x86/hybrid_fwd_txfm_avx2.c @@ -198,8 +198,8 @@ static void mm256_transpose_16x16(__m256i *in) { in[15] = _mm256_permute2x128_si256(tr0_7, tr0_f, 0x31); } -static void load_buffer_16x16(const int16_t *input, int stride, int flipud, - int fliplr, __m256i *in) { +static INLINE void load_buffer_16x16(const int16_t *input, int stride, + int flipud, int fliplr, __m256i *in) { if (!flipud) { in[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * stride)); in[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * stride)); @@ -1273,7 +1273,6 @@ void aom_fdct32x32_1_avx2(const int16_t *input, tran_low_t *output, _mm256_zeroupper(); } -#if CONFIG_EXT_TX static void mm256_vectors_swap(__m256i *a0, __m256i *a1, const int size) { int i = 0; __m256i temp; @@ -1622,7 +1621,6 @@ static void fdct32_avx2(__m256i *in0, __m256i *in1) { mm256_transpose_32x32(in0, in1); } -#endif // CONFIG_EXT_TX static INLINE void write_buffer_32x32(const __m256i *in0, const __m256i *in1, int stride, tran_low_t *output) { @@ -1667,9 +1665,11 @@ static void fhalfright32_avx2(__m256i *in0, __m256i *in1) { mm256_vectors_swap(in1, &in1[16], 16); mm256_transpose_32x32(in0, in1); } +#endif // CONFIG_EXT_TX -static void load_buffer_32x32(const int16_t *input, int stride, int flipud, - int fliplr, __m256i *in0, __m256i *in1) { +static INLINE void load_buffer_32x32(const int16_t *input, int stride, + int flipud, int fliplr, __m256i *in0, + __m256i *in1) { // Load 4 16x16 blocks const int16_t *topL = input; const int16_t *topR = input + 16; @@ -1708,7 +1708,6 @@ static void load_buffer_32x32(const int16_t *input, int stride, int flipud, load_buffer_16x16(topR, stride, flipud, fliplr, in1); load_buffer_16x16(botR, stride, flipud, fliplr, in1 + 16); } -#endif // CONFIG_EXT_TX static void nr_right_shift_32x32_16col(__m256i *in) { int i = 0; @@ -1729,8 +1728,7 @@ static void nr_right_shift_32x32(__m256i *in0, __m256i *in1) { nr_right_shift_32x32_16col(in1); } -#if CONFIG_EXT_TX -static void pr_right_shift_32x32_16col(__m256i *in) { +static INLINE void pr_right_shift_32x32_16col(__m256i *in) { int i = 0; const __m256i zero = _mm256_setzero_si256(); const __m256i one = _mm256_set1_epi16(1); @@ -1745,11 +1743,12 @@ static void pr_right_shift_32x32_16col(__m256i *in) { } // Positive rounding -static void pr_right_shift_32x32(__m256i *in0, __m256i *in1) { +static INLINE void pr_right_shift_32x32(__m256i *in0, __m256i *in1) { pr_right_shift_32x32_16col(in0); pr_right_shift_32x32_16col(in1); } +#if CONFIG_EXT_TX static void fidtx32_avx2(__m256i *in0, __m256i *in1) { int i = 0; while (i < 32) { @@ -1761,23 +1760,42 @@ static void fidtx32_avx2(__m256i *in0, __m256i *in1) { } #endif +static INLINE int range_check_dct32x32(const __m256i *in0, const __m256i *in1, + int row) { + __m256i value, bits0, bits1; + const __m256i bound = _mm256_set1_epi16((1 << 6) - 1); + int flag; + int i = 0; + + while (i < row) { + value = _mm256_abs_epi16(in0[i]); + bits0 = _mm256_cmpgt_epi16(value, bound); + value = _mm256_abs_epi16(in1[i]); + bits1 = _mm256_cmpgt_epi16(value, bound); + bits0 = _mm256_or_si256(bits0, bits1); + flag = _mm256_movemask_epi8(bits0); + if (flag) return 1; + i++; + } + return 0; +} + void av1_fht32x32_avx2(const int16_t *input, tran_low_t *output, int stride, int tx_type) { __m256i in0[32]; // left 32 columns __m256i in1[32]; // right 32 columns - (void)input; - (void)stride; switch (tx_type) { -// TODO(luoyi): For DCT_DCT, fwd_txfm_32x32() uses aom set. But this -// function has better speed. The replacement must work with the -// corresponding inverse transform. -// case DCT_DCT: -// load_buffer_32x32(input, stride, 0, 0, in0, in1); -// fdct32_avx2(in0, in1); -// pr_right_shift_32x32(in0, in1); -// fdct32_avx2(in0, in1); -// break; + case DCT_DCT: + load_buffer_32x32(input, stride, 0, 0, in0, in1); + if (range_check_dct32x32(in0, in1, 32)) { + aom_fdct32x32_avx2(input, output, stride); + return; + } + fdct32_avx2(in0, in1); + pr_right_shift_32x32(in0, in1); + fdct32_avx2(in0, in1); + break; #if CONFIG_EXT_TX case ADST_DCT: load_buffer_32x32(input, stride, 0, 0, in0, in1); diff --git a/test/av1_dct_test.cc b/test/av1_dct_test.cc index ac1a55196..d5c23f6ce 100644 --- a/test/av1_dct_test.cc +++ b/test/av1_dct_test.cc @@ -102,5 +102,6 @@ INSTANTIATE_TEST_CASE_P( C, AV1FwdTxfm, ::testing::Values(FdctParam(&fdct4, &reference_dct_1d, 4, 1), FdctParam(&fdct8, &reference_dct_1d, 8, 1), - FdctParam(&fdct16, &reference_dct_1d, 16, 2))); + FdctParam(&fdct16, &reference_dct_1d, 16, 2), + FdctParam(&fdct32, &reference_dct_1d, 32, 3))); } // namespace diff --git a/test/fht32x32_test.cc b/test/fht32x32_test.cc index a949ebf7a..3d07b44da 100644 --- a/test/fht32x32_test.cc +++ b/test/fht32x32_test.cc @@ -69,6 +69,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase, inv_txfm_ = GET_PARAM(1); tx_type_ = GET_PARAM(2); pitch_ = 32; + height_ = 32; fwd_txfm_ref = fht32x32_ref; bit_depth_ = GET_PARAM(3); mask_ = (1 << bit_depth_) - 1; @@ -90,6 +91,7 @@ class AV1Trans32x32HT : public libaom_test::TransformTestBase, }; TEST_P(AV1Trans32x32HT, CoeffCheck) { RunCoeffCheck(); } +TEST_P(AV1Trans32x32HT, MemCheck) { RunMemCheck(); } #if CONFIG_AOM_HIGHBITDEPTH class AV1HighbdTrans32x32HT @@ -164,8 +166,7 @@ using std::tr1::make_tuple; #if HAVE_AVX2 const Ht32x32Param kArrayHt32x32Param_avx2[] = { - // TODO(luoyi): DCT_DCT tx_type is not enabled in av1_fht32x32_c(avx2) yet. - // make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024), + make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 0, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 1, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 2, AOM_BITS_8, 1024), make_tuple(&av1_fht32x32_avx2, dummy_inv_txfm, 3, AOM_BITS_8, 1024),