]> granicus.if.org Git - libvpx/commitdiff
Use uniform sampling as initial centers for k-means
authorJingning Han <jingning@google.com>
Fri, 12 Apr 2019 22:59:51 +0000 (15:59 -0700)
committerJingning Han <jingning@google.com>
Fri, 12 Apr 2019 22:59:51 +0000 (15:59 -0700)
The Wiener variance output has been sorted prior to the clustering,
which allows to directly use the uniform sampling as the initial
center points. It avoids empty cluster situations when the samples
are heavily distributed at two far ends and leave the middle empty.

Change-Id: I159fbfa6bbb4aafd19411fd005666d144cca30fc

vp9/encoder/vp9_encodeframe.c

index 933cd3cb43e9856b22e03904b5510b60dd45f7b3..d6adce2f648ff0627aa6cef3b6fc1ea7aeb424de 100644 (file)
@@ -5788,13 +5788,11 @@ int vp9_get_group_idx(double value, double *boundary_ls, int k) {
 
 void vp9_kmeans(double *ctr_ls, double *boundary_ls, int *count_ls, int k,
                 KMEANS_DATA *arr, int size) {
-  double min, max;
-  double step;
   int i, j;
   int itr;
   int group_idx;
-  double sum;
-  int count;
+  double sum[MAX_KMEANS_GROUPS];
+  int count[MAX_KMEANS_GROUPS];
 
   vpx_clear_system_state();
 
@@ -5802,38 +5800,37 @@ void vp9_kmeans(double *ctr_ls, double *boundary_ls, int *count_ls, int k,
 
   qsort(arr, size, sizeof(*arr), compare_kmeans_data);
 
-  min = arr[0].value;
-  max = arr[size - 1].value;
-
   // initialize the center points
-  step = (max - min) * 1. / k;
   for (j = 0; j < k; ++j) {
-    ctr_ls[j] = min + j * step + step / 2;
+    ctr_ls[j] = arr[(size * j) / k].value;
   }
 
   for (itr = 0; itr < 10; ++itr) {
     compute_boundary_ls(ctr_ls, k, boundary_ls);
-    group_idx = 0;
-    count = 0;
-    sum = 0;
+    for (i = 0; i < MAX_KMEANS_GROUPS; ++i) {
+      sum[i] = 0;
+      count[i] = 0;
+    }
+
     for (i = 0; i < size; ++i) {
+      // place samples into clusters
+      group_idx = 0;
       while (arr[i].value >= boundary_ls[group_idx]) {
         ++group_idx;
         if (group_idx == k - 1) {
           break;
         }
       }
+      sum[group_idx] += arr[i].value;
+      ++count[group_idx];
+    }
 
-      sum += arr[i].value;
-      ++count;
+    for (group_idx = 0; group_idx < k; ++group_idx) {
+      if (count[group_idx] > 0)
+        ctr_ls[group_idx] = sum[group_idx] / count[group_idx];
 
-      if (i + 1 == size || arr[i + 1].value >= boundary_ls[group_idx]) {
-        if (count > 0) {
-          ctr_ls[group_idx] = sum / count;
-        }
-        count = 0;
-        sum = 0;
-      }
+      sum[group_idx] = 0;
+      count[group_idx] = 0;
     }
   }