]> granicus.if.org Git - libvpx/commitdiff
Account for context information for partition rate estimate
authorJingning Han <jingning@google.com>
Tue, 28 Apr 2015 19:16:51 +0000 (12:16 -0700)
committerJingning Han <jingning@google.com>
Tue, 9 Jun 2015 22:53:55 +0000 (15:53 -0700)
This commit allows the encoder to account for the boundary block
information to estimate the transform block partitiion rate cost
in the rate-distortion optimization scheme.

Change-Id: Idb79cf936d96cdd15bcba27e47318295413a5f5d

vp9/decoder/vp9_decodemv.c
vp9/encoder/vp9_bitstream.c
vp9/encoder/vp9_encodeframe.c
vp9/encoder/vp9_rdopt.c

index efb730ca683392f5841e8e04572032d7274c6dac..60c8b467582afb0e2a24856c326dac14bad03826 100644 (file)
@@ -627,6 +627,8 @@ static void read_inter_frame_mode_info(VP9Decoder *const pbi,
   mbmi->segment_id = read_inter_segment_id(cm, xd, mi_row, mi_col, r);
   mbmi->skip = read_skip(cm, xd, counts, mbmi->segment_id, r);
   inter_block = read_is_inter_block(cm, xd, counts, mbmi->segment_id, r);
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
 
   if (mbmi->sb_type >= BLOCK_8X8 && cm->tx_mode == TX_MODE_SELECT &&
       !mbmi->skip && inter_block) {
@@ -635,8 +637,6 @@ static void read_inter_frame_mode_info(VP9Decoder *const pbi,
     int width  = num_4x4_blocks_wide_lookup[bsize];
     int height = num_4x4_blocks_high_lookup[bsize];
     int idx, idy;
-    xd->above_txfm_context = cm->above_txfm_context + mi_col;
-    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
     for (idy = 0; idy < height; idy += bh)
       for (idx = 0; idx < width; idx += bh)
         read_tx_size_inter(cm, xd, counts, max_txsize_lookup[mbmi->sb_type],
@@ -649,6 +649,9 @@ static void read_inter_frame_mode_info(VP9Decoder *const pbi,
       mbmi->inter_tx_size[i] = mbmi->tx_size;
   }
 
+  if (mbmi->sb_type < BLOCK_8X8)
+    txfm_partition_update(xd, 0, 0, TX_4X4);
+
   if (inter_block)
     read_inter_block_mode_info(pbi, xd, counts, tile, mi, mi_row, mi_col, r);
   else
index 4591710f07780d3bddaecb3721aab5fcc3ec7e4a..5f465de1aeb6c7782811d0c67a50498822d299f9 100644 (file)
@@ -334,6 +334,9 @@ static void pack_inter_mode_mvs(VP9_COMP *cpi, const MODE_INFO *mi,
     }
   }
 
+  if (bsize < BLOCK_8X8)
+    txfm_partition_update(xd, 0, 0, TX_4X4);
+
   if (!is_inter) {
     if (bsize >= BLOCK_8X8) {
       write_intra_mode(w, mode, cm->fc->y_mode_prob[size_group_lookup[bsize]]);
index 63da0917a4eb564db3a220be4de001e8bae5f8b8..3cdbfed8c956cea5289e8b49ef8ac7f74559860b 100644 (file)
@@ -1187,6 +1187,7 @@ static void restore_context(MACROBLOCK *const x, int mi_row, int mi_col,
                             ENTROPY_CONTEXT a[16 * MAX_MB_PLANE],
                             ENTROPY_CONTEXT l[16 * MAX_MB_PLANE],
                             PARTITION_CONTEXT sa[8], PARTITION_CONTEXT sl[8],
+                            TXFM_CONTEXT ta[8], TXFM_CONTEXT tl[8],
                             BLOCK_SIZE bsize) {
   MACROBLOCKD *const xd = &x->e_mbd;
   int p;
@@ -1211,12 +1212,17 @@ static void restore_context(MACROBLOCK *const x, int mi_row, int mi_col,
              sizeof(*xd->above_seg_context) * mi_width);
   vpx_memcpy(xd->left_seg_context + (mi_row & MI_MASK), sl,
              sizeof(xd->left_seg_context[0]) * mi_height);
+  vpx_memcpy(xd->above_txfm_context, ta,
+             sizeof(*xd->above_txfm_context) * mi_width);
+  vpx_memcpy(xd->left_txfm_context, tl,
+             sizeof(*xd->left_txfm_context) * mi_height);
 }
 
 static void save_context(MACROBLOCK *const x, int mi_row, int mi_col,
                          ENTROPY_CONTEXT a[16 * MAX_MB_PLANE],
                          ENTROPY_CONTEXT l[16 * MAX_MB_PLANE],
                          PARTITION_CONTEXT sa[8], PARTITION_CONTEXT sl[8],
+                         TXFM_CONTEXT ta[8], TXFM_CONTEXT tl[8],
                          BLOCK_SIZE bsize) {
   const MACROBLOCKD *const xd = &x->e_mbd;
   int p;
@@ -1243,6 +1249,10 @@ static void save_context(MACROBLOCK *const x, int mi_row, int mi_col,
              sizeof(*xd->above_seg_context) * mi_width);
   vpx_memcpy(sl, xd->left_seg_context + (mi_row & MI_MASK),
              sizeof(xd->left_seg_context[0]) * mi_height);
+  vpx_memcpy(ta, xd->above_txfm_context,
+             sizeof(*xd->above_txfm_context) * mi_width);
+  vpx_memcpy(tl, xd->left_txfm_context,
+             sizeof(*xd->left_txfm_context) * mi_height);
 }
 
 static void encode_b(VP9_COMP *cpi, const TileInfo *const tile,
@@ -1693,6 +1703,7 @@ static void rd_use_partition(VP9_COMP *cpi,
   BLOCK_SIZE subsize;
   ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
   PARTITION_CONTEXT sl[8], sa[8];
+  TXFM_CONTEXT tl[8], ta[8];
   RD_COST last_part_rdc, none_rdc, chosen_rdc;
   BLOCK_SIZE sub_subsize = BLOCK_4X4;
   int splits_below = 0;
@@ -1714,7 +1725,9 @@ static void rd_use_partition(VP9_COMP *cpi,
   subsize = get_subsize(bsize, partition);
 
   pc_tree->partitioning = partition;
-  save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+  save_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
 
   if (bsize == BLOCK_16X16 && cpi->oxcf.aq_mode) {
     set_offsets(cpi, tile_info, x, mi_row, mi_col, bsize);
@@ -1754,7 +1767,7 @@ static void rd_use_partition(VP9_COMP *cpi,
                                  none_rdc.dist);
       }
 
-      restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
       mi_8x8[0].src_mi->mbmi.sb_type = bs_type;
       pc_tree->partitioning = partition;
     }
@@ -1865,7 +1878,7 @@ static void rd_use_partition(VP9_COMP *cpi,
     BLOCK_SIZE split_subsize = get_subsize(bsize, PARTITION_SPLIT);
     chosen_rdc.rate = 0;
     chosen_rdc.dist = 0;
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
     pc_tree->partitioning = PARTITION_SPLIT;
 
     // Split partition.
@@ -1875,17 +1888,18 @@ static void rd_use_partition(VP9_COMP *cpi,
       RD_COST tmp_rdc;
       ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
       PARTITION_CONTEXT sl[8], sa[8];
+      TXFM_CONTEXT tl[8], ta[8];
 
       if ((mi_row + y_idx >= cm->mi_rows) || (mi_col + x_idx >= cm->mi_cols))
         continue;
 
-      save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      save_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
       pc_tree->split[i]->partitioning = PARTITION_NONE;
       rd_pick_sb_modes(cpi, tile_data, x,
                        mi_row + y_idx, mi_col + x_idx, &tmp_rdc,
                        split_subsize, &pc_tree->split[i]->none, INT64_MAX);
 
-      restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+      restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
 
       if (tmp_rdc.rate == INT_MAX || tmp_rdc.dist == INT64_MAX) {
         vp9_rd_cost_reset(&chosen_rdc);
@@ -1925,7 +1939,9 @@ static void rd_use_partition(VP9_COMP *cpi,
     chosen_rdc = none_rdc;
   }
 
-  restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+  restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
 
   // We must have chosen a partitioning and encoding or we'll fail later on.
   // No other opportunities for success.
@@ -2279,6 +2295,7 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
   const int mi_step = num_8x8_blocks_wide_lookup[bsize] / 2;
   ENTROPY_CONTEXT l[16 * MAX_MB_PLANE], a[16 * MAX_MB_PLANE];
   PARTITION_CONTEXT sl[8], sa[8];
+  TXFM_CONTEXT tl[8], ta[8];
   TOKENEXTRA *tp_orig = *tp;
   PICK_MODE_CONTEXT *ctx = &pc_tree->none;
   int i, pl;
@@ -2344,7 +2361,9 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
     partition_vert_allowed &= force_vert_split;
   }
 
-  save_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+  save_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
 
 #if CONFIG_FP_MB_STATS
   if (cpi->use_fp_mb_stats) {
@@ -2490,7 +2509,7 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
 #endif
       }
     }
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
   }
 
   // store estimated motion vector
@@ -2555,7 +2574,9 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
       if (cpi->sf.less_rectangular_check)
         do_rect &= !partition_none_allowed;
     }
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
   }
 
   // PARTITION_HORZ
@@ -2582,6 +2603,10 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
           partition_none_allowed)
         pc_tree->horizontal[1].pred_interp_filter =
             ctx->mic.mbmi.interp_filter;
+
+      xd->above_txfm_context = cm->above_txfm_context + mi_col;
+      xd->left_txfm_context = xd->left_txfm_context_buffer +
+                                ((mi_row + mi_step) & 0x07);
       rd_pick_sb_modes(cpi, tile_data, x, mi_row + mi_step, mi_col,
                        &this_rdc, subsize, &pc_tree->horizontal[1],
                        best_rdc.rdcost - sum_rdc.rdcost);
@@ -2603,7 +2628,9 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
         pc_tree->partitioning = PARTITION_HORZ;
       }
     }
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
   }
   // PARTITION_VERT
   if (partition_vert_allowed && do_rect) {
@@ -2629,6 +2656,8 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
           partition_none_allowed)
         pc_tree->vertical[1].pred_interp_filter =
             ctx->mic.mbmi.interp_filter;
+      xd->above_txfm_context = cm->above_txfm_context + mi_col + mi_step;
+      xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
       rd_pick_sb_modes(cpi, tile_data, x, mi_row, mi_col + mi_step,
                        &this_rdc, subsize,
                        &pc_tree->vertical[1], best_rdc.rdcost - sum_rdc.rdcost);
@@ -2651,7 +2680,9 @@ static void rd_pick_partition(VP9_COMP *cpi, ThreadData *td,
         pc_tree->partitioning = PARTITION_VERT;
       }
     }
-    restore_context(x, mi_row, mi_col, a, l, sa, sl, bsize);
+    xd->above_txfm_context = cm->above_txfm_context + mi_col;
+    xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+    restore_context(x, mi_row, mi_col, a, l, sa, sl, ta, tl, bsize);
   }
 
   // TODO(jbb): This code added so that we avoid static analysis
@@ -4029,7 +4060,8 @@ static void sum_intra_stats(FRAME_COUNTS *counts, const MODE_INFO *mi) {
 }
 
 static void update_txfm_count(MACROBLOCKD *xd, FRAME_COUNTS *counts,
-                              TX_SIZE tx_size, int blk_row, int blk_col) {
+                              TX_SIZE tx_size, int blk_row, int blk_col,
+                              int dry_run) {
   MB_MODE_INFO *mbmi = &xd->mi[0].src_mi->mbmi;
   int tx_idx = (blk_row / 2) * 8 + (blk_col / 2);
   int max_blocks_high = num_4x4_blocks_high_lookup[mbmi->sb_type];
@@ -4046,14 +4078,16 @@ static void update_txfm_count(MACROBLOCKD *xd, FRAME_COUNTS *counts,
     return;
 
   if (tx_size == plane_tx_size) {
-    ++counts->txfm_partition[ctx][0];
+    if (!dry_run)
+      ++counts->txfm_partition[ctx][0];
     mbmi->tx_size = tx_size;
     txfm_partition_update(xd, blk_row, blk_col, tx_size);
   } else {
     BLOCK_SIZE bsize = txsize_to_bsize[tx_size];
     int bh = num_4x4_blocks_high_lookup[bsize];
     int i;
-    ++counts->txfm_partition[ctx][1];
+    if (!dry_run)
+      ++counts->txfm_partition[ctx][1];
 
     if (tx_size == TX_8X8) {
       mbmi->inter_tx_size[tx_idx] = TX_4X4;
@@ -4066,7 +4100,7 @@ static void update_txfm_count(MACROBLOCKD *xd, FRAME_COUNTS *counts,
       int offsetr = (i >> 1) * bh / 2;
       int offsetc = (i & 0x01) * bh / 2;
       update_txfm_count(xd, counts, tx_size - 1,
-                        blk_row + offsetr, blk_col + offsetc);
+                        blk_row + offsetr, blk_col + offsetc, dry_run);
     }
   }
 }
@@ -4134,6 +4168,9 @@ static void encode_superblock(VP9_COMP *cpi, ThreadData *td,
     vp9_tokenize_sb_inter(cpi, td, t, !output_enabled, MAX(bsize, BLOCK_8X8));
   }
 
+  xd->above_txfm_context = cm->above_txfm_context + mi_col;
+  xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
+
   if (output_enabled) {
     if (cm->tx_mode == TX_MODE_SELECT &&
         mbmi->sb_type >= BLOCK_8X8  &&
@@ -4147,12 +4184,10 @@ static void encode_superblock(VP9_COMP *cpi, ThreadData *td,
         int width  = num_4x4_blocks_wide_lookup[bsize];
         int height = num_4x4_blocks_high_lookup[bsize];
         int idx, idy;
-        xd->above_txfm_context = cm->above_txfm_context + mi_col;
-        xd->left_txfm_context = xd->left_txfm_context_buffer + (mi_row & 0x07);
         for (idy = 0; idy < height; idy += bh)
           for (idx = 0; idx < width; idx += bh)
             update_txfm_count(xd, td->counts, max_txsize_lookup[mbmi->sb_type],
-                              idy, idx);
+                              idy, idx, 0);
       }
     } else {
       int x, y;
@@ -4172,5 +4207,24 @@ static void encode_superblock(VP9_COMP *cpi, ThreadData *td,
     }
     ++td->counts->tx.tx_totals[mbmi->tx_size];
     ++td->counts->tx.tx_totals[get_uv_tx_size(mbmi, &xd->plane[1])];
+  } else {
+    if (cm->tx_mode == TX_MODE_SELECT &&
+        mbmi->sb_type >= BLOCK_8X8  &&
+        !(is_inter_block(mbmi) && (mbmi->skip || seg_skip))) {
+      if (is_inter_block(mbmi)) {
+        BLOCK_SIZE txb_size = txsize_to_bsize[max_txsize_lookup[bsize]];
+        int bh = num_4x4_blocks_wide_lookup[txb_size];
+        int width  = num_4x4_blocks_wide_lookup[bsize];
+        int height = num_4x4_blocks_high_lookup[bsize];
+        int idx, idy;
+        for (idy = 0; idy < height; idy += bh)
+          for (idx = 0; idx < width; idx += bh)
+            update_txfm_count(xd, td->counts, max_txsize_lookup[mbmi->sb_type],
+                              idy, idx, 1);
+      }
+    }
   }
+
+  if (mbmi->sb_type < BLOCK_8X8)
+    txfm_partition_update(xd, 0, 0, TX_4X4);
 }
index 8c25756f624317bc5c9d3300107038a3b1209630..cd00410b55e633b767cfdc463fb3cc76412bb0bd 100644 (file)
@@ -1267,7 +1267,6 @@ static void tx_block_rd_b(VP9_COMP const *cpi, MACROBLOCK *x, TX_SIZE tx_size,
     }
   }
   *dist += (int64_t)tmp_sse * 16;
-
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 
   switch (tx_size) {
@@ -1314,14 +1313,25 @@ static void select_tx_block(const VP9_COMP *cpi, MACROBLOCK *x,
                (blk_col >> (1 - pd->subsampling_x));
   int max_blocks_high = num_4x4_blocks_high_lookup[plane_bsize];
   int max_blocks_wide = num_4x4_blocks_wide_lookup[plane_bsize];
+  int mi_width = num_8x8_blocks_wide_lookup[txb_bsize];
+  int mi_height = num_8x8_blocks_high_lookup[txb_bsize];
   int64_t this_rd = INT64_MAX;
   ENTROPY_CONTEXT ctxa[16], ctxl[16];
   ENTROPY_CONTEXT *pta = ta + (blk_col >> pd->subsampling_x);
   ENTROPY_CONTEXT *ptl = tl + (blk_row >> pd->subsampling_y);
+  TXFM_CONTEXT txa[8], txl[8];
+  TXFM_CONTEXT stxa[8], stxl[8];
+  int ctx = txfm_partition_context(xd, blk_row, blk_col, tx_size);
 
   vpx_memcpy(ctxa, ta, sizeof(ENTROPY_CONTEXT) * max_blocks_wide);
   vpx_memcpy(ctxl, tl, sizeof(ENTROPY_CONTEXT) * max_blocks_high);
 
+  // Store the above and left transform block partition context.
+  vpx_memcpy(stxa, xd->above_txfm_context + (blk_col / 2),
+             sizeof(*xd->above_txfm_context) * mi_width);
+  vpx_memcpy(stxl, xd->left_txfm_context + (blk_row / 2),
+             sizeof(*xd->left_txfm_context) * mi_height);
+
   if (xd->mb_to_bottom_edge < 0)
     max_blocks_high += xd->mb_to_bottom_edge >> (5 + pd->subsampling_y);
   if (xd->mb_to_right_edge < 0)
@@ -1341,8 +1351,9 @@ static void select_tx_block(const VP9_COMP *cpi, MACROBLOCK *x,
   if (cpi->common.tx_mode == TX_MODE_SELECT || tx_size == TX_4X4) {
     tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, plane, block,
                   plane_bsize, ta, tl, rate, dist, bsse, skip);
-    if (tx_size > TX_4X4)
-      *rate += 256;
+    txfm_partition_update(xd, blk_row, blk_col, tx_size);
+    if (tx_size >= TX_8X8)
+      *rate += vp9_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 0);
     this_rd = RDCOST(x->rdmult, x->rddiv, *rate, *dist);
   }
 
@@ -1351,11 +1362,24 @@ static void select_tx_block(const VP9_COMP *cpi, MACROBLOCK *x,
     int bh = num_4x4_blocks_high_lookup[bsize];
     int sub_step = 1 << (2 *(tx_size - 1));
     int i;
-    int this_rate, sum_rate = 256;
+    int this_rate, sum_rate;
     int64_t this_dist, sum_dist = 0;
     int64_t this_bsse, sum_bsse = 0;
     int this_skip, all_skip = 1;
     int64_t sum_rd;
+
+    vpx_memcpy(txa, xd->above_txfm_context + (blk_col / 2),
+               sizeof(*xd->above_txfm_context) * mi_width);
+    vpx_memcpy(txl, xd->left_txfm_context + (blk_row / 2),
+               sizeof(*xd->left_txfm_context) * mi_height);
+
+    // Restore the above and left transform block partition context.
+    vpx_memcpy(xd->above_txfm_context + (blk_col / 2), stxa,
+               sizeof(*xd->above_txfm_context) * mi_width);
+    vpx_memcpy(xd->left_txfm_context + (blk_row / 2), stxl,
+               sizeof(*xd->left_txfm_context) * mi_height);
+
+    sum_rate = vp9_cost_bit(cpi->common.fc->txfm_partition_prob[ctx], 1);
     for (i = 0; i < 4; ++i) {
       int offsetr = (i >> 1) * bh / 2;
       int offsetc = (i & 0x01) * bh / 2;
@@ -1377,6 +1401,10 @@ static void select_tx_block(const VP9_COMP *cpi, MACROBLOCK *x,
         for (idx = blk_col; idx < blk_col + bh; idx += 2)
           mbmi->inter_tx_size[(idy / 2) * 8 + (idx / 2)] = tx_size;
       mbmi->tx_size = tx_size;
+      vpx_memcpy(xd->above_txfm_context + (blk_col / 2), txa,
+                 sizeof(*xd->above_txfm_context) * mi_width);
+      vpx_memcpy(xd->left_txfm_context + (blk_row / 2), txl,
+                 sizeof(*xd->left_txfm_context) * mi_height);
     } else {
       *rate = sum_rate;
       *dist = sum_dist;
@@ -3171,6 +3199,14 @@ static int64_t handle_inter_mode(VP9_COMP *cpi, MACROBLOCK *x,
     int skippable_y, skippable_uv;
     int64_t sseuv = INT64_MAX;
     int64_t rdcosty = INT64_MAX;
+    TXFM_CONTEXT ta[8], tl[8];
+    int mi_width = num_8x8_blocks_wide_lookup[bsize];
+    int mi_height = num_8x8_blocks_high_lookup[bsize];
+
+    vpx_memcpy(ta, xd->above_txfm_context,
+               sizeof(*xd->above_txfm_context) * mi_width);
+    vpx_memcpy(tl, xd->left_txfm_context,
+               sizeof(*xd->left_txfm_context) * mi_height);
 
     // Y cost and distortion
     vp9_subtract_plane(x, bsize, 0);
@@ -3178,6 +3214,10 @@ static int64_t handle_inter_mode(VP9_COMP *cpi, MACROBLOCK *x,
     if (cm->tx_mode == TX_MODE_SELECT) {
       inter_block_yrd(cpi, x, rate_y, &distortion_y, &skippable_y, psse,
                       bsize, ref_best_rd);
+      vpx_memcpy(xd->above_txfm_context, ta,
+                 sizeof(*xd->above_txfm_context) * mi_width);
+      vpx_memcpy(xd->left_txfm_context, tl,
+                 sizeof(*xd->left_txfm_context) * mi_height);
     } else {
       super_block_yrd(cpi, x, rate_y, &distortion_y, &skippable_y, psse,
                       bsize, txfm_cache, ref_best_rd);