From 0f80b1f754adae3af799586048ce954b1111cc52 Mon Sep 17 00:00:00 2001 From: Yi Luo Date: Mon, 11 Apr 2016 10:49:43 -0700 Subject: [PATCH] Optimized HBD block subtraction for all block sizes - Interface function takes a local MxN function to call based on the block size. - Repetition call (w/o cache line miss) shows improvement: ~63% - ~340%. - Overall encoder speed improvement: ~0.9%. Change-Id: Ieff8f3d192415c61d6d58d8b99bb2a722004823f --- test/subtract_test.cc | 151 ++++++++++++ vpx_dsp/vpx_dsp.mk | 5 + vpx_dsp/vpx_dsp_rtcd_defs.pl | 6 +- vpx_dsp/x86/highbd_subtract_sse2.c | 366 +++++++++++++++++++++++++++++ 4 files changed, 524 insertions(+), 4 deletions(-) create mode 100644 vpx_dsp/x86/highbd_subtract_sse2.c diff --git a/test/subtract_test.cc b/test/subtract_test.cc index a3f015277..48edf1e6e 100644 --- a/test/subtract_test.cc +++ b/test/subtract_test.cc @@ -15,12 +15,16 @@ #include "test/acm_random.h" #include "test/clear_system_state.h" #include "test/register_state_check.h" +#include "test/util.h" #if CONFIG_VP10 #include "vp10/common/blockd.h" #elif CONFIG_VP9 #include "vp9/common/vp9_blockd.h" #endif #include "vpx_mem/vpx_mem.h" +#include "vpx_ports/mem.h" + +#define USE_SPEED_TEST (0) typedef void (*SubtractFunc)(int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, @@ -108,4 +112,151 @@ INSTANTIATE_TEST_CASE_P(NEON, VP9SubtractBlockTest, INSTANTIATE_TEST_CASE_P(MSA, VP9SubtractBlockTest, ::testing::Values(vpx_subtract_block_msa)); #endif + +typedef void (*HBDSubtractFunc)(int rows, int cols, + int16_t *diff_ptr, ptrdiff_t diff_stride, + const uint8_t *src_ptr, ptrdiff_t src_stride, + const uint8_t *pred_ptr, ptrdiff_t pred_stride, + int bd); + +using ::std::tr1::get; +using ::std::tr1::make_tuple; +using ::std::tr1::tuple; + +// +typedef tuple Params; + +#if CONFIG_VP9_HIGHBITDEPTH +class VP10HBDSubtractBlockTest : public ::testing::TestWithParam { + public: + virtual void SetUp() { + block_width_ = GET_PARAM(0); + block_height_ = GET_PARAM(1); + bit_depth_ = static_cast(GET_PARAM(2)); + func_ = GET_PARAM(3); + + rnd_.Reset(ACMRandom::DeterministicSeed()); + + const size_t max_width = 128; + const size_t max_block_size = max_width * max_width; + src_ = CONVERT_TO_BYTEPTR(reinterpret_cast( + vpx_memalign(16, max_block_size * sizeof(uint16_t)))); + pred_ = CONVERT_TO_BYTEPTR(reinterpret_cast( + vpx_memalign(16, max_block_size * sizeof(uint16_t)))); + diff_ = reinterpret_cast( + vpx_memalign(16, max_block_size * sizeof(int16_t))); + } + + virtual void TearDown() { + vpx_free(CONVERT_TO_SHORTPTR(src_)); + vpx_free(CONVERT_TO_SHORTPTR(pred_)); + vpx_free(diff_); + } + + protected: + void RunForSpeed(); + void CheckResult(); + + private: + ACMRandom rnd_; + int block_height_; + int block_width_; + vpx_bit_depth_t bit_depth_; + HBDSubtractFunc func_; + uint8_t *src_; + uint8_t *pred_; + int16_t *diff_; +}; + +void VP10HBDSubtractBlockTest::RunForSpeed() { + const int test_num = 200000; + const int max_width = 128; + const int max_block_size = max_width * max_width; + const int mask = (1 << bit_depth_) - 1; + int i, j; + + for (j = 0; j < max_block_size; ++j) { + CONVERT_TO_SHORTPTR(src_)[j] = rnd_.Rand16() & mask; + CONVERT_TO_SHORTPTR(pred_)[j] = rnd_.Rand16() & mask; + } + + for (i = 0; i < test_num; ++i) { + func_(block_height_, block_width_, diff_, block_width_, + src_, block_width_, pred_, block_width_, bit_depth_); + } +} + +void VP10HBDSubtractBlockTest::CheckResult() { + const int test_num = 100; + const int max_width = 128; + const int max_block_size = max_width * max_width; + const int mask = (1 << bit_depth_) - 1; + int i, j; + + for (i = 0; i < test_num; ++i) { + for (j = 0; j < max_block_size; ++j) { + CONVERT_TO_SHORTPTR(src_)[j] = rnd_.Rand16() & mask; + CONVERT_TO_SHORTPTR(pred_)[j] = rnd_.Rand16() & mask; + } + + func_(block_height_, block_width_, diff_, block_width_, + src_, block_width_, pred_, block_width_, bit_depth_); + + for (int r = 0; r < block_height_; ++r) { + for (int c = 0; c < block_width_; ++c) { + EXPECT_EQ(diff_[r * block_width_ + c], + (CONVERT_TO_SHORTPTR(src_)[r * block_width_ + c] - + CONVERT_TO_SHORTPTR(pred_)[r * block_width_ + c])) + << "r = " << r << ", c = " << c << ", test: " << i; + } + } + } +} + +TEST_P(VP10HBDSubtractBlockTest, CheckResult) { + CheckResult(); +} + +#if USE_SPEED_TEST +TEST_P(VP10HBDSubtractBlockTest, CheckSpeed) { + RunForSpeed(); +} +#endif // USE_SPEED_TEST + +#if HAVE_SSE2 +INSTANTIATE_TEST_CASE_P(SSE2, VP10HBDSubtractBlockTest, ::testing::Values( + make_tuple(4, 4, 12, vpx_highbd_subtract_block_sse2), + make_tuple(4, 4, 12, vpx_highbd_subtract_block_c), + make_tuple(4, 8, 12, vpx_highbd_subtract_block_sse2), + make_tuple(4, 8, 12, vpx_highbd_subtract_block_c), + make_tuple(8, 4, 12, vpx_highbd_subtract_block_sse2), + make_tuple(8, 4, 12, vpx_highbd_subtract_block_c), + make_tuple(8, 8, 12, vpx_highbd_subtract_block_sse2), + make_tuple(8, 8, 12, vpx_highbd_subtract_block_c), + make_tuple(8, 16, 12, vpx_highbd_subtract_block_sse2), + make_tuple(8, 16, 12, vpx_highbd_subtract_block_c), + make_tuple(16, 8, 12, vpx_highbd_subtract_block_sse2), + make_tuple(16, 8, 12, vpx_highbd_subtract_block_c), + make_tuple(16, 16, 12, vpx_highbd_subtract_block_sse2), + make_tuple(16, 16, 12, vpx_highbd_subtract_block_c), + make_tuple(16, 32, 12, vpx_highbd_subtract_block_sse2), + make_tuple(16, 32, 12, vpx_highbd_subtract_block_c), + make_tuple(32, 16, 12, vpx_highbd_subtract_block_sse2), + make_tuple(32, 16, 12, vpx_highbd_subtract_block_c), + make_tuple(32, 32, 12, vpx_highbd_subtract_block_sse2), + make_tuple(32, 32, 12, vpx_highbd_subtract_block_c), + make_tuple(32, 64, 12, vpx_highbd_subtract_block_sse2), + make_tuple(32, 64, 12, vpx_highbd_subtract_block_c), + make_tuple(64, 32, 12, vpx_highbd_subtract_block_sse2), + make_tuple(64, 32, 12, vpx_highbd_subtract_block_c), + make_tuple(64, 64, 12, vpx_highbd_subtract_block_sse2), + make_tuple(64, 64, 12, vpx_highbd_subtract_block_c), + make_tuple(64, 128, 12, vpx_highbd_subtract_block_sse2), + make_tuple(64, 128, 12, vpx_highbd_subtract_block_c), + make_tuple(128, 64, 12, vpx_highbd_subtract_block_sse2), + make_tuple(128, 64, 12, vpx_highbd_subtract_block_c), + make_tuple(128, 128, 12, vpx_highbd_subtract_block_sse2), + make_tuple(128, 128, 12, vpx_highbd_subtract_block_c))); +#endif // HAVE_SSE2 +#endif // CONFIG_VP9_HIGHBITDEPTH } // namespace diff --git a/vpx_dsp/vpx_dsp.mk b/vpx_dsp/vpx_dsp.mk index e37184965..2f430334e 100644 --- a/vpx_dsp/vpx_dsp.mk +++ b/vpx_dsp/vpx_dsp.mk @@ -266,6 +266,11 @@ DSP_SRCS-$(HAVE_SSSE3) += x86/avg_ssse3_x86_64.asm endif endif +# high bit depth subtract +ifeq ($(CONFIG_VP9_HIGHBITDEPTH),yes) +DSP_SRCS-$(HAVE_SSE2) += x86/highbd_subtract_sse2.c +endif + endif # CONFIG_VP9_ENCODER || CONFIG_VP10_ENCODER ifeq ($(CONFIG_VP10_ENCODER),yes) diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index d01e81d0c..4ff348d6e 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -965,10 +965,6 @@ if (vpx_config("CONFIG_ENCODERS") eq "yes") { # add_proto qw/void vpx_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src_ptr, ptrdiff_t src_stride, const uint8_t *pred_ptr, ptrdiff_t pred_stride"; specialize qw/vpx_subtract_block neon msa/, "$sse2_x86inc"; -if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") { - add_proto qw/void vpx_highbd_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src_ptr, ptrdiff_t src_stride, const uint8_t *pred_ptr, ptrdiff_t pred_stride, int bd"; - specialize qw/vpx_highbd_subtract_block/; -} if (vpx_config("CONFIG_VP10_ENCODER") eq "yes") { # @@ -991,6 +987,8 @@ if ((vpx_config("CONFIG_VP9_ENCODER") eq "yes") || (vpx_config("CONFIG_VP10_ENCO specialize qw/vpx_highbd_avg_8x8/; add_proto qw/unsigned int vpx_highbd_avg_4x4/, "const uint8_t *, int p"; specialize qw/vpx_highbd_avg_4x4/; + add_proto qw/void vpx_highbd_subtract_block/, "int rows, int cols, int16_t *diff_ptr, ptrdiff_t diff_stride, const uint8_t *src_ptr, ptrdiff_t src_stride, const uint8_t *pred_ptr, ptrdiff_t pred_stride, int bd"; + specialize qw/vpx_highbd_subtract_block sse2/; } # diff --git a/vpx_dsp/x86/highbd_subtract_sse2.c b/vpx_dsp/x86/highbd_subtract_sse2.c new file mode 100644 index 000000000..33e464b78 --- /dev/null +++ b/vpx_dsp/x86/highbd_subtract_sse2.c @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2016 The WebM project authors. All Rights Reserved. + * + * Use of this source code is governed by a BSD-style license + * that can be found in the LICENSE file in the root of the source + * tree. An additional intellectual property rights grant can be found + * in the file PATENTS. All contributing project authors may + * be found in the AUTHORS file in the root of the source tree. + */ + +#include +#include +#include + +#include "./vpx_config.h" +#include "./vpx_dsp_rtcd.h" + +typedef void (*SubtractWxHFuncType)( + int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride); + +static void subtract_4x4(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + __m128i u0, u1, u2, u3; + __m128i v0, v1, v2, v3; + __m128i x0, x1, x2, x3; + int64_t *store_diff = (int64_t *) (diff + 0 * diff_stride); + + u0 = _mm_loadu_si128((__m128i const *) (src + 0 * src_stride)); + u1 = _mm_loadu_si128((__m128i const *) (src + 1 * src_stride)); + u2 = _mm_loadu_si128((__m128i const *) (src + 2 * src_stride)); + u3 = _mm_loadu_si128((__m128i const *) (src + 3 * src_stride)); + + v0 = _mm_loadu_si128((__m128i const *) (pred + 0 * pred_stride)); + v1 = _mm_loadu_si128((__m128i const *) (pred + 1 * pred_stride)); + v2 = _mm_loadu_si128((__m128i const *) (pred + 2 * pred_stride)); + v3 = _mm_loadu_si128((__m128i const *) (pred + 3 * pred_stride)); + + x0 = _mm_sub_epi16(u0, v0); + x1 = _mm_sub_epi16(u1, v1); + x2 = _mm_sub_epi16(u2, v2); + x3 = _mm_sub_epi16(u3, v3); + + _mm_storel_epi64((__m128i *)store_diff, x0); + store_diff = (int64_t *) (diff + 1 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x1); + store_diff = (int64_t *) (diff + 2 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x2); + store_diff = (int64_t *) (diff + 3 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x3); +} + +static void subtract_4x8(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + __m128i u0, u1, u2, u3, u4, u5, u6, u7; + __m128i v0, v1, v2, v3, v4, v5, v6, v7; + __m128i x0, x1, x2, x3, x4, x5, x6, x7; + int64_t *store_diff = (int64_t *) (diff + 0 * diff_stride); + + u0 = _mm_loadu_si128((__m128i const *) (src + 0 * src_stride)); + u1 = _mm_loadu_si128((__m128i const *) (src + 1 * src_stride)); + u2 = _mm_loadu_si128((__m128i const *) (src + 2 * src_stride)); + u3 = _mm_loadu_si128((__m128i const *) (src + 3 * src_stride)); + u4 = _mm_loadu_si128((__m128i const *) (src + 4 * src_stride)); + u5 = _mm_loadu_si128((__m128i const *) (src + 5 * src_stride)); + u6 = _mm_loadu_si128((__m128i const *) (src + 6 * src_stride)); + u7 = _mm_loadu_si128((__m128i const *) (src + 7 * src_stride)); + + v0 = _mm_loadu_si128((__m128i const *) (pred + 0 * pred_stride)); + v1 = _mm_loadu_si128((__m128i const *) (pred + 1 * pred_stride)); + v2 = _mm_loadu_si128((__m128i const *) (pred + 2 * pred_stride)); + v3 = _mm_loadu_si128((__m128i const *) (pred + 3 * pred_stride)); + v4 = _mm_loadu_si128((__m128i const *) (pred + 4 * pred_stride)); + v5 = _mm_loadu_si128((__m128i const *) (pred + 5 * pred_stride)); + v6 = _mm_loadu_si128((__m128i const *) (pred + 6 * pred_stride)); + v7 = _mm_loadu_si128((__m128i const *) (pred + 7 * pred_stride)); + + x0 = _mm_sub_epi16(u0, v0); + x1 = _mm_sub_epi16(u1, v1); + x2 = _mm_sub_epi16(u2, v2); + x3 = _mm_sub_epi16(u3, v3); + x4 = _mm_sub_epi16(u4, v4); + x5 = _mm_sub_epi16(u5, v5); + x6 = _mm_sub_epi16(u6, v6); + x7 = _mm_sub_epi16(u7, v7); + + _mm_storel_epi64((__m128i *)store_diff, x0); + store_diff = (int64_t *) (diff + 1 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x1); + store_diff = (int64_t *) (diff + 2 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x2); + store_diff = (int64_t *) (diff + 3 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x3); + store_diff = (int64_t *) (diff + 4 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x4); + store_diff = (int64_t *) (diff + 5 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x5); + store_diff = (int64_t *) (diff + 6 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x6); + store_diff = (int64_t *) (diff + 7 * diff_stride); + _mm_storel_epi64((__m128i *)store_diff, x7); +} + +static void subtract_8x4(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + __m128i u0, u1, u2, u3; + __m128i v0, v1, v2, v3; + __m128i x0, x1, x2, x3; + + u0 = _mm_loadu_si128((__m128i const *) (src + 0 * src_stride)); + u1 = _mm_loadu_si128((__m128i const *) (src + 1 * src_stride)); + u2 = _mm_loadu_si128((__m128i const *) (src + 2 * src_stride)); + u3 = _mm_loadu_si128((__m128i const *) (src + 3 * src_stride)); + + v0 = _mm_loadu_si128((__m128i const *) (pred + 0 * pred_stride)); + v1 = _mm_loadu_si128((__m128i const *) (pred + 1 * pred_stride)); + v2 = _mm_loadu_si128((__m128i const *) (pred + 2 * pred_stride)); + v3 = _mm_loadu_si128((__m128i const *) (pred + 3 * pred_stride)); + + x0 = _mm_sub_epi16(u0, v0); + x1 = _mm_sub_epi16(u1, v1); + x2 = _mm_sub_epi16(u2, v2); + x3 = _mm_sub_epi16(u3, v3); + + _mm_storeu_si128((__m128i *) (diff + 0 * diff_stride), x0); + _mm_storeu_si128((__m128i *) (diff + 1 * diff_stride), x1); + _mm_storeu_si128((__m128i *) (diff + 2 * diff_stride), x2); + _mm_storeu_si128((__m128i *) (diff + 3 * diff_stride), x3); +} + +static void subtract_8x8(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + __m128i u0, u1, u2, u3, u4, u5, u6, u7; + __m128i v0, v1, v2, v3, v4, v5, v6, v7; + __m128i x0, x1, x2, x3, x4, x5, x6, x7; + + u0 = _mm_loadu_si128((__m128i const *) (src + 0 * src_stride)); + u1 = _mm_loadu_si128((__m128i const *) (src + 1 * src_stride)); + u2 = _mm_loadu_si128((__m128i const *) (src + 2 * src_stride)); + u3 = _mm_loadu_si128((__m128i const *) (src + 3 * src_stride)); + u4 = _mm_loadu_si128((__m128i const *) (src + 4 * src_stride)); + u5 = _mm_loadu_si128((__m128i const *) (src + 5 * src_stride)); + u6 = _mm_loadu_si128((__m128i const *) (src + 6 * src_stride)); + u7 = _mm_loadu_si128((__m128i const *) (src + 7 * src_stride)); + + v0 = _mm_loadu_si128((__m128i const *) (pred + 0 * pred_stride)); + v1 = _mm_loadu_si128((__m128i const *) (pred + 1 * pred_stride)); + v2 = _mm_loadu_si128((__m128i const *) (pred + 2 * pred_stride)); + v3 = _mm_loadu_si128((__m128i const *) (pred + 3 * pred_stride)); + v4 = _mm_loadu_si128((__m128i const *) (pred + 4 * pred_stride)); + v5 = _mm_loadu_si128((__m128i const *) (pred + 5 * pred_stride)); + v6 = _mm_loadu_si128((__m128i const *) (pred + 6 * pred_stride)); + v7 = _mm_loadu_si128((__m128i const *) (pred + 7 * pred_stride)); + + x0 = _mm_sub_epi16(u0, v0); + x1 = _mm_sub_epi16(u1, v1); + x2 = _mm_sub_epi16(u2, v2); + x3 = _mm_sub_epi16(u3, v3); + x4 = _mm_sub_epi16(u4, v4); + x5 = _mm_sub_epi16(u5, v5); + x6 = _mm_sub_epi16(u6, v6); + x7 = _mm_sub_epi16(u7, v7); + + _mm_storeu_si128((__m128i *) (diff + 0 * diff_stride), x0); + _mm_storeu_si128((__m128i *) (diff + 1 * diff_stride), x1); + _mm_storeu_si128((__m128i *) (diff + 2 * diff_stride), x2); + _mm_storeu_si128((__m128i *) (diff + 3 * diff_stride), x3); + _mm_storeu_si128((__m128i *) (diff + 4 * diff_stride), x4); + _mm_storeu_si128((__m128i *) (diff + 5 * diff_stride), x5); + _mm_storeu_si128((__m128i *) (diff + 6 * diff_stride), x6); + _mm_storeu_si128((__m128i *) (diff + 7 * diff_stride), x7); +} + +static void subtract_8x16(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 3; + src += src_stride << 3; + pred += pred_stride << 3; + subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_16x8(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += 8; + src += 8; + pred += 8; + subtract_8x8(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_16x16(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_16x8(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 3; + src += src_stride << 3; + pred += pred_stride << 3; + subtract_16x8(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_16x32(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 4; + src += src_stride << 4; + pred += pred_stride << 4; + subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_32x16(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += 16; + src += 16; + pred += 16; + subtract_16x16(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_32x32(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_32x16(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 4; + src += src_stride << 4; + pred += pred_stride << 4; + subtract_32x16(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_32x64(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 5; + src += src_stride << 5; + pred += pred_stride << 5; + subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_64x32(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += 32; + src += 32; + pred += 32; + subtract_32x32(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_64x64(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_64x32(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 5; + src += src_stride << 5; + pred += pred_stride << 5; + subtract_64x32(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_64x128(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 6; + src += src_stride << 6; + pred += pred_stride << 6; + subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_128x64(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += 64; + src += 64; + pred += 64; + subtract_64x64(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static void subtract_128x128(int16_t *diff, ptrdiff_t diff_stride, + const uint16_t *src, ptrdiff_t src_stride, + const uint16_t *pred, ptrdiff_t pred_stride) { + subtract_128x64(diff, diff_stride, src, src_stride, pred, pred_stride); + diff += diff_stride << 6; + src += src_stride << 6; + pred += pred_stride << 6; + subtract_128x64(diff, diff_stride, src, src_stride, pred, pred_stride); +} + +static SubtractWxHFuncType getSubtractFunc(int rows, int cols) { + SubtractWxHFuncType ret_func_ptr = NULL; + if (rows == 4) { + if (cols == 4) { + ret_func_ptr = subtract_4x4; + } else if (cols == 8) { + ret_func_ptr = subtract_8x4; + } + } else if (rows == 8) { + if (cols == 4) { + ret_func_ptr = subtract_4x8; + } else if (cols == 8) { + ret_func_ptr = subtract_8x8; + } else if (cols == 16) { + ret_func_ptr = subtract_16x8; + } + } else if (rows == 16) { + if (cols == 8) { + ret_func_ptr = subtract_8x16; + } else if (cols == 16) { + ret_func_ptr = subtract_16x16; + } else if (cols == 32) { + ret_func_ptr = subtract_32x16; + } + } else if (rows == 32) { + if (cols == 16) { + ret_func_ptr = subtract_16x32; + } else if (cols == 32) { + ret_func_ptr = subtract_32x32; + } else if (cols == 64) { + ret_func_ptr = subtract_64x32; + } + } else if (rows == 64) { + if (cols == 32) { + ret_func_ptr = subtract_32x64; + } else if (cols == 64) { + ret_func_ptr = subtract_64x64; + } else if (cols == 128) { + ret_func_ptr = subtract_128x64; + } + } else if (rows == 128) { + if (cols == 64) { + ret_func_ptr = subtract_64x128; + } else if (cols == 128) { + ret_func_ptr = subtract_128x128; + } + } + if (!ret_func_ptr) { + assert(0); + } + return ret_func_ptr; +} + +void vpx_highbd_subtract_block_sse2( + int rows, int cols, + int16_t *diff, ptrdiff_t diff_stride, + const uint8_t *src8, ptrdiff_t src_stride, + const uint8_t *pred8, + ptrdiff_t pred_stride, + int bd) { + uint16_t *src = CONVERT_TO_SHORTPTR(src8); + uint16_t *pred = CONVERT_TO_SHORTPTR(pred8); + SubtractWxHFuncType func; + (void) bd; + + func = getSubtractFunc(rows, cols); + func(diff, diff_stride, src, src_stride, pred, pred_stride); +} -- 2.40.0