if (cpi->sf.use_transform_domain_distortion) {
// Transform domain distortion computation is more accurate as it does
// not involve an inverse transform, but it is less accurate.
- const int ss_txfrm_size = num_4x4_blocks_txsize_log2_lookup[tx_size];
+ const int buffer_length = tx_size_2d[tx_size];
int64_t this_sse;
int tx_type = get_tx_type(pd->plane_type, xd, block, tx_size);
int shift = (MAX_TX_SCALE - get_tx_scale(xd, tx_type, tx_size)) * 2;
tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
#if CONFIG_AOM_HIGHBITDEPTH
const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
- *out_dist = av1_highbd_block_error(coeff, dqcoeff, 16 << ss_txfrm_size,
- &this_sse, bd) >>
- shift;
-#else
*out_dist =
- av1_block_error(coeff, dqcoeff, 16 << ss_txfrm_size, &this_sse) >>
+ av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse, bd) >>
shift;
+#else
+ *out_dist =
+ av1_block_error(coeff, dqcoeff, buffer_length, &this_sse) >> shift;
#endif // CONFIG_AOM_HIGHBITDEPTH
*out_sse = this_sse >> shift;
} else {
const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
- const int bsw = 4 * num_4x4_blocks_wide_lookup[tx_bsize];
- const int bsh = 4 * num_4x4_blocks_high_lookup[tx_bsize];
+ const int bsw = block_size_wide[tx_bsize];
+ const int bsh = block_size_high[tx_bsize];
const int src_stride = x->plane[plane].src.stride;
const int dst_stride = xd->plane[plane].dst.stride;
- const int src_idx = 4 * (blk_row * src_stride + blk_col);
- const int dst_idx = 4 * (blk_row * dst_stride + blk_col);
+ // Scale the transform block index to pixel unit.
+ const int src_idx = (blk_row * src_stride + blk_col)
+ << tx_size_wide_log2[0];
+ const int dst_idx = (blk_row * dst_stride + blk_col)
+ << tx_size_wide_log2[0];
const uint8_t *src = &x->plane[plane].src.buf[src_idx];
const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
const tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
unsigned int tmp;
assert(cpi != NULL);
+ assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &tmp);
*out_sse = (int64_t)tmp * 16;