]> granicus.if.org Git - libvpx/commitdiff
Simplify vp9_adapt_nmv_probs
authorJohn Koleszar <jkoleszar@google.com>
Mon, 11 Mar 2013 16:34:05 +0000 (09:34 -0700)
committerJohn Koleszar <jkoleszar@google.com>
Mon, 11 Mar 2013 16:44:22 +0000 (09:44 -0700)
Remove the temporary branch count arrays and build the adapted probabilities
while walking the tree. Gives an additional 1.5% or so on CIF.

Change-Id: I875d61e5e0ec778e5d2f7f9d0837b989a91cf3a3

vp9/common/vp9_entropymv.c

index ab87dfee2e6e0b6b8db20b6eb7beb89745cca643..56cebff8f80d9a845f1ea777496a1b30321adb34 100644 (file)
@@ -211,14 +211,16 @@ void vp9_increment_nmv(const MV *mv, const MV *ref, nmv_context_counts *mvctx,
   }
 }
 
-static void adapt_prob(vp9_prob *dest, vp9_prob prep, vp9_prob newp,
+static void adapt_prob(vp9_prob *dest, vp9_prob prep,
                        unsigned int ct[2]) {
   int count = ct[0] + ct[1];
-
   if (count) {
+    vp9_prob newp = get_binary_prob(ct[0], ct[1]);
     count = count > MV_COUNT_SAT ? MV_COUNT_SAT : count;
     *dest = weighted_prob(prep, newp,
                           MV_MAX_UPDATE_FACTOR * count / MV_COUNT_SAT);
+  } else {
+    *dest = prep;
   }
 }
 
@@ -294,18 +296,43 @@ void vp9_counts_to_nmv_context(
   }
 }
 
+static unsigned int adapt_probs(unsigned int i,
+                                vp9_tree tree,
+                                vp9_prob this_probs[],
+                                const vp9_prob last_probs[],
+                                const unsigned int num_events[]) {
+  unsigned int left, right, weight;
+  vp9_prob this_prob;
+
+  if (tree[i] <= 0) {
+    left = num_events[-tree[i]];
+  } else {
+    left = adapt_probs(tree[i], tree, this_probs, last_probs,
+                       num_events);
+  }
+  if (tree[i + 1] <= 0) {
+    right = num_events[-tree[i + 1]];
+  } else {
+    right = adapt_probs(tree[i + 1], tree, this_probs, last_probs,
+                        num_events);
+  }
+
+  weight = left + right;
+  if (weight) {
+    this_prob = get_binary_prob(left, right);
+    weight = weight > MV_COUNT_SAT ? MV_COUNT_SAT : weight;
+    this_prob = weighted_prob(last_probs[i>>1], this_prob,
+                              MV_MAX_UPDATE_FACTOR * weight / MV_COUNT_SAT);
+  } else {
+    this_prob = last_probs[i >> 1];
+  }
+  this_probs[i>>1] = this_prob;
+  return left + right;
+}
+
+
 void vp9_adapt_nmv_probs(VP9_COMMON *cm, int usehp) {
-  int i, j, k;
-  nmv_context prob;
-  unsigned int branch_ct_joint[MV_JOINTS - 1][2];
-  unsigned int branch_ct_sign[2][2];
-  unsigned int branch_ct_classes[2][MV_CLASSES - 1][2];
-  unsigned int branch_ct_class0[2][CLASS0_SIZE - 1][2];
-  unsigned int branch_ct_bits[2][MV_OFFSET_BITS][2];
-  unsigned int branch_ct_class0_fp[2][CLASS0_SIZE][4 - 1][2];
-  unsigned int branch_ct_fp[2][4 - 1][2];
-  unsigned int branch_ct_class0_hp[2][2];
-  unsigned int branch_ct_hp[2][2];
+  int i, j;
 #ifdef MV_COUNT_TESTING
   printf("joints count: ");
   for (j = 0; j < MV_JOINTS; ++j) printf("%d ", cm->fc.NMVcount.joints[j]);
@@ -366,75 +393,48 @@ void vp9_adapt_nmv_probs(VP9_COMMON *cm, int usehp) {
   smooth_counts(&cm->fc.NMVcount.comps[0]);
   smooth_counts(&cm->fc.NMVcount.comps[1]);
 #endif
-  vp9_counts_to_nmv_context(&cm->fc.NMVcount,
-                            &prob,
-                            usehp,
-                            branch_ct_joint,
-                            branch_ct_sign,
-                            branch_ct_classes,
-                            branch_ct_class0,
-                            branch_ct_bits,
-                            branch_ct_class0_fp,
-                            branch_ct_fp,
-                            branch_ct_class0_hp,
-                            branch_ct_hp);
-
-  for (j = 0; j < MV_JOINTS - 1; ++j) {
-    adapt_prob(&cm->fc.nmvc.joints[j],
-               cm->fc.pre_nmvc.joints[j],
-               prob.joints[j],
-               branch_ct_joint[j]);
-  }
+  vp9_counts_process(&cm->fc.NMVcount, usehp);
+
+  adapt_probs(0, vp9_mv_joint_tree,
+              cm->fc.nmvc.joints, cm->fc.pre_nmvc.joints,
+              cm->fc.NMVcount.joints);
+
   for (i = 0; i < 2; ++i) {
     adapt_prob(&cm->fc.nmvc.comps[i].sign,
                cm->fc.pre_nmvc.comps[i].sign,
-               prob.comps[i].sign,
-               branch_ct_sign[i]);
-    for (j = 0; j < MV_CLASSES - 1; ++j) {
-      adapt_prob(&cm->fc.nmvc.comps[i].classes[j],
-                 cm->fc.pre_nmvc.comps[i].classes[j],
-                 prob.comps[i].classes[j],
-                 branch_ct_classes[i][j]);
-    }
-    for (j = 0; j < CLASS0_SIZE - 1; ++j) {
-      adapt_prob(&cm->fc.nmvc.comps[i].class0[j],
-                 cm->fc.pre_nmvc.comps[i].class0[j],
-                 prob.comps[i].class0[j],
-                 branch_ct_class0[i][j]);
-    }
+               cm->fc.NMVcount.comps[i].sign);
+    adapt_probs(0, vp9_mv_class_tree,
+                cm->fc.nmvc.comps[i].classes, cm->fc.pre_nmvc.comps[i].classes,
+                cm->fc.NMVcount.comps[i].classes);
+    adapt_probs(0, vp9_mv_class0_tree,
+                cm->fc.nmvc.comps[i].class0, cm->fc.pre_nmvc.comps[i].class0,
+                cm->fc.NMVcount.comps[i].class0);
     for (j = 0; j < MV_OFFSET_BITS; ++j) {
       adapt_prob(&cm->fc.nmvc.comps[i].bits[j],
                  cm->fc.pre_nmvc.comps[i].bits[j],
-                 prob.comps[i].bits[j],
-                 branch_ct_bits[i][j]);
+                 cm->fc.NMVcount.comps[i].bits[j]);
     }
   }
   for (i = 0; i < 2; ++i) {
     for (j = 0; j < CLASS0_SIZE; ++j) {
-      for (k = 0; k < 3; ++k) {
-        adapt_prob(&cm->fc.nmvc.comps[i].class0_fp[j][k],
-                   cm->fc.pre_nmvc.comps[i].class0_fp[j][k],
-                   prob.comps[i].class0_fp[j][k],
-                   branch_ct_class0_fp[i][j][k]);
-      }
-    }
-    for (j = 0; j < 3; ++j) {
-      adapt_prob(&cm->fc.nmvc.comps[i].fp[j],
-                 cm->fc.pre_nmvc.comps[i].fp[j],
-                 prob.comps[i].fp[j],
-                 branch_ct_fp[i][j]);
+      adapt_probs(0, vp9_mv_fp_tree,
+                  cm->fc.nmvc.comps[i].class0_fp[j],
+                  cm->fc.pre_nmvc.comps[i].class0_fp[j],
+                  cm->fc.NMVcount.comps[i].class0_fp[j]);
     }
+    adapt_probs(0, vp9_mv_fp_tree,
+                cm->fc.nmvc.comps[i].fp,
+                cm->fc.pre_nmvc.comps[i].fp,
+                cm->fc.NMVcount.comps[i].fp);
   }
   if (usehp) {
     for (i = 0; i < 2; ++i) {
       adapt_prob(&cm->fc.nmvc.comps[i].class0_hp,
                  cm->fc.pre_nmvc.comps[i].class0_hp,
-                 prob.comps[i].class0_hp,
-                 branch_ct_class0_hp[i]);
+                 cm->fc.NMVcount.comps[i].class0_hp);
       adapt_prob(&cm->fc.nmvc.comps[i].hp,
                  cm->fc.pre_nmvc.comps[i].hp,
-                 prob.comps[i].hp,
-                 branch_ct_hp[i]);
+                 cm->fc.NMVcount.comps[i].hp);
     }
   }
 }