]> granicus.if.org Git - libvpx/commitdiff
Rewrite reference frame costing in the RD loop.
authorRonald S. Bultje <rbultje@google.com>
Tue, 15 May 2012 00:39:42 +0000 (17:39 -0700)
committerRonald S. Bultje <rbultje@google.com>
Tue, 15 May 2012 22:32:44 +0000 (15:32 -0700)
I now see I didn't write a very long description, so let's do it
here then. We took a pretty big quality hit (0.1-0.2%) from my
recent fix of the inversion of arguments to vp8_cost_bit() in the
RD reference frame costing. I looked into it and basically the
costing prevented us from switching reference frames. This is of
course silly, since each frame codes its own prob_intra_coded, so
using last frame cost indications as a limiting factor can never
be right.

Here, I've rewritten that code to estimate costings based partially
on statistics from progress on current frame encoding. Overall,
this gives us a ~0.2%-0.3% improvement over what we had previously
before my argument-inversion-fix, and thus about ~0.4% over current
git (on derf-set), and a little more (0.5-1.0%) on HD/STD-HD/YT.

Change-Id: I79ebd4ccec4d6edbf0e152d9590d103ba2747775

vp8/common/pred_common.c
vp8/encoder/bitstream.c
vp8/encoder/encodeframe.c
vp8/encoder/mbgraph.c
vp8/encoder/onyx_int.h
vp8/encoder/rdopt.c

index 0e01e6c8a983a4d9027c8c686d5c1fcdfb43a446..f7222a7ecf581bf1b079e3b6907aae354d03f0bf 100644 (file)
@@ -277,7 +277,7 @@ void calc_ref_probs( int * count, vp8_prob * probs )
     tot_count = count[0] + count[1] + count[2] + count[3];
     if ( tot_count )
     {
-        probs[0] = (vp8_prob)((count[0] * 255) / tot_count);
+        probs[0] = (vp8_prob)((count[0] * 255 + (tot_count >> 1)) / tot_count);
         probs[0] += !probs[0];
     }
     else
@@ -286,7 +286,7 @@ void calc_ref_probs( int * count, vp8_prob * probs )
     tot_count -= count[0];
     if ( tot_count )
     {
-        probs[1] = (vp8_prob)((count[1] * 255) / tot_count);
+        probs[1] = (vp8_prob)((count[1] * 255 + (tot_count >> 1)) / tot_count);
         probs[1] += !probs[1];
     }
     else
@@ -295,7 +295,7 @@ void calc_ref_probs( int * count, vp8_prob * probs )
     tot_count -= count[1];
     if ( tot_count )
     {
-        probs[2] = (vp8_prob)((count[2] * 255) / tot_count);
+        probs[2] = (vp8_prob)((count[2] * 255 + (tot_count >> 1)) / tot_count);
         probs[2] += !probs[2];
     }
     else
index 6ebf282ed3a7ae288f6f048259b116e7ded95487..6033f790e6f6df4b43063776f78bc2c2d4081c76 100644 (file)
@@ -219,21 +219,11 @@ void update_skip_probs(VP8_COMP *cpi)
 static void update_refpred_stats( VP8_COMP *cpi )
 {
     VP8_COMMON *const cm = & cpi->common;
-    MACROBLOCKD *const xd = & cpi->mb.e_mbd;
-
-    int mb_row, mb_col;
     int i;
     int tot_count;
-    int ref_pred_count[PREDICTION_PROBS][2];
     vp8_prob new_pred_probs[PREDICTION_PROBS];
-    unsigned char pred_context;
-    unsigned char pred_flag;
-
     int old_cost, new_cost;
 
-    // Clear the prediction hit counters
-    vpx_memset(ref_pred_count, 0, sizeof(ref_pred_count));
-
     // Set the prediction probability structures to defaults
     if ( cm->frame_type == KEY_FRAME )
     {
@@ -247,47 +237,17 @@ static void update_refpred_stats( VP8_COMP *cpi )
     }
     else
     {
-        // For non-key frames.......
-
-        // Scan through the macroblocks and collate prediction counts.
-        xd->mode_info_context = cm->mi;
-        for (mb_row = 0; mb_row < cm->mb_rows; mb_row++)
-        {
-            for (mb_col = 0; mb_col < cm->mb_cols; mb_col++)
-            {
-                // Get the prediction context and status
-                pred_flag = get_pred_flag( xd, PRED_REF );
-                pred_context = get_pred_context( cm, xd, PRED_REF );
-
-                // Count prediction success
-                ref_pred_count[pred_context][pred_flag]++;
-
-                // Step on to the next mb
-                xd->mode_info_context++;
-            }
-
-            // this is to account for the border in mode_info_context
-            xd->mode_info_context++;
-        }
-
         // From the prediction counts set the probabilities for each context
         for ( i = 0; i < PREDICTION_PROBS; i++ )
         {
-            // MB reference frame not relevent to key frame encoding
-            if ( cm->frame_type != KEY_FRAME )
+            tot_count = cpi->ref_pred_count[i][0] + cpi->ref_pred_count[i][1];
+            if ( tot_count )
             {
-                // Work out the probabilities for the reference frame predictor
-                tot_count = ref_pred_count[i][0] + ref_pred_count[i][1];
-                if ( tot_count )
-                {
-                    new_pred_probs[i] =
-                        ( ref_pred_count[i][0] * 255 ) / tot_count;
+                new_pred_probs[i] =
+                    ( cpi->ref_pred_count[i][0] * 255 + (tot_count >> 1)) / tot_count;
 
-                    // Clamp to minimum allowed value
-                    new_pred_probs[i] += !new_pred_probs[i];
-                }
-                else
-                    new_pred_probs[i] = 128;
+                // Clamp to minimum allowed value
+                new_pred_probs[i] += !new_pred_probs[i];
             }
             else
                 new_pred_probs[i] = 128;
@@ -295,12 +255,12 @@ static void update_refpred_stats( VP8_COMP *cpi )
             // Decide whether or not to update the reference frame probs.
             // Returned costs are in 1/256 bit units.
             old_cost =
-                (ref_pred_count[i][0] * vp8_cost_zero(cm->ref_pred_probs[i])) +
-                (ref_pred_count[i][1] * vp8_cost_one(cm->ref_pred_probs[i]));
+                (cpi->ref_pred_count[i][0] * vp8_cost_zero(cm->ref_pred_probs[i])) +
+                (cpi->ref_pred_count[i][1] * vp8_cost_one(cm->ref_pred_probs[i]));
 
             new_cost =
-                (ref_pred_count[i][0] * vp8_cost_zero(new_pred_probs[i])) +
-                (ref_pred_count[i][1] * vp8_cost_one(new_pred_probs[i]));
+                (cpi->ref_pred_count[i][0] * vp8_cost_zero(new_pred_probs[i])) +
+                (cpi->ref_pred_count[i][1] * vp8_cost_one(new_pred_probs[i]));
 
             // Cost saving must be >= 8 bits (2048 in these units)
             if ( (old_cost - new_cost) >= 2048 )
index 80d1f7045ed846674987963019cad0aad8d08772..364d6ed55bafc0ddfa20fa83a9dd1cb0e1d9141b 100644 (file)
@@ -620,6 +620,23 @@ static void pick_mb_modes (VP8_COMP *cpi,
         }
         else
         {
+            int seg_id;
+
+            if (xd->segmentation_enabled && cpi->seg0_cnt > 0 &&
+                !segfeature_active( xd, 0, SEG_LVL_REF_FRAME ) &&
+                segfeature_active( xd, 1, SEG_LVL_REF_FRAME ) &&
+                check_segref(xd, 1, INTRA_FRAME)  +
+                check_segref(xd, 1, LAST_FRAME)   +
+                check_segref(xd, 1, GOLDEN_FRAME) +
+                check_segref(xd, 1, ALTREF_FRAME) == 1)
+            {
+                cpi->seg0_progress = (cpi->seg0_idx << 16) / cpi->seg0_cnt;
+            }
+            else
+            {
+                cpi->seg0_progress = (((mb_col & ~1) * 2 + (mb_row & ~1) * cm->mb_cols + i) << 16) / cm->MBs;
+            }
+
             *totalrate += vp8cx_pick_mode_inter_macroblock(cpi, x,
                                                            recon_yoffset,
                                                            recon_uvoffset);
@@ -627,6 +644,26 @@ static void pick_mb_modes (VP8_COMP *cpi,
             // Dummy encode, do not do the tokenization
             vp8cx_encode_inter_macroblock(cpi, x, tp,
                                          recon_yoffset, recon_uvoffset, 0);
+
+            seg_id = xd->mode_info_context->mbmi.segment_id;
+            if (cpi->mb.e_mbd.segmentation_enabled && seg_id == 0)
+            {
+                cpi->seg0_idx++;
+            }
+            if (!xd->segmentation_enabled ||
+                !segfeature_active( xd, seg_id, SEG_LVL_REF_FRAME ) ||
+                check_segref(xd, seg_id, INTRA_FRAME)  +
+                check_segref(xd, seg_id, LAST_FRAME)   +
+                check_segref(xd, seg_id, GOLDEN_FRAME) +
+                check_segref(xd, seg_id, ALTREF_FRAME) > 1)
+            {
+                // Get the prediction context and status
+                int pred_flag = get_pred_flag( xd, PRED_REF );
+                int pred_context = get_pred_context( cm, xd, PRED_REF );
+
+                // Count prediction success
+                cpi->ref_pred_count[pred_context][pred_flag]++;
+            }
         }
 
         // Keep a copy of the updated left context
@@ -989,6 +1026,8 @@ void init_encode_frame_mb_context(VP8_COMP *cpi)
     x->mb_activity_ptr = cpi->mb_activity_map;
 
     x->act_zbin_adj = 0;
+    cpi->seg0_idx = 0;
+    vpx_memset(cpi->ref_pred_count, 0, sizeof(cpi->ref_pred_count));
 
     x->partition_info = x->pi;
 
index ebdb9b504cd903945d57b1649d1402d9a4dacd3b..d2d3b6bdcdf6b0ff4fdeaef12c5f7d2f3d75eb62 100644 (file)
@@ -502,6 +502,7 @@ void separate_arf_mbs
         else
             cpi->static_mb_pct = 0;
 
+        cpi->seg0_cnt = ncnt[0];
         vp8_enable_segmentation((VP8_PTR) cpi);
     }
     else
index 7f6545ed2cf14b21424b4520b5a125a78c35a061..feefc249aa533d1343e6c5374e307d249ab4d36a 100644 (file)
@@ -485,6 +485,8 @@ typedef struct VP8_COMP
     MBGRAPH_FRAME_STATS mbgraph_stats[MAX_LAG_BUFFERS];
     int mbgraph_n_frames;             // number of frames filled in the above
     int static_mb_pct;                // % forced skip mbs by segmentation
+    int seg0_progress, seg0_idx, seg0_cnt;
+    int ref_pred_count[3][2];
 
     int decimation_factor;
     int decimation_count;
index a87d1afe383cd32fd4ee68913d2e853469a6cfca..0caa4c73d9a6ed21eed1da82a8322ec3cc643dab 100644 (file)
@@ -2581,7 +2581,66 @@ static void set_i8x8_block_modes(MACROBLOCK *x, int modes[2][4])
     }
 }
 
-void vp8_estimate_ref_frame_costs(VP8_COMP *cpi, unsigned int * ref_costs )
+extern void calc_ref_probs( int * count, vp8_prob * probs );
+static void estimate_curframe_refprobs(VP8_COMP *cpi, vp8_prob mod_refprobs[3], int pred_ref)
+{
+    int norm_cnt[MAX_REF_FRAMES];
+    const int *const rfct = cpi->count_mb_ref_frame_usage;
+    int intra_count = rfct[INTRA_FRAME];
+    int last_count  = rfct[LAST_FRAME];
+    int gf_count    = rfct[GOLDEN_FRAME];
+    int arf_count   = rfct[ALTREF_FRAME];
+
+    // Work out modified reference frame probabilities to use where prediction
+    // of the reference frame fails
+    if (pred_ref == INTRA_FRAME)
+    {
+        norm_cnt[0] = 0;
+        norm_cnt[1] = last_count;
+        norm_cnt[2] = gf_count;
+        norm_cnt[3] = arf_count;
+        calc_ref_probs( norm_cnt, mod_refprobs );
+        mod_refprobs[0] = 0;    // This branch implicit
+    }
+    else if (pred_ref == LAST_FRAME)
+    {
+        norm_cnt[0] = intra_count;
+        norm_cnt[1] = 0;
+        norm_cnt[2] = gf_count;
+        norm_cnt[3] = arf_count;
+        calc_ref_probs( norm_cnt, mod_refprobs);
+        mod_refprobs[1] = 0;    // This branch implicit
+    }
+    else if (pred_ref == GOLDEN_FRAME)
+    {
+        norm_cnt[0] = intra_count;
+        norm_cnt[1] = last_count;
+        norm_cnt[2] = 0;
+        norm_cnt[3] = arf_count;
+        calc_ref_probs( norm_cnt, mod_refprobs );
+        mod_refprobs[2] = 0;  // This branch implicit
+    }
+    else
+    {
+        norm_cnt[0] = intra_count;
+        norm_cnt[1] = last_count;
+        norm_cnt[2] = gf_count;
+        norm_cnt[3] = 0;
+        calc_ref_probs( norm_cnt, mod_refprobs );
+        mod_refprobs[2] = 0;  // This branch implicit
+    }
+}
+
+static __inline unsigned weighted_cost(vp8_prob *tab0, vp8_prob *tab1, int idx, int val, int weight)
+{
+    unsigned cost0 = tab0[idx] ? vp8_cost_bit(tab0[idx], val) : 0;
+    unsigned cost1 = tab1[idx] ? vp8_cost_bit(tab1[idx], val) : 0;
+    // weight is 16-bit fixed point, so this basically calculates:
+    // 0.5 + weight * cost1 + (1.0 - weight) * cost0
+    return (0x8000 + weight * cost1 + (0x10000 - weight) * cost0) >> 16;
+}
+
+static void vp8_estimate_ref_frame_costs(VP8_COMP *cpi, int segment_id, unsigned int * ref_costs )
 {
     VP8_COMMON *cm = &cpi->common;
     MACROBLOCKD *xd = &cpi->mb.e_mbd;
@@ -2590,47 +2649,87 @@ void vp8_estimate_ref_frame_costs(VP8_COMP *cpi, unsigned int * ref_costs )
     unsigned int cost;
     int pred_ref ;
     int pred_flag;
+    int pred_ctx ;
     int i;
+    int tot_count;
+
+    vp8_prob pred_prob, new_pred_prob;
+    int seg_ref_active;
+    int seg_ref_count = 0;
+    seg_ref_active = segfeature_active( xd,
+                                       segment_id,
+                                       SEG_LVL_REF_FRAME );
 
-    vp8_prob pred_prob;
+    if ( seg_ref_active )
+    {
+        seg_ref_count = check_segref( xd, segment_id, INTRA_FRAME )  +
+                        check_segref( xd, segment_id, LAST_FRAME )   +
+                        check_segref( xd, segment_id, GOLDEN_FRAME ) +
+                        check_segref( xd, segment_id, ALTREF_FRAME );
+    }
 
     // Get the predicted reference for this mb
     pred_ref = get_pred_ref( cm, xd );
 
-    // Get the context probability for the prediction flag
+    // Get the context probability for the prediction flag (based on last frame)
     pred_prob = get_pred_prob( cm, xd, PRED_REF );
 
+    // Predict probability for current frame based on stats so far
+    pred_ctx = get_pred_context(cm, xd, PRED_REF);
+    tot_count = cpi->ref_pred_count[pred_ctx][0] + cpi->ref_pred_count[pred_ctx][1];
+    if ( tot_count )
+    {
+        new_pred_prob =
+            ( cpi->ref_pred_count[pred_ctx][0] * 255 + (tot_count >> 1)) / tot_count;
+        new_pred_prob += !new_pred_prob;
+    }
+    else
+        new_pred_prob = 128;
+
     // Get the set of probabilities to use if prediction fails
     mod_refprobs = cm->mod_refprobs[pred_ref];
 
     // For each possible selected reference frame work out a cost.
-    // TODO: correct handling of costs if segment indicates only a subset of
-    // reference frames are allowed... though mostly this should come out
-    // in the wash.
     for ( i = 0; i < MAX_REF_FRAMES; i++ )
     {
-        pred_flag = (i == pred_ref);
-
-        // Get the prediction for the current mb
-        cost = vp8_cost_bit( pred_prob, pred_flag );
-
-        // for incorrectly predicted cases
-        if ( ! pred_flag )
+        if (seg_ref_active && seg_ref_count == 1)
         {
-            if ( mod_refprobs[0] )
-                cost += vp8_cost_bit( mod_refprobs[0], (i != INTRA_FRAME) );
+            cost = 0;
+        }
+        else
+        {
+            pred_flag = (i == pred_ref);
+
+            // Get the prediction for the current mb
+            cost = weighted_cost(&pred_prob, &new_pred_prob, 0,
+                                 pred_flag, cpi->seg0_progress);
+            if (cost > 1024) cost = 768; // i.e. account for 4 bits max.
 
-            // Inter coded
-            if (i != INTRA_FRAME)
+            // for incorrectly predicted cases
+            if ( ! pred_flag )
             {
-                if ( mod_refprobs[1] )
-                    cost += vp8_cost_bit( mod_refprobs[1], (i != LAST_FRAME) );
+                vp8_prob curframe_mod_refprobs[3];
 
-                if (i != LAST_FRAME)
+                if (cpi->seg0_progress)
                 {
-                    if ( mod_refprobs[2] )
-                        cost += vp8_cost_bit( mod_refprobs[2],
-                                             (i != GOLDEN_FRAME));
+                    estimate_curframe_refprobs(cpi, curframe_mod_refprobs, pred_ref);
+                }
+                else
+                {
+                    vpx_memset(curframe_mod_refprobs, 0, sizeof(curframe_mod_refprobs));
+                }
+
+                cost += weighted_cost(mod_refprobs, curframe_mod_refprobs, 0,
+                                      (i != INTRA_FRAME), cpi->seg0_progress);
+                if (i != INTRA_FRAME)
+                {
+                    cost += weighted_cost(mod_refprobs, curframe_mod_refprobs, 1,
+                                          (i != LAST_FRAME), cpi->seg0_progress);
+                    if (i != LAST_FRAME)
+                    {
+                        cost += weighted_cost(mod_refprobs, curframe_mod_refprobs, 2,
+                                              (i != GOLDEN_FRAME), cpi->seg0_progress);
+                    }
                 }
             }
         }
@@ -2819,7 +2918,7 @@ void vp8_rd_pick_inter_mode(VP8_COMP *cpi, MACROBLOCK *x, int recon_yoffset, int
 
     // Get estimates of reference frame costs for each reference frame
     // that depend on the current prediction etc.
-    vp8_estimate_ref_frame_costs( cpi, ref_costs );
+    vp8_estimate_ref_frame_costs( cpi, segment_id, ref_costs );
 
     for (mode_index = 0; mode_index < MAX_MODES; mode_index++)
     {