]> granicus.if.org Git - libvpx/commitdiff
Optimize vp9_tree_probs_from_distribution
authorJohn Koleszar <jkoleszar@google.com>
Sun, 10 Mar 2013 20:39:30 +0000 (13:39 -0700)
committerJohn Koleszar <jkoleszar@google.com>
Sun, 10 Mar 2013 20:39:30 +0000 (13:39 -0700)
The previous implementation visited each node in the tree multiple times
because it used each symbol's encoding to revisit the branches taken and
increment its count. Instead, we can traverse the tree depth first and
calculate the probabilities and branch counts as we walk back up. The
complexity goes from somewhere between O(nlogn) and O(n^2) (depending on
how balanced the tree is) to O(n).

Only tested one clip (256kbps, CIF), saw 13% decoding perf improvement.

Note that this optimization should port trivially to VP8 as well. In VP8,
the decoder doesn't use this function, but it does routinely show up
on the profile for realtime encoding.

Change-Id: I4f2848e4f41dc9a7694f73f3e75034bce08d1b12

vp9/common/vp9_entropy.c
vp9/common/vp9_entropymode.c
vp9/common/vp9_entropymv.c
vp9/common/vp9_treecoder.c
vp9/common/vp9_treecoder.h
vp9/encoder/vp9_bitstream.c

index 1e3a7e17e33b68be179c811751ac70de68192229..bc6935313150a7028e809fe566860140000f8f22 100644 (file)
@@ -292,10 +292,9 @@ static void update_coef_probs(vp9_coeff_probs *dst_coef_probs,
         for (l = 0; l < PREV_COEF_CONTEXTS; ++l) {
           if (l >= 3 && k == 0)
             continue;
-          vp9_tree_probs_from_distribution(MAX_ENTROPY_TOKENS,
-                                           vp9_coef_encodings, vp9_coef_tree,
+          vp9_tree_probs_from_distribution(vp9_coef_tree,
                                            coef_probs, branch_ct,
-                                           coef_counts[i][j][k][l]);
+                                           coef_counts[i][j][k][l], 0);
           for (t = 0; t < ENTROPY_NODES; ++t) {
             count = branch_ct[t][0] + branch_ct[t][1];
             count = count > count_sat ? count_sat : count;
index 23b2abef71789389c55669e61e53e224a21a71f5..061c279fa5e457be269c54839a3b9c707d3fcff4 100644 (file)
@@ -302,40 +302,32 @@ struct vp9_token_struct vp9_sub_mv_ref_encoding_array[VP9_SUBMVREFS];
 void vp9_init_mbmode_probs(VP9_COMMON *x) {
   unsigned int bct [VP9_YMODES] [2];      /* num Ymodes > num UV modes */
 
-  vp9_tree_probs_from_distribution(VP9_YMODES, vp9_ymode_encodings,
-                                   vp9_ymode_tree, x->fc.ymode_prob,
-                                   bct, y_mode_cts);
-  vp9_tree_probs_from_distribution(VP9_I32X32_MODES, vp9_sb_ymode_encodings,
-                                   vp9_sb_ymode_tree, x->fc.sb_ymode_prob,
-                                   bct, y_mode_cts);
+  vp9_tree_probs_from_distribution(vp9_ymode_tree, x->fc.ymode_prob,
+                                   bct, y_mode_cts, 0);
+  vp9_tree_probs_from_distribution(vp9_sb_ymode_tree, x->fc.sb_ymode_prob,
+                                   bct, y_mode_cts, 0);
   {
     int i;
     for (i = 0; i < 8; i++) {
-      vp9_tree_probs_from_distribution(VP9_YMODES, vp9_kf_ymode_encodings,
-                                       vp9_kf_ymode_tree, x->kf_ymode_prob[i],
-                                       bct, kf_y_mode_cts[i]);
-      vp9_tree_probs_from_distribution(VP9_I32X32_MODES,
-                                       vp9_sb_kf_ymode_encodings,
-                                       vp9_sb_kf_ymode_tree,
+      vp9_tree_probs_from_distribution(vp9_kf_ymode_tree, x->kf_ymode_prob[i],
+                                       bct, kf_y_mode_cts[i], 0);
+      vp9_tree_probs_from_distribution(vp9_sb_kf_ymode_tree,
                                        x->sb_kf_ymode_prob[i], bct,
-                                       kf_y_mode_cts[i]);
+                                       kf_y_mode_cts[i], 0);
     }
   }
   {
     int i;
     for (i = 0; i < VP9_YMODES; i++) {
-      vp9_tree_probs_from_distribution(VP9_UV_MODES, vp9_uv_mode_encodings,
-                                       vp9_uv_mode_tree, x->kf_uv_mode_prob[i],
-                                       bct, kf_uv_mode_cts[i]);
-      vp9_tree_probs_from_distribution(VP9_UV_MODES, vp9_uv_mode_encodings,
-                                       vp9_uv_mode_tree, x->fc.uv_mode_prob[i],
-                                       bct, uv_mode_cts[i]);
+      vp9_tree_probs_from_distribution(vp9_uv_mode_tree, x->kf_uv_mode_prob[i],
+                                       bct, kf_uv_mode_cts[i], 0);
+      vp9_tree_probs_from_distribution(vp9_uv_mode_tree, x->fc.uv_mode_prob[i],
+                                       bct, uv_mode_cts[i], 0);
     }
   }
 
-  vp9_tree_probs_from_distribution(VP9_I8X8_MODES, vp9_i8x8_mode_encodings,
-                                   vp9_i8x8_mode_tree, x->fc.i8x8_mode_prob,
-                                   bct, i8x8_mode_cts);
+  vp9_tree_probs_from_distribution(vp9_i8x8_mode_tree, x->fc.i8x8_mode_prob,
+                                   bct, i8x8_mode_cts, 0);
 
   vpx_memcpy(x->fc.sub_mv_ref_prob, vp9_sub_mv_ref_prob2,
              sizeof(vp9_sub_mv_ref_prob2));
@@ -355,8 +347,7 @@ static void intra_bmode_probs_from_distribution(
   vp9_prob p[VP9_NKF_BINTRAMODES - 1],
   unsigned int branch_ct[VP9_NKF_BINTRAMODES - 1][2],
   const unsigned int events[VP9_NKF_BINTRAMODES]) {
-  vp9_tree_probs_from_distribution(VP9_NKF_BINTRAMODES, vp9_bmode_encodings,
-                                   vp9_bmode_tree, p, branch_ct, events);
+  vp9_tree_probs_from_distribution(vp9_bmode_tree, p, branch_ct, events, 0);
 }
 
 void vp9_default_bmode_probs(vp9_prob p[VP9_NKF_BINTRAMODES - 1]) {
@@ -368,8 +359,7 @@ static void intra_kf_bmode_probs_from_distribution(
   vp9_prob p[VP9_KF_BINTRAMODES - 1],
   unsigned int branch_ct[VP9_KF_BINTRAMODES - 1][2],
   const unsigned int events[VP9_KF_BINTRAMODES]) {
-  vp9_tree_probs_from_distribution(VP9_KF_BINTRAMODES, vp9_kf_bmode_encodings,
-                                   vp9_kf_bmode_tree, p, branch_ct, events);
+  vp9_tree_probs_from_distribution(vp9_kf_bmode_tree, p, branch_ct, events, 0);
 }
 
 void vp9_kf_default_bmode_probs(vp9_prob p[VP9_KF_BINTRAMODES]
@@ -538,17 +528,17 @@ void print_mode_contexts(VP9_COMMON *pc) {
 
 #define MODE_COUNT_SAT 20
 #define MODE_MAX_UPDATE_FACTOR 144
-static void update_mode_probs(int n_modes, struct vp9_token_struct *encoding,
+static void update_mode_probs(int n_modes,
                               const vp9_tree_index *tree, unsigned int *cnt,
-                              vp9_prob *pre_probs, vp9_prob *dst_probs) {
+                              vp9_prob *pre_probs, vp9_prob *dst_probs,
+                              unsigned int tok0_offset) {
 #define MAX_PROBS 32
   vp9_prob probs[MAX_PROBS];
   unsigned int branch_ct[MAX_PROBS][2];
   int t, count, factor;
 
   assert(n_modes - 1 < MAX_PROBS);
-  vp9_tree_probs_from_distribution(n_modes, encoding, tree, probs,
-                                   branch_ct, cnt);
+  vp9_tree_probs_from_distribution(tree, probs, branch_ct, cnt, tok0_offset);
   for (t = 0; t < n_modes - 1; ++t) {
     count = branch_ct[t][0] + branch_ct[t][1];
     count = count > MODE_COUNT_SAT ? MODE_COUNT_SAT : count;
@@ -604,31 +594,32 @@ void vp9_adapt_mode_probs(VP9_COMMON *cm) {
 #endif
 #endif
 
-  update_mode_probs(VP9_YMODES, vp9_ymode_encodings, vp9_ymode_tree,
+  update_mode_probs(VP9_YMODES, vp9_ymode_tree,
                     cm->fc.ymode_counts, cm->fc.pre_ymode_prob,
-                    cm->fc.ymode_prob);
-  update_mode_probs(VP9_I32X32_MODES, vp9_sb_ymode_encodings, vp9_sb_ymode_tree,
+                    cm->fc.ymode_prob, 0);
+  update_mode_probs(VP9_I32X32_MODES, vp9_sb_ymode_tree,
                     cm->fc.sb_ymode_counts, cm->fc.pre_sb_ymode_prob,
-                    cm->fc.sb_ymode_prob);
+                    cm->fc.sb_ymode_prob, 0);
   for (i = 0; i < VP9_YMODES; ++i) {
-    update_mode_probs(VP9_UV_MODES, vp9_uv_mode_encodings, vp9_uv_mode_tree,
+    update_mode_probs(VP9_UV_MODES, vp9_uv_mode_tree,
                       cm->fc.uv_mode_counts[i], cm->fc.pre_uv_mode_prob[i],
-                      cm->fc.uv_mode_prob[i]);
+                      cm->fc.uv_mode_prob[i], 0);
   }
-  update_mode_probs(VP9_NKF_BINTRAMODES, vp9_bmode_encodings, vp9_bmode_tree,
+  update_mode_probs(VP9_NKF_BINTRAMODES, vp9_bmode_tree,
                     cm->fc.bmode_counts, cm->fc.pre_bmode_prob,
-                    cm->fc.bmode_prob);
-  update_mode_probs(VP9_I8X8_MODES, vp9_i8x8_mode_encodings,
+                    cm->fc.bmode_prob, 0);
+  update_mode_probs(VP9_I8X8_MODES,
                     vp9_i8x8_mode_tree, cm->fc.i8x8_mode_counts,
-                    cm->fc.pre_i8x8_mode_prob, cm->fc.i8x8_mode_prob);
+                    cm->fc.pre_i8x8_mode_prob, cm->fc.i8x8_mode_prob, 0);
   for (i = 0; i < SUBMVREF_COUNT; ++i) {
-    update_mode_probs(VP9_SUBMVREFS, vp9_sub_mv_ref_encoding_array,
+    update_mode_probs(VP9_SUBMVREFS,
                       vp9_sub_mv_ref_tree, cm->fc.sub_mv_ref_counts[i],
-                      cm->fc.pre_sub_mv_ref_prob[i], cm->fc.sub_mv_ref_prob[i]);
+                      cm->fc.pre_sub_mv_ref_prob[i], cm->fc.sub_mv_ref_prob[i],
+                      LEFT4X4);
   }
-  update_mode_probs(VP9_NUMMBSPLITS, vp9_mbsplit_encodings, vp9_mbsplit_tree,
+  update_mode_probs(VP9_NUMMBSPLITS, vp9_mbsplit_tree,
                     cm->fc.mbsplit_counts, cm->fc.pre_mbsplit_prob,
-                    cm->fc.mbsplit_prob);
+                    cm->fc.mbsplit_prob, 0);
 #if CONFIG_COMP_INTERINTRA_PRED
   if (cm->use_interintra) {
     int factor, interintra_prob, count;
index 99e3c2e8c2424658fa8432d70bccff2bec43e540..ab87dfee2e6e0b6b8db20b6eb7beb89745cca643 100644 (file)
@@ -242,29 +242,23 @@ void vp9_counts_to_nmv_context(
     unsigned int (*branch_ct_hp)[2]) {
   int i, j, k;
   vp9_counts_process(NMVcount, usehp);
-  vp9_tree_probs_from_distribution(MV_JOINTS,
-                                   vp9_mv_joint_encodings,
-                                   vp9_mv_joint_tree,
+  vp9_tree_probs_from_distribution(vp9_mv_joint_tree,
                                    prob->joints,
                                    branch_ct_joint,
-                                   NMVcount->joints);
+                                   NMVcount->joints, 0);
   for (i = 0; i < 2; ++i) {
     prob->comps[i].sign = get_binary_prob(NMVcount->comps[i].sign[0],
                                           NMVcount->comps[i].sign[1]);
     branch_ct_sign[i][0] = NMVcount->comps[i].sign[0];
     branch_ct_sign[i][1] = NMVcount->comps[i].sign[1];
-    vp9_tree_probs_from_distribution(MV_CLASSES,
-                                     vp9_mv_class_encodings,
-                                     vp9_mv_class_tree,
+    vp9_tree_probs_from_distribution(vp9_mv_class_tree,
                                      prob->comps[i].classes,
                                      branch_ct_classes[i],
-                                     NMVcount->comps[i].classes);
-    vp9_tree_probs_from_distribution(CLASS0_SIZE,
-                                     vp9_mv_class0_encodings,
-                                     vp9_mv_class0_tree,
+                                     NMVcount->comps[i].classes, 0);
+    vp9_tree_probs_from_distribution(vp9_mv_class0_tree,
                                      prob->comps[i].class0,
                                      branch_ct_class0[i],
-                                     NMVcount->comps[i].class0);
+                                     NMVcount->comps[i].class0, 0);
     for (j = 0; j < MV_OFFSET_BITS; ++j) {
       prob->comps[i].bits[j] = get_binary_prob(NMVcount->comps[i].bits[j][0],
                                                NMVcount->comps[i].bits[j][1]);
@@ -274,19 +268,15 @@ void vp9_counts_to_nmv_context(
   }
   for (i = 0; i < 2; ++i) {
     for (k = 0; k < CLASS0_SIZE; ++k) {
-      vp9_tree_probs_from_distribution(4,
-                                       vp9_mv_fp_encodings,
-                                       vp9_mv_fp_tree,
+      vp9_tree_probs_from_distribution(vp9_mv_fp_tree,
                                        prob->comps[i].class0_fp[k],
                                        branch_ct_class0_fp[i][k],
-                                       NMVcount->comps[i].class0_fp[k]);
+                                       NMVcount->comps[i].class0_fp[k], 0);
     }
-    vp9_tree_probs_from_distribution(4,
-                                     vp9_mv_fp_encodings,
-                                     vp9_mv_fp_tree,
+    vp9_tree_probs_from_distribution(vp9_mv_fp_tree,
                                      prob->comps[i].fp,
                                      branch_ct_fp[i],
-                                     NMVcount->comps[i].fp);
+                                     NMVcount->comps[i].fp, 0);
   }
   if (usehp) {
     for (i = 0; i < 2; ++i) {
index fbc8a38cd41df1ed79697059bd6d4dac90149c0b..6e2597954b454e3cdb516e1f5d4df72a4d06ea1d 100644 (file)
@@ -48,66 +48,37 @@ void vp9_tokens_from_tree_offset(struct vp9_token_struct *p, vp9_tree t,
   tree2tok(p - offset, t, 0, 0, 0);
 }
 
-static void branch_counts(
-  int n,                      /* n = size of alphabet */
-  vp9_token tok               [ /* n */ ],
-  vp9_tree tree,
-  unsigned int branch_ct       [ /* n-1 */ ] [2],
-  const unsigned int num_events[ /* n */ ]
-) {
-  const int tree_len = n - 1;
-  int t = 0;
-
-#if CONFIG_DEBUG
-  assert(tree_len);
-#endif
-
-  do {
-    branch_ct[t][0] = branch_ct[t][1] = 0;
-  } while (++t < tree_len);
-
-  t = 0;
-
-  do {
-    int L = tok[t].Len;
-    const int enc = tok[t].value;
-    const unsigned int ct = num_events[t];
-
-    vp9_tree_index i = 0;
-
-    do {
-      const int b = (enc >> --L) & 1;
-      const int j = i >> 1;
-#if CONFIG_DEBUG
-      assert(j < tree_len  &&  0 <= L);
-#endif
-
-      branch_ct [j] [b] += ct;
-      i = tree[ i + b];
-    } while (i > 0);
-
-#if CONFIG_DEBUG
-    assert(!L);
-#endif
-  } while (++t < n);
-
+static unsigned int convert_distribution(unsigned int i,
+                                         vp9_tree tree,
+                                         vp9_prob probs[],
+                                         unsigned int branch_ct[][2],
+                                         const unsigned int num_events[],
+                                         unsigned int tok0_offset) {
+  unsigned int left, right;
+
+  if (tree[i] <= 0) {
+    left = num_events[-tree[i] - tok0_offset];
+  } else {
+    left = convert_distribution(tree[i], tree, probs, branch_ct,
+                                num_events, tok0_offset);
+  }
+  if (tree[i + 1] <= 0) {
+    right = num_events[-tree[i + 1] - tok0_offset];
+  } else {
+    right = convert_distribution(tree[i + 1], tree, probs, branch_ct,
+                                num_events, tok0_offset);
+  }
+  probs[i>>1] = get_binary_prob(left, right);
+  branch_ct[i>>1][0] = left;
+  branch_ct[i>>1][1] = right;
+  return left + right;
 }
 
-
 void vp9_tree_probs_from_distribution(
-  int n,                      /* n = size of alphabet */
-  vp9_token tok               [ /* n */ ],
   vp9_tree tree,
   vp9_prob probs          [ /* n-1 */ ],
   unsigned int branch_ct       [ /* n-1 */ ] [2],
-  const unsigned int num_events[ /* n */ ]
-) {
-  const int tree_len = n - 1;
-  int t = 0;
-
-  branch_counts(n, tok, tree, branch_ct, num_events);
-
-  do {
-    probs[t] = get_binary_prob(branch_ct[t][0], branch_ct[t][1]);
-  } while (++t < tree_len);
+  const unsigned int num_events[ /* n */ ],
+  unsigned int tok0_offset) {
+  convert_distribution(0, tree, probs, branch_ct, num_events, tok0_offset);
 }
index f9f1d135e542dba3f9ae1d574ff6d7df80fa6b19..9297d5280157691ec8b490508079a44566d511a1 100644 (file)
@@ -47,12 +47,11 @@ void vp9_tokens_from_tree_offset(struct vp9_token_struct *, vp9_tree,
    taken for each node on the tree; this facilitiates decisions as to
    probability updates. */
 
-void vp9_tree_probs_from_distribution(int n,  /* n = size of alphabet */
-                                      vp9_token tok[ /* n */ ],
-                                      vp9_tree tree,
+void vp9_tree_probs_from_distribution(vp9_tree tree,
                                       vp9_prob probs[ /* n - 1 */ ],
                                       unsigned int branch_ct[ /* n - 1 */ ][2],
-                                      const unsigned int num_events[ /* n */ ]);
+                                      const unsigned int num_events[ /* n */ ],
+                                      unsigned int tok0_offset);
 
 static INLINE vp9_prob clip_prob(int p) {
   return (p > 255) ? 255u : (p < 1) ? 1u : p;
index b05da870c8183066e7871ba174ab3cb086986c39..fcbd3a1d6732a33a25d50d279c94a430629d4248 100644 (file)
@@ -110,8 +110,8 @@ static void update_mode(
   unsigned int new_b = 0, old_b = 0;
   int i = 0;
 
-  vp9_tree_probs_from_distribution(n--, tok, tree,
-                                   Pnew, bct, num_events);
+  vp9_tree_probs_from_distribution(tree, Pnew, bct, num_events, 0);
+  n--;
 
   do {
     new_b += cost_branch(bct[i], Pnew[i]);
@@ -167,10 +167,9 @@ static void update_switchable_interp_probs(VP9_COMP *cpi,
   int i, j;
   for (j = 0; j <= VP9_SWITCHABLE_FILTERS; ++j) {
     vp9_tree_probs_from_distribution(
-        VP9_SWITCHABLE_FILTERS,
-        vp9_switchable_interp_encodings, vp9_switchable_interp_tree,
+        vp9_switchable_interp_tree,
         pc->fc.switchable_interp_prob[j], branch_ct,
-        cpi->switchable_interp_count[j]);
+        cpi->switchable_interp_count[j], 0);
     for (i = 0; i < VP9_SWITCHABLE_FILTERS - 1; ++i) {
       if (pc->fc.switchable_interp_prob[j][i] < 1)
         pc->fc.switchable_interp_prob[j][i] = 1;
@@ -1189,11 +1188,10 @@ static void build_tree_distribution(vp9_coeff_probs *coef_probs,
         for (l = 0; l < PREV_COEF_CONTEXTS; ++l) {
           if (l >= 3 && k == 0)
             continue;
-          vp9_tree_probs_from_distribution(MAX_ENTROPY_TOKENS,
-                                           vp9_coef_encodings, vp9_coef_tree,
+          vp9_tree_probs_from_distribution(vp9_coef_tree,
                                            coef_probs[i][j][k][l],
                                            coef_branch_ct[i][j][k][l],
-                                           coef_counts[i][j][k][l]);
+                                           coef_counts[i][j][k][l], 0);
 #ifdef ENTROPY_STATS
         if (!cpi->dummy_packing)
           for (t = 0; t < MAX_ENTROPY_TOKENS; ++t)