]> granicus.if.org Git - libvpx/commitdiff
Enable tile-adaptive restoration
authorDebargha Mukherjee <debargha@google.com>
Thu, 8 Sep 2016 22:15:17 +0000 (15:15 -0700)
committerDebargha Mukherjee <debargha@google.com>
Sat, 17 Sep 2016 16:46:28 +0000 (09:46 -0700)
Includes a major refactoring/enhancement to support
tile-adaptive switchable restoration. The framework can be
readily extended to add more restoration schemes in the
future. Also includes various cleanups and fixes.

Specifically the framework allows restoration to be conducted
on tiles such that each tile can be either left unrestored, or
use bilateral or wiener filtering.

There is a modest improvemnt in coding efficiency (0.1 - 0.2%).

Further enhancements will be added subsequently to improve coding
efficiency and complexity.

Change-Id: I5ebedb04785ce1ef6f324abe209e925c2d6cbe8a

16 files changed:
aom_dsp/psnr.c
aom_dsp/psnr.h
av1/common/alloccommon.c
av1/common/entropymode.c
av1/common/entropymode.h
av1/common/enums.h
av1/common/loopfilter.c
av1/common/loopfilter.h
av1/common/onyxc_int.h
av1/common/restoration.c
av1/common/restoration.h
av1/decoder/decodeframe.c
av1/encoder/bitstream.c
av1/encoder/encoder.h
av1/encoder/pickrst.c
av1/encoder/rd.c

index e69ffb442c2eadc03433b1ec5ccb3402b085502d..db789a33e7d80193af8f65645f57c72d74e24d21 100644 (file)
@@ -177,6 +177,14 @@ static int64_t highbd_get_sse(const uint8_t *a, int a_stride, const uint8_t *b,
 }
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
+int64_t aom_get_y_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b,
+                           int hstart, int width, int vstart, int height) {
+  return get_sse(a->y_buffer + vstart * a->y_stride + hstart, a->y_stride,
+                 b->y_buffer + vstart * b->y_stride + hstart, b->y_stride,
+                 width, height);
+}
+
 int64_t aom_get_y_sse(const YV12_BUFFER_CONFIG *a,
                       const YV12_BUFFER_CONFIG *b) {
   assert(a->y_crop_width == b->y_crop_width);
@@ -205,6 +213,16 @@ int64_t aom_get_v_sse(const YV12_BUFFER_CONFIG *a,
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
+int64_t aom_highbd_get_y_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b,
+                                  int hstart, int width,
+                                  int vstart, int height) {
+  return highbd_get_sse(
+      a->y_buffer + vstart * a->y_stride + hstart, a->y_stride,
+      b->y_buffer + vstart * b->y_stride + hstart, b->y_stride,
+      width, height);
+}
+
 int64_t aom_highbd_get_y_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b) {
   assert(a->y_crop_width == b->y_crop_width);
index 29ccb5f0537668ca71b3d7f4b5fb9326c0bd06ec..303573bbd9a61f257a9e3201fb15082a6be5f97e 100644 (file)
@@ -35,10 +35,17 @@ typedef struct {
 * \param[in]    sse           Sum of squared errors
 */
 double aom_sse_to_psnr(double samples, double peak, double sse);
+int64_t aom_get_y_sse_part(const YV12_BUFFER_CONFIG *a,
+                           const YV12_BUFFER_CONFIG *b,
+                           int hstart, int width, int vstart, int height);
 int64_t aom_get_y_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
 int64_t aom_get_u_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
 int64_t aom_get_v_sse(const YV12_BUFFER_CONFIG *a, const YV12_BUFFER_CONFIG *b);
 #if CONFIG_AOM_HIGHBITDEPTH
+int64_t aom_highbd_get_y_sse_part(const YV12_BUFFER_CONFIG *a,
+                                  const YV12_BUFFER_CONFIG *b,
+                                  int hstart, int width,
+                                  int vstart, int height);
 int64_t aom_highbd_get_y_sse(const YV12_BUFFER_CONFIG *a,
                              const YV12_BUFFER_CONFIG *b);
 int64_t aom_highbd_get_u_sse(const YV12_BUFFER_CONFIG *a,
index dab01168f70fafd24ee412471decd8d223cbe2a4..db4fbf71e76c775171ebc4057e2931220648629b 100644 (file)
@@ -83,6 +83,8 @@ void av1_free_ref_frame_buffers(BufferPool *pool) {
 
 #if CONFIG_LOOP_RESTORATION
 void av1_free_restoration_buffers(AV1_COMMON *cm) {
+  aom_free(cm->rst_info.restoration_type);
+  cm->rst_info.restoration_type = NULL;
   aom_free(cm->rst_info.bilateral_level);
   cm->rst_info.bilateral_level = NULL;
   aom_free(cm->rst_info.vfilter);
index e1593e30e8b8e223e918c6ca2759f0f984938943..856fa356150ae1084fb517f14d34d3b93258fba6 100644 (file)
@@ -857,6 +857,17 @@ static const aom_prob
       },
     };
 
+#if CONFIG_LOOP_RESTORATION
+const aom_tree_index
+    av1_switchable_restore_tree[TREE_SIZE(RESTORE_SWITCHABLE_TYPES)] = {
+      -RESTORE_NONE, 2,
+      -RESTORE_BILATERAL, -RESTORE_WIENER,
+    };
+
+static const aom_prob
+    default_switchable_restore_prob[RESTORE_SWITCHABLE_TYPES - 1] = {32, 128};
+#endif  // CONFIG_LOOP_RESTORATION
+
 #if CONFIG_EXT_TX && CONFIG_RECT_TX && CONFIG_VAR_TX
 // the probability of (0) using recursive square tx partition vs.
 // (1) biggest rect tx for 4X8-8X4/8X16-16X8/16X32-32X16 blocks
@@ -1340,6 +1351,9 @@ static void init_mode_probs(FRAME_CONTEXT *fc) {
 #endif  // CONFIG_EXT_INTRA
   av1_copy(fc->inter_ext_tx_prob, default_inter_ext_tx_prob);
   av1_copy(fc->intra_ext_tx_prob, default_intra_ext_tx_prob);
+#if CONFIG_LOOP_RESTORATION
+  av1_copy(fc->switchable_restore_prob, default_switchable_restore_prob);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 #if CONFIG_EXT_INTERP
index 7968484fbb910cefbcd0dd79d38d9444b8863452..c389e182dbadaa7cf67d1a04077e5ea2718da3cf 100644 (file)
@@ -128,6 +128,9 @@ typedef struct frame_contexts {
 #if CONFIG_GLOBAL_MOTION
   aom_prob global_motion_types_prob[GLOBAL_MOTION_TYPES - 1];
 #endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_LOOP_RESTORATION
+  aom_prob switchable_restore_prob[RESTORE_SWITCHABLE_TYPES - 1];
+#endif  // CONFIG_LOOP_RESTORATION
 } FRAME_CONTEXT;
 
 typedef struct FRAME_COUNTS {
@@ -263,6 +266,13 @@ extern const aom_tree_index av1_ext_tx_tree[TREE_SIZE(TX_TYPES)];
 extern const aom_tree_index av1_motvar_tree[TREE_SIZE(MOTION_VARIATIONS)];
 #endif  // CONFIG_OBMC || CONFIG_WARPED_MOTION
 
+#if CONFIG_LOOP_RESTORATION
+#define RESTORE_NONE_BILATERAL_PROB 16
+#define RESTORE_NONE_WIENER_PROB 64
+extern const aom_tree_index
+    av1_switchable_restore_tree[TREE_SIZE(RESTORE_SWITCHABLE_TYPES)];
+#endif  // CONFIG_LOOP_RESTORATION
+
 void av1_setup_past_independence(struct AV1Common *cm);
 
 void av1_adapt_intra_frame_probs(struct AV1Common *cm);
index b1ac2a01acb2c60b371a58370b3e8c4bdbd89530..c9d321139da37241a1b63c791713f82255ab889f 100644 (file)
@@ -433,6 +433,16 @@ typedef TX_SIZE TXFM_CONTEXT;
 #define MAX_SUPERTX_BLOCK_SIZE BLOCK_32X32
 #endif  // CONFIG_SUPERTX
 
+#if CONFIG_LOOP_RESTORATION
+typedef enum {
+  RESTORE_NONE,
+  RESTORE_BILATERAL,
+  RESTORE_WIENER,
+  RESTORE_SWITCHABLE,
+  RESTORE_SWITCHABLE_TYPES = RESTORE_SWITCHABLE,
+  RESTORE_TYPES,
+} RestorationType;
+#endif  // CONFIG_LOOP_RESTORATION
 #ifdef __cplusplus
 }  // extern "C"
 #endif
index 2147bb8267e3d7bb6ab69532201b53e786db52ad..f45f3db2921cfe658527944d44048fc5edd79ba1 100644 (file)
@@ -16,7 +16,6 @@
 #include "av1/common/loopfilter.h"
 #include "av1/common/onyxc_int.h"
 #include "av1/common/reconinter.h"
-#include "av1/common/restoration.h"
 #include "aom_dsp/aom_dsp_common.h"
 #include "aom_mem/aom_mem.h"
 #include "aom_ports/mem.h"
index ae0ef8a34aa220d664b92df9d0b1ff41b0c13831..975cbdf197f922195bebfc980592c9b6d869b378 100644 (file)
@@ -16,7 +16,6 @@
 #include "./aom_config.h"
 
 #include "av1/common/blockd.h"
-#include "av1/common/restoration.h"
 #include "av1/common/seg_common.h"
 
 #ifdef __cplusplus
index a14b34f5ce47c4203785af49dfadac889ff93152..6cd6cbeaf6b5e2dbfd17c0f8961bd4dc9f52cbfc 100644 (file)
@@ -25,7 +25,9 @@
 #include "av1/common/frame_buffers.h"
 #include "av1/common/quant_common.h"
 #include "av1/common/tile_common.h"
+#if CONFIG_LOOP_RESTORATION
 #include "av1/common/restoration.h"
+#endif  // CONFIG_LOOP_RESTORATION
 
 #ifdef __cplusplus
 extern "C" {
index d50181ed10a2027954d08a67a269917e3551aae3..4f44e126ff23ddaee7e751384b8c45de88ba47b9 100644 (file)
@@ -70,36 +70,6 @@ static INLINE BilateralParamsType av1_bilateral_level_to_params(int index,
             : bilateral_level_to_params_arr[index];
 }
 
-typedef struct TileParams {
-  int width;
-  int height;
-} TileParams;
-
-static TileParams restoration_tile_sizes[RESTORATION_TILESIZES] = {
-  { 64, 64 }, { 128, 128 }, { 256, 256 }
-};
-
-void av1_get_restoration_tile_size(int tilesize, int width, int height,
-                                   int *tile_width, int *tile_height,
-                                   int *nhtiles, int *nvtiles) {
-  *tile_width = (tilesize < 0)
-                    ? width
-                    : AOMMIN(restoration_tile_sizes[tilesize].width, width);
-  *tile_height = (tilesize < 0)
-                     ? height
-                     : AOMMIN(restoration_tile_sizes[tilesize].height, height);
-  *nhtiles = (width + (*tile_width >> 1)) / *tile_width;
-  *nvtiles = (height + (*tile_height >> 1)) / *tile_height;
-}
-
-int av1_get_restoration_ntiles(int tilesize, int width, int height) {
-  int nhtiles, nvtiles;
-  int tile_width, tile_height;
-  av1_get_restoration_tile_size(tilesize, width, height, &tile_width,
-                                &tile_height, &nhtiles, &nvtiles);
-  return (nhtiles * nvtiles);
-}
-
 void av1_loop_restoration_precal() {
   int i;
   for (i = 0; i < BILATERAL_LEVELS_KF; i++) {
@@ -169,90 +139,75 @@ int av1_bilateral_level_bits(const AV1_COMMON *const cm) {
 void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
                                int kf, int width, int height) {
   int i, tile_idx;
-  rst->restoration_type = rsi->restoration_type;
+  rst->rsi = rsi;
+  rst->keyframe = kf;
   rst->subsampling_x = 0;
   rst->subsampling_y = 0;
-  if (rsi->restoration_type == RESTORE_BILATERAL) {
-    rst->tilesize_index = BILATERAL_TILESIZE;
-    rst->ntiles =
-        av1_get_restoration_ntiles(rst->tilesize_index, width, height);
-    av1_get_restoration_tile_size(rst->tilesize_index, width, height,
-                                  &rst->tile_width, &rst->tile_height,
-                                  &rst->nhtiles, &rst->nvtiles);
-    rst->bilateral_level = rsi->bilateral_level;
-    rst->wr_lut = (uint8_t **)malloc(sizeof(*rst->wr_lut) * rst->ntiles);
-    assert(rst->wr_lut != NULL);
-    rst->wx_lut = (uint8_t(**)[RESTORATION_WIN])malloc(sizeof(*rst->wx_lut) *
-                                                       rst->ntiles);
-    assert(rst->wx_lut != NULL);
-    for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      const int level = rsi->bilateral_level[tile_idx];
-      if (level >= 0) {
-        rst->wr_lut[tile_idx] = kf ? bilateral_filter_coeffs_r_kf[level]
-                                   : bilateral_filter_coeffs_r[level];
-        rst->wx_lut[tile_idx] = kf ? bilateral_filter_coeffs_s_kf[level]
-                                   : bilateral_filter_coeffs_s[level];
-      }
-    }
-  } else if (rsi->restoration_type == RESTORE_WIENER) {
-    rst->tilesize_index = WIENER_TILESIZE;
-    rst->ntiles =
-        av1_get_restoration_ntiles(rst->tilesize_index, width, height);
-    av1_get_restoration_tile_size(rst->tilesize_index, width, height,
-                                  &rst->tile_width, &rst->tile_height,
-                                  &rst->nhtiles, &rst->nvtiles);
-    rst->wiener_level = rsi->wiener_level;
-    rst->vfilter =
-        (int(*)[RESTORATION_WIN])malloc(sizeof(*rst->vfilter) * rst->ntiles);
-    assert(rst->vfilter != NULL);
-    rst->hfilter =
-        (int(*)[RESTORATION_WIN])malloc(sizeof(*rst->hfilter) * rst->ntiles);
-    assert(rst->hfilter != NULL);
+  rst->ntiles =
+      av1_get_rest_ntiles(width, height, &rst->tile_width,
+                          &rst->tile_height, &rst->nhtiles, &rst->nvtiles);
+  if (rsi->frame_restoration_type == RESTORE_WIENER) {
     for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-      rst->vfilter[tile_idx][RESTORATION_HALFWIN] =
-          rst->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
+      rsi->vfilter[tile_idx][RESTORATION_HALFWIN] =
+          rsi->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
       for (i = 0; i < RESTORATION_HALFWIN; ++i) {
-        rst->vfilter[tile_idx][i] =
-            rst->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
-                rsi->vfilter[tile_idx][i];
-        rst->hfilter[tile_idx][i] =
-            rst->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
-                rsi->hfilter[tile_idx][i];
-        rst->vfilter[tile_idx][RESTORATION_HALFWIN] -=
+        rsi->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+            rsi->vfilter[tile_idx][i];
+        rsi->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+            rsi->hfilter[tile_idx][i];
+        rsi->vfilter[tile_idx][RESTORATION_HALFWIN] -=
             2 * rsi->vfilter[tile_idx][i];
-        rst->hfilter[tile_idx][RESTORATION_HALFWIN] -=
+        rsi->hfilter[tile_idx][RESTORATION_HALFWIN] -=
             2 * rsi->hfilter[tile_idx][i];
       }
     }
+  } else if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
+    for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+      if (rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+        rsi->vfilter[tile_idx][RESTORATION_HALFWIN] =
+            rsi->hfilter[tile_idx][RESTORATION_HALFWIN] = RESTORATION_FILT_STEP;
+        for (i = 0; i < RESTORATION_HALFWIN; ++i) {
+          rsi->vfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+              rsi->vfilter[tile_idx][i];
+          rsi->hfilter[tile_idx][RESTORATION_WIN - 1 - i] =
+              rsi->hfilter[tile_idx][i];
+          rsi->vfilter[tile_idx][RESTORATION_HALFWIN] -=
+              2 * rsi->vfilter[tile_idx][i];
+          rsi->hfilter[tile_idx][RESTORATION_HALFWIN] -=
+              2 * rsi->hfilter[tile_idx][i];
+        }
+      }
+    }
   }
 }
 
-static void loop_bilateral_filter(uint8_t *data, int width, int height,
-                                  int stride, RestorationInternal *rst,
-                                  uint8_t *tmpdata, int tmpstride) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
+static void loop_bilateral_filter_tile(uint8_t *data, int tile_idx, int width,
+                                       int height, int stride,
+                                       RestorationInternal *rst,
+                                       uint8_t *tmpdata, int tmpstride) {
+  int i, j, subtile_idx;
   int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
-
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
 
-  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+  for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
     uint8_t *data_p, *tmpdata_p;
-    const uint8_t *wr_lut_ = rst->wr_lut[tile_idx] + BILATERAL_AMP_RANGE;
-
-    if (rst->bilateral_level[tile_idx] < 0) continue;
-
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
+    const int level =
+        rst->rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + subtile_idx];
+    uint8_t(*wx_lut)[RESTORATION_WIN];
+    uint8_t *wr_lut_;
+
+    if (level < 0) continue;
+    wr_lut_ = (rst->keyframe ? bilateral_filter_coeffs_r_kf[level]
+                             : bilateral_filter_coeffs_r[level]) +
+              BILATERAL_AMP_RANGE;
+    wx_lut = rst->keyframe ? bilateral_filter_coeffs_s_kf[level]
+                           : bilateral_filter_coeffs_s[level];
+
+    av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                             rst->nhtiles, rst->nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 
     data_p = data + h_start + v_start * stride;
     tmpdata_p = tmpdata + h_start + v_start * tmpstride;
@@ -264,8 +219,7 @@ static void loop_bilateral_filter(uint8_t *data, int width, int height,
         uint8_t *data_p2 = data_p + j - RESTORATION_HALFWIN * stride;
         for (y = -RESTORATION_HALFWIN; y <= RESTORATION_HALFWIN; ++y) {
           for (x = -RESTORATION_HALFWIN; x <= RESTORATION_HALFWIN; ++x) {
-            wt = (int)rst->wx_lut[tile_idx][y + RESTORATION_HALFWIN]
-                                 [x + RESTORATION_HALFWIN] *
+            wt = (int)wx_lut[y + RESTORATION_HALFWIN][x + RESTORATION_HALFWIN] *
                  (int)wr_lut_[data_p2[x] - data_p[j]];
             wtsum += wt;
             flsum += wt * data_p2[x];
@@ -287,6 +241,16 @@ static void loop_bilateral_filter(uint8_t *data, int width, int height,
   }
 }
 
+static void loop_bilateral_filter(uint8_t *data, int width, int height,
+                                  int stride, RestorationInternal *rst,
+                                  uint8_t *tmpdata, int tmpstride) {
+  int tile_idx;
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_bilateral_filter_tile(data, tile_idx, width, height, stride, rst,
+                               tmpdata, tmpstride);
+  }
+}
+
 uint8_t hor_sym_filter(uint8_t *d, int *hfilter) {
   int32_t s =
       (1 << (RESTORATION_FILT_BITS - 1)) + d[0] * hfilter[RESTORATION_HALFWIN];
@@ -305,17 +269,52 @@ uint8_t ver_sym_filter(uint8_t *d, int stride, int *vfilter) {
   return clip_pixel(s >> RESTORATION_FILT_BITS);
 }
 
+static void loop_wiener_filter_tile(uint8_t *data, int tile_idx, int width,
+                                    int height, int stride,
+                                    RestorationInternal *rst, uint8_t *tmpdata,
+                                    int tmpstride) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int i, j;
+  int h_start, h_end, v_start, v_end;
+  uint8_t *data_p, *tmpdata_p;
+
+  if (rst->rsi->wiener_level[tile_idx] == 0) return;
+  // Filter row-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 1, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *tmpdata_p++ = hor_sym_filter(data_p++, rst->rsi->hfilter[tile_idx]);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+  // Filter col-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 1,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *data_p++ =
+          ver_sym_filter(tmpdata_p++, tmpstride, rst->rsi->vfilter[tile_idx]);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+}
+
 static void loop_wiener_filter(uint8_t *data, int width, int height, int stride,
                                RestorationInternal *rst, uint8_t *tmpdata,
                                int tmpstride) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
-  int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
+  int i, tile_idx;
   uint8_t *data_p, *tmpdata_p;
 
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
   // Initialize tmp buffer
   data_p = data;
   tmpdata_p = tmpdata;
@@ -324,88 +323,65 @@ static void loop_wiener_filter(uint8_t *data, int width, int height, int stride,
     data_p += stride;
     tmpdata_p += tmpstride;
   }
-
-  // Filter row-wise tile-by-tile
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start = vtile_idx * tile_height;
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : height;
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *tmpdata_p++ = hor_sym_filter(data_p++, rst->hfilter[tile_idx]);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
-    }
+    loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst, tmpdata,
+                            tmpstride);
   }
+}
 
-  // Filter column-wise tile-by-tile (bands of thickness RESTORATION_HALFWIN
-  // at top and bottom of tiles allow filtering overlap, and are not optimally
-  // filtered)
+static void loop_switchable_filter(uint8_t *data, int width, int height,
+                                   int stride, RestorationInternal *rst,
+                                   uint8_t *tmpdata, int tmpstride) {
+  int i, tile_idx;
+  uint8_t *data_p, *tmpdata_p;
+
+  // Initialize tmp buffer
+  data_p = data;
+  tmpdata_p = tmpdata;
+  for (i = 0; i < height; ++i) {
+    memcpy(tmpdata_p, data_p, sizeof(*data_p) * width);
+    data_p += stride;
+    tmpdata_p += tmpstride;
+  }
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start = htile_idx * tile_width;
-    h_end =
-        (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width) : width;
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *data_p++ =
-            ver_sym_filter(tmpdata_p++, tmpstride, rst->vfilter[tile_idx]);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
+    if (rst->rsi->restoration_type[tile_idx] == RESTORE_BILATERAL) {
+      loop_bilateral_filter_tile(data, tile_idx, width, height, stride, rst,
+                                 tmpdata, tmpstride);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+      loop_wiener_filter_tile(data, tile_idx, width, height, stride, rst,
+                              tmpdata, tmpstride);
     }
   }
 }
 
 #if CONFIG_AOM_HIGHBITDEPTH
-static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
-                                         int stride, RestorationInternal *rst,
-                                         uint8_t *tmpdata8, int tmpstride,
-                                         int bit_depth) {
-  int i, j, tile_idx, htile_idx, vtile_idx;
+static void loop_bilateral_filter_tile_highbd(uint16_t *data, int tile_idx,
+                                              int width, int height, int stride,
+                                              RestorationInternal *rst,
+                                              uint16_t *tmpdata, int tmpstride,
+                                              int bit_depth) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int i, j, subtile_idx;
   int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
-
-  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
-  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
-
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
 
-  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+  for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
     uint16_t *data_p, *tmpdata_p;
-    const uint8_t *wr_lut_ = rst->wr_lut[tile_idx] + BILATERAL_AMP_RANGE;
-
-    if (rst->bilateral_level[tile_idx] < 0) continue;
-
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
+    const int level =
+        rst->rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + subtile_idx];
+    uint8_t(*wx_lut)[RESTORATION_WIN];
+    uint8_t *wr_lut_;
+
+    if (level < 0) continue;
+    wr_lut_ = (rst->keyframe ? bilateral_filter_coeffs_r_kf[level]
+                             : bilateral_filter_coeffs_r[level]) +
+              BILATERAL_AMP_RANGE;
+    wx_lut = rst->keyframe ? bilateral_filter_coeffs_s_kf[level]
+                           : bilateral_filter_coeffs_s[level];
+    av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                             rst->nhtiles, rst->nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 
     data_p = data + h_start + v_start * stride;
     tmpdata_p = tmpdata + h_start + v_start * tmpstride;
@@ -417,8 +393,7 @@ static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
         uint16_t *data_p2 = data_p + j - RESTORATION_HALFWIN * stride;
         for (y = -RESTORATION_HALFWIN; y <= RESTORATION_HALFWIN; ++y) {
           for (x = -RESTORATION_HALFWIN; x <= RESTORATION_HALFWIN; ++x) {
-            wt = (int)rst->wx_lut[tile_idx][y + RESTORATION_HALFWIN]
-                                 [x + RESTORATION_HALFWIN] *
+            wt = (int)wx_lut[y + RESTORATION_HALFWIN][x + RESTORATION_HALFWIN] *
                  (int)wr_lut_[data_p2[x] - data_p[j]];
             wtsum += wt;
             flsum += wt * data_p2[x];
@@ -441,6 +416,20 @@ static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
   }
 }
 
+static void loop_bilateral_filter_highbd(uint8_t *data8, int width, int height,
+                                         int stride, RestorationInternal *rst,
+                                         uint8_t *tmpdata8, int tmpstride,
+                                         int bit_depth) {
+  int tile_idx;
+  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
+  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
+
+  for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
+    loop_bilateral_filter_tile_highbd(data, tile_idx, width, height, stride,
+                                      rst, tmpdata, tmpstride, bit_depth);
+  }
+}
+
 uint16_t hor_sym_filter_highbd(uint16_t *d, int *hfilter, int bd) {
   int32_t s =
       (1 << (RESTORATION_FILT_BITS - 1)) + d[0] * hfilter[RESTORATION_HALFWIN];
@@ -459,20 +448,57 @@ uint16_t ver_sym_filter_highbd(uint16_t *d, int stride, int *vfilter, int bd) {
   return clip_pixel_highbd(s >> RESTORATION_FILT_BITS, bd);
 }
 
+static void loop_wiener_filter_tile_highbd(uint16_t *data, int tile_idx,
+                                           int width, int height, int stride,
+                                           RestorationInternal *rst,
+                                           uint16_t *tmpdata, int tmpstride,
+                                           int bit_depth) {
+  const int tile_width = rst->tile_width >> rst->subsampling_x;
+  const int tile_height = rst->tile_height >> rst->subsampling_y;
+  int h_start, h_end, v_start, v_end;
+  int i, j;
+  uint16_t *data_p, *tmpdata_p;
+
+  if (rst->rsi->wiener_level[tile_idx] == 0) return;
+  // Filter row-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 1, 0,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *tmpdata_p++ = hor_sym_filter_highbd(
+          data_p++, rst->rsi->hfilter[tile_idx], bit_depth);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+  // Filter col-wise
+  av1_get_rest_tile_limits(tile_idx, 0, 0, rst->nhtiles, rst->nvtiles,
+                           tile_width, tile_height, width, height, 0, 1,
+                           &h_start, &h_end, &v_start, &v_end);
+  data_p = data + h_start + v_start * stride;
+  tmpdata_p = tmpdata + h_start + v_start * tmpstride;
+  for (i = 0; i < (v_end - v_start); ++i) {
+    for (j = 0; j < (h_end - h_start); ++j) {
+      *data_p++ = ver_sym_filter_highbd(tmpdata_p++, tmpstride,
+                                        rst->rsi->vfilter[tile_idx], bit_depth);
+    }
+    data_p += stride - (h_end - h_start);
+    tmpdata_p += tmpstride - (h_end - h_start);
+  }
+}
+
 static void loop_wiener_filter_highbd(uint8_t *data8, int width, int height,
                                       int stride, RestorationInternal *rst,
                                       uint8_t *tmpdata8, int tmpstride,
                                       int bit_depth) {
   uint16_t *data = CONVERT_TO_SHORTPTR(data8);
   uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
-  int i, j, tile_idx, htile_idx, vtile_idx;
-  int h_start, h_end, v_start, v_end;
-  int tile_width, tile_height;
+  int i, tile_idx;
   uint16_t *data_p, *tmpdata_p;
 
-  tile_width = rst->tile_width >> rst->subsampling_x;
-  tile_height = rst->tile_height >> rst->subsampling_y;
-
   // Initialize tmp buffer
   data_p = data;
   tmpdata_p = tmpdata;
@@ -481,54 +507,36 @@ static void loop_wiener_filter_highbd(uint8_t *data8, int width, int height,
     data_p += stride;
     tmpdata_p += tmpstride;
   }
-
-  // Filter row-wise tile-by-tile
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                           : (width - RESTORATION_HALFWIN);
-    v_start = vtile_idx * tile_height;
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : height;
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *tmpdata_p++ =
-            hor_sym_filter_highbd(data_p++, rst->hfilter[tile_idx], bit_depth);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
-    }
+    loop_wiener_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
+                                   tmpdata, tmpstride, bit_depth);
   }
+}
+
+static void loop_switchable_filter_highbd(uint8_t *data8, int width, int height,
+                                          int stride, RestorationInternal *rst,
+                                          uint8_t *tmpdata8, int tmpstride,
+                                          int bit_depth) {
+  uint16_t *data = CONVERT_TO_SHORTPTR(data8);
+  uint16_t *tmpdata = CONVERT_TO_SHORTPTR(tmpdata8);
+  int i, tile_idx;
+  uint16_t *data_p, *tmpdata_p;
 
-  // Filter column-wise tile-by-tile (bands of thickness RESTORATION_HALFWIN
-  // at top and bottom of tiles allow filtering overlap, and are not optimally
-  // filtered)
+  // Initialize tmp buffer
+  data_p = data;
+  tmpdata_p = tmpdata;
+  for (i = 0; i < height; ++i) {
+    memcpy(tmpdata_p, data_p, sizeof(*data_p) * width);
+    data_p += stride;
+    tmpdata_p += tmpstride;
+  }
   for (tile_idx = 0; tile_idx < rst->ntiles; ++tile_idx) {
-    if (rst->wiener_level[tile_idx] == 0) continue;
-    htile_idx = tile_idx % rst->nhtiles;
-    vtile_idx = tile_idx / rst->nhtiles;
-    h_start = htile_idx * tile_width;
-    h_end =
-        (htile_idx < rst->nhtiles - 1) ? ((htile_idx + 1) * tile_width) : width;
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < rst->nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                           : (height - RESTORATION_HALFWIN);
-    data_p = data + h_start + v_start * stride;
-    tmpdata_p = tmpdata + h_start + v_start * tmpstride;
-    for (i = 0; i < (v_end - v_start); ++i) {
-      for (j = 0; j < (h_end - h_start); ++j) {
-        *data_p++ = ver_sym_filter_highbd(tmpdata_p++, tmpstride,
-                                          rst->vfilter[tile_idx], bit_depth);
-      }
-      data_p += stride - (h_end - h_start);
-      tmpdata_p += tmpstride - (h_end - h_start);
+    if (rst->rsi->restoration_type[tile_idx] == RESTORE_BILATERAL) {
+      loop_bilateral_filter_tile_highbd(data, tile_idx, width, height, stride,
+                                        rst, tmpdata, tmpstride, bit_depth);
+    } else if (rst->rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+      loop_wiener_filter_tile_highbd(data, tile_idx, width, height, stride, rst,
+                                     tmpdata, tmpstride, bit_depth);
     }
   }
 }
@@ -545,16 +553,23 @@ void av1_loop_restoration_rows(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
   int yend = end_mi_row << MI_SIZE_LOG2;
   int uvend = yend >> cm->subsampling_y;
   restore_func_type restore_func =
-      cm->rst_internal.restoration_type == RESTORE_BILATERAL
+      cm->rst_internal.rsi->frame_restoration_type == RESTORE_BILATERAL
           ? loop_bilateral_filter
-          : loop_wiener_filter;
+          : (cm->rst_internal.rsi->frame_restoration_type == RESTORE_WIENER
+                 ? loop_wiener_filter
+                 : loop_switchable_filter);
 #if CONFIG_AOM_HIGHBITDEPTH
   restore_func_highbd_type restore_func_highbd =
-      cm->rst_internal.restoration_type == RESTORE_BILATERAL
+      cm->rst_internal.rsi->frame_restoration_type == RESTORE_BILATERAL
           ? loop_bilateral_filter_highbd
-          : loop_wiener_filter_highbd;
+          : (cm->rst_internal.rsi->frame_restoration_type == RESTORE_WIENER
+                 ? loop_wiener_filter_highbd
+                 : loop_switchable_filter_highbd);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   YV12_BUFFER_CONFIG tmp_buf;
+
+  if (cm->rst_internal.rsi->frame_restoration_type == RESTORE_NONE) return;
+
   memset(&tmp_buf, 0, sizeof(YV12_BUFFER_CONFIG));
 
   yend = AOMMIN(yend, cm->height);
@@ -609,25 +624,13 @@ void av1_loop_restoration_rows(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
 #endif  // CONFIG_AOM_HIGHBITDEPTH
   }
   aom_free_frame_buffer(&tmp_buf);
-  if (cm->rst_internal.restoration_type == RESTORE_BILATERAL) {
-    free(cm->rst_internal.wr_lut);
-    cm->rst_internal.wr_lut = NULL;
-    free(cm->rst_internal.wx_lut);
-    cm->rst_internal.wx_lut = NULL;
-  }
-  if (cm->rst_internal.restoration_type == RESTORE_WIENER) {
-    free(cm->rst_internal.vfilter);
-    cm->rst_internal.vfilter = NULL;
-    free(cm->rst_internal.hfilter);
-    cm->rst_internal.hfilter = NULL;
-  }
 }
 
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, AV1_COMMON *cm,
                                 RestorationInfo *rsi, int y_only,
                                 int partial_frame) {
   int start_mi_row, end_mi_row, mi_rows_to_filter;
-  if (rsi->restoration_type != RESTORE_NONE) {
+  if (rsi->frame_restoration_type != RESTORE_NONE) {
     start_mi_row = 0;
     mi_rows_to_filter = cm->mi_rows;
     if (partial_frame && cm->mi_rows > 8) {
index d8a312d5b6c1f33fd07ca2b31e0b431fe6f857eb..3d4802f70cbf74f88293903dc4ee967991bcfb2d 100644 (file)
@@ -26,9 +26,10 @@ extern "C" {
 #define BILATERAL_LEVELS (1 << BILATERAL_LEVEL_BITS)
 // #define DEF_BILATERAL_LEVEL     2
 
-#define RESTORATION_TILESIZES 3
-#define BILATERAL_TILESIZE 1
-#define WIENER_TILESIZE 2
+#define RESTORATION_TILESIZE_SML 128
+#define RESTORATION_TILESIZE_BIG 256
+#define BILATERAL_SUBTILE_BITS 1
+#define BILATERAL_SUBTILES (1 << (2 * BILATERAL_SUBTILE_BITS))
 
 #define RESTORATION_HALFWIN 3
 #define RESTORATION_HALFWIN1 (RESTORATION_HALFWIN + 1)
@@ -56,43 +57,84 @@ extern "C" {
 #define WIENER_FILT_TAP2_MAXV \
   (WIENER_FILT_TAP2_MINV - 1 + (1 << WIENER_FILT_TAP2_BITS))
 
-typedef enum {
-  RESTORE_NONE,
-  RESTORE_BILATERAL,
-  RESTORE_WIENER,
-} RestorationType;
-
 typedef struct {
-  RestorationType restoration_type;
+  RestorationType frame_restoration_type;
+  RestorationType *restoration_type;
   // Bilateral filter
   int *bilateral_level;
   // Wiener filter
   int *wiener_level;
-  int (*vfilter)[RESTORATION_HALFWIN], (*hfilter)[RESTORATION_HALFWIN];
+  int (*vfilter)[RESTORATION_WIN], (*hfilter)[RESTORATION_WIN];
 } RestorationInfo;
 
 typedef struct {
-  RestorationType restoration_type;
+  RestorationInfo *rsi;
+  int keyframe;
   int subsampling_x;
   int subsampling_y;
-  int tilesize_index;
   int ntiles;
   int tile_width, tile_height;
   int nhtiles, nvtiles;
-  // Bilateral filter
-  int *bilateral_level;
-  uint8_t (**wx_lut)[RESTORATION_WIN];
-  uint8_t **wr_lut;
-  // Wiener filter
-  int *wiener_level;
-  int (*vfilter)[RESTORATION_WIN], (*hfilter)[RESTORATION_WIN];
 } RestorationInternal;
 
+static INLINE int get_rest_tilesize(int width, int height) {
+  if (width * height <= 352 * 288)
+    return RESTORATION_TILESIZE_SML;
+  else
+    return RESTORATION_TILESIZE_BIG;
+}
+
+static INLINE int av1_get_rest_ntiles(int width, int height,
+                                      int *tile_width, int *tile_height,
+                                      int *nhtiles, int *nvtiles) {
+  int nhtiles_, nvtiles_;
+  int tile_width_, tile_height_;
+  int tilesize = get_rest_tilesize(width, height);
+  tile_width_ = (tilesize < 0) ? width : AOMMIN(tilesize, width);
+  tile_height_ = (tilesize < 0) ? height : AOMMIN(tilesize, height);
+  nhtiles_ = (width + (tile_width_ >> 1)) / tile_width_;
+  nvtiles_ = (height + (tile_height_ >> 1)) / tile_height_;
+  if (tile_width) *tile_width = tile_width_;
+  if (tile_height) *tile_height = tile_height_;
+  if (nhtiles) *nhtiles = nhtiles_;
+  if (nvtiles) *nvtiles = nvtiles_;
+  return (nhtiles_ * nvtiles_);
+}
+
+static INLINE void av1_get_rest_tile_limits(
+    int tile_idx, int subtile_idx, int subtile_bits, int nhtiles, int nvtiles,
+    int tile_width, int tile_height, int im_width, int im_height, int clamp_h,
+    int clamp_v, int *h_start, int *h_end, int *v_start, int *v_end) {
+  const int htile_idx = tile_idx % nhtiles;
+  const int vtile_idx = tile_idx / nhtiles;
+  *h_start = htile_idx * tile_width;
+  *v_start = vtile_idx * tile_height;
+  *h_end = (htile_idx < nhtiles - 1) ? *h_start + tile_width : im_width;
+  *v_end = (vtile_idx < nvtiles - 1) ? *v_start + tile_height : im_height;
+  if (subtile_bits) {
+    const int num_subtiles_1d = (1 << subtile_bits);
+    const int subtile_width = (*h_end - *h_start) >> subtile_bits;
+    const int subtile_height = (*v_end - *v_start) >> subtile_bits;
+    const int subtile_idx_h = subtile_idx & (num_subtiles_1d - 1);
+    const int subtile_idx_v = subtile_idx >> subtile_bits;
+    *h_start += subtile_idx_h * subtile_width;
+    *v_start += subtile_idx_v * subtile_height;
+    *h_end = subtile_idx_h == num_subtiles_1d - 1 ? *h_end
+                                                  : *h_start + subtile_width;
+    *v_end = subtile_idx_v == num_subtiles_1d - 1 ? *v_end
+                                                  : *v_start + subtile_height;
+  }
+  if (clamp_h) {
+    *h_start = AOMMAX(*h_start, RESTORATION_HALFWIN);
+    *h_end = AOMMIN(*h_end, im_width - RESTORATION_HALFWIN);
+  }
+  if (clamp_v) {
+    *v_start = AOMMAX(*v_start, RESTORATION_HALFWIN);
+    *v_end = AOMMIN(*v_end, im_height - RESTORATION_HALFWIN);
+  }
+}
+
 int av1_bilateral_level_bits(const struct AV1Common *const cm);
-int av1_get_restoration_ntiles(int tilesize, int width, int height);
-void av1_get_restoration_tile_size(int tilesize, int width, int height,
-                                   int *tile_width, int *tile_height,
-                                   int *nhtiles, int *nvtiles);
 void av1_loop_restoration_init(RestorationInternal *rst, RestorationInfo *rsi,
                                int kf, int width, int height);
 void av1_loop_restoration_frame(YV12_BUFFER_CONFIG *frame, struct AV1Common *cm,
index de0b502eaf408834ae86d8a4dac2cdd945cc2002..2f32e94099c101d7cb07db34705c02a45e39b274 100644 (file)
@@ -1899,62 +1899,134 @@ static void setup_segmentation(AV1_COMMON *const cm,
 }
 
 #if CONFIG_LOOP_RESTORATION
-static void setup_restoration(AV1_COMMON *cm, struct aom_read_bit_buffer *rb) {
-  int i;
+static void decode_restoration_mode(AV1_COMMON *cm,
+                                    struct aom_read_bit_buffer *rb) {
   RestorationInfo *rsi = &cm->rst_info;
-  int ntiles;
   if (aom_rb_read_bit(rb)) {
-    if (aom_rb_read_bit(rb)) {
-      rsi->restoration_type = RESTORE_BILATERAL;
-      ntiles =
-          av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+    rsi->frame_restoration_type =
+        aom_rb_read_bit(rb) ? RESTORE_WIENER : RESTORE_BILATERAL;
+  } else {
+    rsi->frame_restoration_type =
+        aom_rb_read_bit(rb) ? RESTORE_SWITCHABLE : RESTORE_NONE;
+  }
+}
+
+static void decode_restoration(AV1_COMMON *cm, aom_reader *rb) {
+  int i;
+  RestorationInfo *rsi = &cm->rst_info;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height,
+                                         NULL, NULL, NULL, NULL);
+  if (rsi->frame_restoration_type != RESTORE_NONE) {
+    rsi->restoration_type = (RestorationType *)aom_realloc(
+        rsi->restoration_type, sizeof(*rsi->restoration_type) * ntiles);
+    if (rsi->frame_restoration_type == RESTORE_SWITCHABLE) {
       rsi->bilateral_level = (int *)aom_realloc(
-          rsi->bilateral_level, sizeof(*rsi->bilateral_level) * ntiles);
+          rsi->bilateral_level,
+          sizeof(*rsi->bilateral_level) * ntiles * BILATERAL_SUBTILES);
       assert(rsi->bilateral_level != NULL);
+      rsi->wiener_level = (int *)aom_realloc(
+          rsi->wiener_level, sizeof(*rsi->wiener_level) * ntiles);
+      assert(rsi->wiener_level != NULL);
+      rsi->vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
+          rsi->vfilter, sizeof(*rsi->vfilter) * ntiles);
+      assert(rsi->vfilter != NULL);
+      rsi->hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
+          rsi->hfilter, sizeof(*rsi->hfilter) * ntiles);
+      assert(rsi->hfilter != NULL);
       for (i = 0; i < ntiles; ++i) {
-        if (aom_rb_read_bit(rb)) {
-          rsi->bilateral_level[i] =
-              aom_rb_read_literal(rb, av1_bilateral_level_bits(cm));
-        } else {
-          rsi->bilateral_level[i] = -1;
+        rsi->restoration_type[i] = aom_read_tree(
+            rb, av1_switchable_restore_tree, cm->fc->switchable_restore_prob);
+        if (rsi->restoration_type[i] == RESTORE_WIENER) {
+          rsi->wiener_level[i] = 1;
+          rsi->vfilter[i][0] =
+              aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+              WIENER_FILT_TAP0_MINV;
+          rsi->vfilter[i][1] =
+              aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+              WIENER_FILT_TAP1_MINV;
+          rsi->vfilter[i][2] =
+              aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+              WIENER_FILT_TAP2_MINV;
+          rsi->hfilter[i][0] =
+              aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+              WIENER_FILT_TAP0_MINV;
+          rsi->hfilter[i][1] =
+              aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+              WIENER_FILT_TAP1_MINV;
+          rsi->hfilter[i][2] =
+              aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+              WIENER_FILT_TAP2_MINV;
+        } else if (rsi->restoration_type[i] == RESTORE_BILATERAL) {
+          int s;
+          for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+            const int j = i * BILATERAL_SUBTILES + s;
+#if BILATERAL_SUBTILES == 0
+            rsi->bilateral_level[j] =
+                aom_read_literal(rb, av1_bilateral_level_bits(cm));
+#else
+            if (aom_read(rb, RESTORE_NONE_BILATERAL_PROB)) {
+              rsi->bilateral_level[j] =
+                  aom_read_literal(rb, av1_bilateral_level_bits(cm));
+            } else {
+              rsi->bilateral_level[j] = -1;
+            }
+#endif
+          }
         }
       }
-    } else {
-      rsi->restoration_type = RESTORE_WIENER;
-      ntiles =
-          av1_get_restoration_ntiles(WIENER_TILESIZE, cm->width, cm->height);
+    } else if (rsi->frame_restoration_type == RESTORE_WIENER) {
       rsi->wiener_level = (int *)aom_realloc(
           rsi->wiener_level, sizeof(*rsi->wiener_level) * ntiles);
       assert(rsi->wiener_level != NULL);
-      rsi->vfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+      rsi->vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
           rsi->vfilter, sizeof(*rsi->vfilter) * ntiles);
       assert(rsi->vfilter != NULL);
-      rsi->hfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+      rsi->hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
           rsi->hfilter, sizeof(*rsi->hfilter) * ntiles);
       assert(rsi->hfilter != NULL);
       for (i = 0; i < ntiles; ++i) {
-        rsi->wiener_level[i] = aom_rb_read_bit(rb);
-        if (rsi->wiener_level[i]) {
-          rsi->vfilter[i][0] = aom_rb_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+        if (aom_read(rb, RESTORE_NONE_WIENER_PROB)) {
+          rsi->wiener_level[i] = 1;
+          rsi->restoration_type[i] = RESTORE_WIENER;
+          rsi->vfilter[i][0] = aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
                                WIENER_FILT_TAP0_MINV;
-          rsi->vfilter[i][1] = aom_rb_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+          rsi->vfilter[i][1] = aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
                                WIENER_FILT_TAP1_MINV;
-          rsi->vfilter[i][2] = aom_rb_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+          rsi->vfilter[i][2] = aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
                                WIENER_FILT_TAP2_MINV;
-          rsi->hfilter[i][0] = aom_rb_read_literal(rb, WIENER_FILT_TAP0_BITS) +
+          rsi->hfilter[i][0] = aom_read_literal(rb, WIENER_FILT_TAP0_BITS) +
                                WIENER_FILT_TAP0_MINV;
-          rsi->hfilter[i][1] = aom_rb_read_literal(rb, WIENER_FILT_TAP1_BITS) +
+          rsi->hfilter[i][1] = aom_read_literal(rb, WIENER_FILT_TAP1_BITS) +
                                WIENER_FILT_TAP1_MINV;
-          rsi->hfilter[i][2] = aom_rb_read_literal(rb, WIENER_FILT_TAP2_BITS) +
+          rsi->hfilter[i][2] = aom_read_literal(rb, WIENER_FILT_TAP2_BITS) +
                                WIENER_FILT_TAP2_MINV;
         } else {
-          rsi->vfilter[i][0] = rsi->vfilter[i][1] = rsi->vfilter[i][2] = 0;
-          rsi->hfilter[i][0] = rsi->hfilter[i][1] = rsi->hfilter[i][2] = 0;
+          rsi->wiener_level[i] = 0;
+          rsi->restoration_type[i] = RESTORE_NONE;
+        }
+      }
+    } else {
+      rsi->frame_restoration_type = RESTORE_BILATERAL;
+      rsi->bilateral_level = (int *)aom_realloc(
+          rsi->bilateral_level,
+          sizeof(*rsi->bilateral_level) * ntiles * BILATERAL_SUBTILES);
+      assert(rsi->bilateral_level != NULL);
+      for (i = 0; i < ntiles; ++i) {
+        int s;
+        rsi->restoration_type[i] = RESTORE_BILATERAL;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+          const int j = i * BILATERAL_SUBTILES + s;
+          if (aom_read(rb, RESTORE_NONE_BILATERAL_PROB)) {
+            rsi->bilateral_level[j] =
+                aom_read_literal(rb, av1_bilateral_level_bits(cm));
+          } else {
+            rsi->bilateral_level[j] = -1;
+          }
         }
       }
     }
   } else {
-    rsi->restoration_type = RESTORE_NONE;
+    rsi->frame_restoration_type = RESTORE_NONE;
   }
 }
 #endif  // CONFIG_LOOP_RESTORATION
@@ -3286,7 +3358,7 @@ static size_t read_uncompressed_header(AV1Decoder *pbi,
   setup_dering(cm, rb);
 #endif
 #if CONFIG_LOOP_RESTORATION
-  setup_restoration(cm, rb);
+  decode_restoration_mode(cm, rb);
 #endif  // CONFIG_LOOP_RESTORATION
   setup_quantization(cm, rb);
 #if CONFIG_AOM_HIGHBITDEPTH
@@ -3468,6 +3540,10 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data,
                        "Failed to allocate compressed header ANS decoder");
 #endif  // !CONFIG_ANS
 
+#if CONFIG_LOOP_RESTORATION
+  decode_restoration(cm, &r);
+#endif
+
   if (cm->tx_mode == TX_MODE_SELECT) {
     for (i = 0; i < TX_SIZES - 1; ++i)
       for (j = 0; j < TX_SIZE_CONTEXTS; ++j)
index 6578c0cffc9169058ce451ff37da812d71dfac92..f09a5cde0ef05fab2c38ea2d2bdba832ff6eb3f1 100644 (file)
@@ -150,6 +150,9 @@ static struct av1_token interintra_mode_encodings[INTERINTRA_MODES];
 #if CONFIG_OBMC || CONFIG_WARPED_MOTION
 static struct av1_token motvar_encodings[MOTION_VARIATIONS];
 #endif  // CONFIG_OBMC || CONFIG_WARPED_MOTION
+#if CONFIG_LOOP_RESTORATION
+static struct av1_token switchable_restore_encodings[RESTORE_SWITCHABLE_TYPES];
+#endif  // CONFIG_LOOP_RESTORATION
 
 void av1_encode_token_init(void) {
 #if CONFIG_EXT_TX
@@ -176,6 +179,10 @@ void av1_encode_token_init(void) {
   av1_tokens_from_tree(global_motion_types_encodings,
                        av1_global_motion_types_tree);
 #endif  // CONFIG_GLOBAL_MOTION
+#if CONFIG_LOOP_RESTORATION
+  av1_tokens_from_tree(switchable_restore_encodings,
+                       av1_switchable_restore_tree);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 static void write_intra_mode(aom_writer *w, PREDICTION_MODE mode,
@@ -2420,42 +2427,102 @@ static void update_coef_probs(AV1_COMP *cpi, aom_writer *w) {
 }
 
 #if CONFIG_LOOP_RESTORATION
-static void encode_restoration(AV1_COMMON *cm,
-                               struct aom_write_bit_buffer *wb) {
-  int i;
+static void encode_restoration_mode(AV1_COMMON *cm,
+                                    struct aom_write_bit_buffer *wb) {
   RestorationInfo *rst = &cm->rst_info;
-  aom_wb_write_bit(wb, rst->restoration_type != RESTORE_NONE);
-  if (rst->restoration_type != RESTORE_NONE) {
-    if (rst->restoration_type == RESTORE_BILATERAL) {
+  switch (rst->frame_restoration_type) {
+    case RESTORE_NONE:
+      aom_wb_write_bit(wb, 0);
+      aom_wb_write_bit(wb, 0);
+      break;
+    case RESTORE_SWITCHABLE:
+      aom_wb_write_bit(wb, 0);
+      aom_wb_write_bit(wb, 1);
+      break;
+    case RESTORE_BILATERAL:
       aom_wb_write_bit(wb, 1);
-      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
-        if (rst->bilateral_level[i] >= 0) {
-          aom_wb_write_bit(wb, 1);
-          aom_wb_write_literal(wb, rst->bilateral_level[i],
-                               av1_bilateral_level_bits(cm));
-        } else {
-          aom_wb_write_bit(wb, 0);
-        }
-      }
-    } else {
       aom_wb_write_bit(wb, 0);
+      break;
+    case RESTORE_WIENER:
+      aom_wb_write_bit(wb, 1);
+      aom_wb_write_bit(wb, 1);
+      break;
+    default: assert(0);
+  }
+}
+
+static void encode_restoration(AV1_COMMON *cm, aom_writer *wb) {
+  int i;
+  RestorationInfo *rst = &cm->rst_info;
+  if (rst->frame_restoration_type != RESTORE_NONE) {
+    if (rst->frame_restoration_type == RESTORE_SWITCHABLE) {
+      // RESTORE_SWITCHABLE
       for (i = 0; i < cm->rst_internal.ntiles; ++i) {
-        if (rst->wiener_level[i]) {
-          aom_wb_write_bit(wb, 1);
-          aom_wb_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
-                               WIENER_FILT_TAP0_BITS);
-          aom_wb_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
-                               WIENER_FILT_TAP1_BITS);
-          aom_wb_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
+        av1_write_token(
+            wb, av1_switchable_restore_tree,
+            cm->fc->switchable_restore_prob,
+            &switchable_restore_encodings[rst->restoration_type[i]]);
+        if (rst->restoration_type[i] == RESTORE_NONE) {
+        } else if (rst->restoration_type[i] == RESTORE_BILATERAL) {
+          int s;
+          for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+            const int j = i * BILATERAL_SUBTILES + s;
+#if BILATERAL_SUBTILES == 0
+            aom_write_literal(wb, rst->bilateral_level[j],
+                              av1_bilateral_level_bits(cm));
+#else
+            aom_write(wb, rst->bilateral_level[j] >= 0,
+                      RESTORE_NONE_BILATERAL_PROB);
+            if (rst->bilateral_level[j] >= 0) {
+              aom_write_literal(wb, rst->bilateral_level[j],
+                                av1_bilateral_level_bits(cm));
+            }
+#endif
+          }
+        } else {
+          aom_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
                                WIENER_FILT_TAP2_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
+          aom_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
                                WIENER_FILT_TAP0_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
+          aom_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
                                WIENER_FILT_TAP1_BITS);
-          aom_wb_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
+          aom_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
                                WIENER_FILT_TAP2_BITS);
-        } else {
-          aom_wb_write_bit(wb, 0);
+        }
+      }
+    } else if (rst->frame_restoration_type == RESTORE_BILATERAL) {
+      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+        int s;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+          const int j = i * BILATERAL_SUBTILES + s;
+          aom_write(wb, rst->bilateral_level[j] >= 0,
+                    RESTORE_NONE_BILATERAL_PROB);
+          if (rst->bilateral_level[j] >= 0) {
+            aom_write_literal(wb, rst->bilateral_level[j],
+                                 av1_bilateral_level_bits(cm));
+          }
+        }
+      }
+    } else if (rst->frame_restoration_type == RESTORE_WIENER) {
+      for (i = 0; i < cm->rst_internal.ntiles; ++i) {
+        aom_write(wb, rst->wiener_level[i] != 0, RESTORE_NONE_WIENER_PROB);
+        if (rst->wiener_level[i]) {
+          aom_write_literal(wb, rst->vfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->vfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->vfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                            WIENER_FILT_TAP2_BITS);
+          aom_write_literal(wb, rst->hfilter[i][0] - WIENER_FILT_TAP0_MINV,
+                            WIENER_FILT_TAP0_BITS);
+          aom_write_literal(wb, rst->hfilter[i][1] - WIENER_FILT_TAP1_MINV,
+                            WIENER_FILT_TAP1_BITS);
+          aom_write_literal(wb, rst->hfilter[i][2] - WIENER_FILT_TAP2_MINV,
+                            WIENER_FILT_TAP2_BITS);
         }
       }
     }
@@ -3183,7 +3250,7 @@ static void write_uncompressed_header(AV1_COMP *cpi,
   encode_dering(cm->dering_level, wb);
 #endif  // CONFIG_DERING
 #if CONFIG_LOOP_RESTORATION
-  encode_restoration(cm, wb);
+  encode_restoration_mode(cm, wb);
 #endif  // CONFIG_LOOP_RESTORATION
   encode_quantization(cm, wb);
   encode_segmentation(cm, xd, wb);
@@ -3282,6 +3349,11 @@ static uint32_t write_compressed_header(AV1_COMP *cpi, uint8_t *data) {
   header_bc = &real_header_bc;
   aom_start_encode(header_bc, data);
 #endif
+
+#if CONFIG_LOOP_RESTORATION
+  encode_restoration(cm, header_bc);
+#endif  // CONFIG_LOOP_RESTORATION
+
   update_txfm_probs(cm, header_bc, counts);
   update_coef_probs(cpi, header_bc);
 
index 821d2f19eb1bcfae2b01f0c2bc147ee7c18153bb..9902517251dce7017b176115c7f48bf9a2c2ea4a 100644 (file)
@@ -572,6 +572,9 @@ typedef struct AV1_COMP {
 #if CONFIG_EXT_INTRA
   int intra_filter_cost[INTRA_FILTERS + 1][INTRA_FILTERS];
 #endif  // CONFIG_EXT_INTRA
+#if CONFIG_LOOP_RESTORATION
+  int switchable_restore_cost[RESTORE_SWITCHABLE_TYPES];
+#endif  // CONFIG_LOOP_RESTORATION
 
   int multi_arf_allowed;
   int multi_arf_enabled;
index 22bd0195415459f866d4fbe4315d9953b090b5ff..00e46a689562cc4afb5e19931b11850ecd2180fe 100644 (file)
 #include "av1/encoder/pickrst.h"
 #include "av1/encoder/quantize.h"
 
-static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *sd,
+const int frame_level_restore_bits[RESTORE_TYPES] = { 2, 2, 2, 2 };
+
+static int64_t sse_restoration_tile(const YV12_BUFFER_CONFIG *src,
+                                    AV1_COMMON *const cm, int h_start,
+                                    int width, int v_start, int height) {
+  int64_t filt_err;
+#if CONFIG_AOM_HIGHBITDEPTH
+  if (cm->use_highbitdepth) {
+    filt_err = aom_highbd_get_y_sse_part(src, cm->frame_to_show, h_start, width,
+                                         v_start, height);
+  } else {
+    filt_err = aom_get_y_sse_part(src, cm->frame_to_show, h_start, width,
+                                  v_start, height);
+  }
+#else
+  filt_err = aom_get_y_sse_part(src, cm->frame_to_show, h_start, width, v_start,
+                                height);
+#endif  // CONFIG_AOM_HIGHBITDEPTH
+  return filt_err;
+}
+
+static int64_t try_restoration_tile(const YV12_BUFFER_CONFIG *src,
+                                    AV1_COMP *const cpi, RestorationInfo *rsi,
+                                    int partial_frame, int tile_idx,
+                                    int subtile_idx, int subtile_bits) {
+  AV1_COMMON *const cm = &cpi->common;
+  int64_t filt_err;
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+  (void)ntiles;
+
+  av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame);
+  av1_get_rest_tile_limits(tile_idx, subtile_idx, subtile_bits, nhtiles,
+                           nvtiles, tile_width, tile_height, cm->width,
+                           cm->height, 0, 0, &h_start, &h_end, &v_start,
+                           &v_end);
+  filt_err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                                  v_end - v_start);
+
+  // Re-instate the unfiltered frame
+  aom_yv12_copy_y(&cpi->last_frame_db, cm->frame_to_show);
+  return filt_err;
+}
+
+static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *src,
                                      AV1_COMP *const cpi, RestorationInfo *rsi,
                                      int partial_frame) {
   AV1_COMMON *const cm = &cpi->common;
@@ -36,12 +82,12 @@ static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *sd,
   av1_loop_restoration_frame(cm->frame_to_show, cm, rsi, 1, partial_frame);
 #if CONFIG_AOM_HIGHBITDEPTH
   if (cm->use_highbitdepth) {
-    filt_err = aom_highbd_get_y_sse(sd, cm->frame_to_show);
+    filt_err = aom_highbd_get_y_sse(src, cm->frame_to_show);
   } else {
-    filt_err = aom_get_y_sse(sd, cm->frame_to_show);
+    filt_err = aom_get_y_sse(src, cm->frame_to_show);
   }
 #else
-  filt_err = aom_get_y_sse(sd, cm->frame_to_show);
+  filt_err = aom_get_y_sse(src, cm->frame_to_show);
 #endif  // CONFIG_AOM_HIGHBITDEPTH
 
   // Re-instate the unfiltered frame
@@ -49,20 +95,24 @@ static int64_t try_restoration_frame(const YV12_BUFFER_CONFIG *sd,
   return filt_err;
 }
 
-static int search_bilateral_level(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
+static int search_bilateral_level(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                   int filter_level, int partial_frame,
-                                  int *bilateral_level, double *best_cost_ret) {
+                                  int *bilateral_level, double *best_cost_ret,
+                                  double *best_tile_cost) {
   AV1_COMMON *const cm = &cpi->common;
   int i, j, tile_idx;
   int64_t err;
   int bits;
-  double cost, best_cost, cost_norestore, cost_bilateral;
+  double cost, best_cost, cost_norestore, cost_bilateral,
+      cost_norestore_subtile;
   const int bilateral_level_bits = av1_bilateral_level_bits(&cpi->common);
   const int bilateral_levels = 1 << bilateral_level_bits;
   MACROBLOCK *x = &cpi->td.mb;
   RestorationInfo rsi;
-  const int ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
 
   //  Make a copy of the unfiltered / processed recon buffer
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
@@ -71,53 +121,94 @@ static int search_bilateral_level(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
   // RD cost associated with no restoration
-  rsi.restoration_type = RESTORE_NONE;
-  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-  bits = 0;
-  cost_norestore =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-  best_cost = cost_norestore;
+  rsi.frame_restoration_type = RESTORE_NONE;
+  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
+  // err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   // RD cost associated with bilateral filtering
-  rsi.restoration_type = RESTORE_BILATERAL;
-  rsi.bilateral_level =
-      (int *)aom_malloc(sizeof(*rsi.bilateral_level) * ntiles);
+  rsi.frame_restoration_type = RESTORE_BILATERAL;
+  rsi.bilateral_level = (int *)aom_malloc(sizeof(*rsi.bilateral_level) *
+                                          ntiles * BILATERAL_SUBTILES);
   assert(rsi.bilateral_level != NULL);
 
-  for (j = 0; j < ntiles; ++j) bilateral_level[j] = -1;
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) bilateral_level[j] = -1;
 
+  // TODO(debargha): This is a pretty inefficient way to find the best
+  // parameters per tile. Needs fixing.
   // Find best filter for each tile
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    for (j = 0; j < ntiles; ++j) rsi.bilateral_level[j] = -1;
-    best_cost = cost_norestore;
-    for (i = 0; i < bilateral_levels; ++i) {
-      rsi.bilateral_level[tile_idx] = i;
-      err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-      bits = bilateral_level_bits + 1;
-      // Normally the rate is rate in bits * 256 and dist is sum sq err * 64
-      // when RDCOST is used.  However below we just scale both in the correct
-      // ratios appropriately but not exactly by these values.
-      cost = RDCOST_DBL(x->rdmult, x->rddiv,
-                        (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-      if (cost < best_cost) {
-        bilateral_level[tile_idx] = i;
-        best_cost = cost;
+    int subtile_idx;
+    for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
+      const int fulltile_idx = tile_idx * BILATERAL_SUBTILES + subtile_idx;
+      av1_get_rest_tile_limits(tile_idx, subtile_idx, BILATERAL_SUBTILE_BITS,
+                               nhtiles, nvtiles, tile_width, tile_height,
+                               cm->width, cm->height, 0, 0, &h_start, &h_end,
+                               &v_start, &v_end);
+      err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                                 v_end - v_start);
+#if BILATERAL_SUBTILES
+      // #bits when a subtile is not restored
+      bits = av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, 0);
+#else
+      bits = 0;
+#endif
+      cost_norestore_subtile =
+          RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+      best_cost = cost_norestore_subtile;
+      for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j)
+        rsi.bilateral_level[j] = -1;
+
+      for (i = 0; i < bilateral_levels; ++i) {
+        rsi.bilateral_level[fulltile_idx] = i;
+        err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx,
+                                   subtile_idx, BILATERAL_SUBTILE_BITS);
+        bits = bilateral_level_bits << AV1_PROB_COST_SHIFT;
+        bits += av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, 1);
+        cost = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+        if (cost < best_cost) {
+          bilateral_level[fulltile_idx] = i;
+          best_cost = cost;
+        }
+      }
+    }
+    if (best_tile_cost) {
+      bits = 0;
+      for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j)
+        rsi.bilateral_level[j] = -1;
+      for (subtile_idx = 0; subtile_idx < BILATERAL_SUBTILES; ++subtile_idx) {
+        const int fulltile_idx = tile_idx * BILATERAL_SUBTILES + subtile_idx;
+        rsi.bilateral_level[fulltile_idx] = bilateral_level[fulltile_idx];
+        if (rsi.bilateral_level[fulltile_idx] >= 0)
+          bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
+#if BILATERAL_SUBTILES
+        bits += av1_cost_bit(RESTORE_NONE_BILATERAL_PROB,
+                             rsi.bilateral_level[fulltile_idx] >= 0);
+#endif
       }
+      err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0);
+      best_tile_cost[tile_idx] = RDCOST_DBL(
+          x->rdmult, x->rddiv,
+          (bits + cpi->switchable_restore_cost[RESTORE_BILATERAL]) >> 4, err);
     }
   }
   // Find cost for combined configuration
-  bits = 0;
-  for (j = 0; j < ntiles; ++j) {
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
     rsi.bilateral_level[j] = bilateral_level[j];
     if (rsi.bilateral_level[j] >= 0) {
-      bits += (bilateral_level_bits + 1);
-    } else {
-      bits += 1;
+      bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
     }
+#if BILATERAL_SUBTILES
+    bits +=
+        av1_cost_bit(RESTORE_NONE_BILATERAL_PROB, rsi.bilateral_level[j] >= 0);
+#endif
   }
-  err = try_restoration_frame(sd, cpi, &rsi, partial_frame);
-  cost_bilateral =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
+  cost_bilateral = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   aom_free(rsi.bilateral_level);
 
@@ -131,10 +222,11 @@ static int search_bilateral_level(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
   }
 }
 
-static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
+static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *src,
                                          AV1_COMP *cpi, int partial_frame,
                                          int *filter_best, int *bilateral_level,
-                                         double *best_cost_ret) {
+                                         double *best_cost_ret,
+                                         double *best_tile_cost) {
   const AV1_COMMON *const cm = &cpi->common;
   const struct loopfilter *const lf = &cm->lf;
   const int min_filter_level = 0;
@@ -147,7 +239,8 @@ static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
   int bilateral_success[MAX_LOOP_FILTER + 1];
 
   const int ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
+      av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  double *tile_cost = (double *)aom_malloc(sizeof(*tile_cost) * ntiles);
 
   // Start the search at the previous frame filter level unless it is now out of
   // range.
@@ -157,13 +250,14 @@ static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
   // Set each entry to -1
   for (i = 0; i <= MAX_LOOP_FILTER; ++i) ss_err[i] = -1.0;
 
-  tmp_level = (int *)aom_malloc(sizeof(*tmp_level) * ntiles);
+  tmp_level =
+      (int *)aom_malloc(sizeof(*tmp_level) * ntiles * BILATERAL_SUBTILES);
 
   bilateral_success[filt_mid] = search_bilateral_level(
-      sd, cpi, filt_mid, partial_frame, tmp_level, &best_err);
+      src, cpi, filt_mid, partial_frame, tmp_level, &best_err, best_tile_cost);
   filt_best = filt_mid;
   ss_err[filt_mid] = best_err;
-  for (j = 0; j < ntiles; ++j) {
+  for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
     bilateral_level[j] = tmp_level[j];
   }
 
@@ -183,8 +277,9 @@ static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
     if (filt_direction <= 0 && filt_low != filt_mid) {
       // Get Low filter error score
       if (ss_err[filt_low] < 0) {
-        bilateral_success[filt_low] = search_bilateral_level(
-            sd, cpi, filt_low, partial_frame, tmp_level, &ss_err[filt_low]);
+        bilateral_success[filt_low] =
+            search_bilateral_level(src, cpi, filt_low, partial_frame, tmp_level,
+                                   &ss_err[filt_low], tile_cost);
       }
       // If value is close to the best so far then bias towards a lower loop
       // filter value.
@@ -194,26 +289,29 @@ static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
           best_err = ss_err[filt_low];
         }
         filt_best = filt_low;
-        for (j = 0; j < ntiles; ++j) {
+        for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
           bilateral_level[j] = tmp_level[j];
         }
+        memcpy(best_tile_cost, tile_cost, sizeof(*tile_cost) * ntiles);
       }
     }
 
     // Now look at filt_high
     if (filt_direction >= 0 && filt_high != filt_mid) {
       if (ss_err[filt_high] < 0) {
-        bilateral_success[filt_high] = search_bilateral_level(
-            sd, cpi, filt_high, partial_frame, tmp_level, &ss_err[filt_high]);
+        bilateral_success[filt_high] =
+            search_bilateral_level(src, cpi, filt_high, partial_frame,
+                                   tmp_level, &ss_err[filt_high], tile_cost);
       }
       // If value is significantly better than previous best, bias added against
       // raising filter value
       if (ss_err[filt_high] < (best_err - bias)) {
         best_err = ss_err[filt_high];
         filt_best = filt_high;
-        for (j = 0; j < ntiles; ++j) {
+        for (j = 0; j < ntiles * BILATERAL_SUBTILES; ++j) {
           bilateral_level[j] = tmp_level[j];
         }
+        memcpy(best_tile_cost, tile_cost, sizeof(*tile_cost) * ntiles);
       }
     }
 
@@ -226,12 +324,11 @@ static int search_filter_bilateral_level(const YV12_BUFFER_CONFIG *sd,
       filt_mid = filt_best;
     }
   }
-
   aom_free(tmp_level);
+  aom_free(tile_cost);
 
   // Update best error
   best_err = ss_err[filt_best];
-
   if (best_cost_ret) *best_cost_ret = best_err;
   if (filter_best) *filter_best = filt_best;
 
@@ -546,14 +643,15 @@ static void quantize_sym_filter(double *f, int *fi) {
 
 static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                 int filter_level, int partial_frame,
-                                int (*vfilter)[RESTORATION_HALFWIN],
-                                int (*hfilter)[RESTORATION_HALFWIN],
-                                int *process_tile, double *best_cost_ret) {
+                                int (*vfilter)[RESTORATION_WIN],
+                                int (*hfilter)[RESTORATION_WIN],
+                                int *wiener_level, double *best_cost_ret,
+                                double *best_tile_cost) {
   AV1_COMMON *const cm = &cpi->common;
   RestorationInfo rsi;
   int64_t err;
   int bits;
-  double cost_wiener, cost_norestore;
+  double cost_wiener, cost_norestore, cost_norestore_tile;
   MACROBLOCK *x = &cpi->td.mb;
   double M[RESTORATION_WIN2];
   double H[RESTORATION_WIN2 * RESTORATION_WIN2];
@@ -564,56 +662,55 @@ static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
   const int src_stride = src->y_stride;
   const int dgd_stride = dgd->y_stride;
   double score;
-  int tile_idx, htile_idx, vtile_idx, tile_width, tile_height, nhtiles, nvtiles;
+  int tile_idx, tile_width, tile_height, nhtiles, nvtiles;
   int h_start, h_end, v_start, v_end;
   int i, j;
 
-  const int tilesize = WIENER_TILESIZE;
-  const int ntiles = av1_get_restoration_ntiles(tilesize, width, height);
-
+  const int ntiles = av1_get_rest_ntiles(width, height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
   assert(width == dgd->y_crop_width);
   assert(height == dgd->y_crop_height);
   assert(width == src->y_crop_width);
   assert(height == src->y_crop_height);
 
-  av1_get_restoration_tile_size(tilesize, width, height, &tile_width,
-                                &tile_height, &nhtiles, &nvtiles);
-
   //  Make a copy of the unfiltered / processed recon buffer
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
   av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
                         1, partial_frame);
   aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
 
-  rsi.restoration_type = RESTORE_NONE;
-  err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-  bits = 0;
-  cost_norestore =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  rsi.frame_restoration_type = RESTORE_NONE;
+  err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
-  rsi.restoration_type = RESTORE_WIENER;
+  rsi.frame_restoration_type = RESTORE_WIENER;
   rsi.vfilter =
-      (int(*)[RESTORATION_HALFWIN])aom_malloc(sizeof(*rsi.vfilter) * ntiles);
+      (int(*)[RESTORATION_WIN])aom_malloc(sizeof(*rsi.vfilter) * ntiles);
   assert(rsi.vfilter != NULL);
   rsi.hfilter =
-      (int(*)[RESTORATION_HALFWIN])aom_malloc(sizeof(*rsi.hfilter) * ntiles);
+      (int(*)[RESTORATION_WIN])aom_malloc(sizeof(*rsi.hfilter) * ntiles);
   assert(rsi.hfilter != NULL);
   rsi.wiener_level = (int *)aom_malloc(sizeof(*rsi.wiener_level) * ntiles);
   assert(rsi.wiener_level != NULL);
 
   // Compute best Wiener filters for each tile
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    htile_idx = tile_idx % nhtiles;
-    vtile_idx = tile_idx / nhtiles;
-    h_start =
-        htile_idx * tile_width + ((htile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    h_end = (htile_idx < nhtiles - 1) ? ((htile_idx + 1) * tile_width)
-                                      : (width - RESTORATION_HALFWIN);
-    v_start =
-        vtile_idx * tile_height + ((vtile_idx > 0) ? 0 : RESTORATION_HALFWIN);
-    v_end = (vtile_idx < nvtiles - 1) ? ((vtile_idx + 1) * tile_height)
-                                      : (height - RESTORATION_HALFWIN);
-
+    wiener_level[tile_idx] = 0;
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, width, height, 0, 0, &h_start, &h_end,
+                             &v_start, &v_end);
+    err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                               v_end - v_start);
+    // #bits when a tile is not restored
+    bits = av1_cost_bit(RESTORE_NONE_WIENER_PROB, 0);
+    cost_norestore_tile = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (best_tile_cost) best_tile_cost[tile_idx] = cost_norestore_tile;
+
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, width, height, 1, 1, &h_start, &h_end,
+                             &v_start, &v_end);
 #if CONFIG_AOM_HIGHBITDEPTH
     if (cm->use_highbitdepth)
       compute_stats_highbd(dgd->y_buffer, src->y_buffer, h_start, h_end,
@@ -626,12 +723,12 @@ static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
     if (!wiener_decompose_sep_sym(M, H, vfilterd, hfilterd)) {
       for (i = 0; i < RESTORATION_HALFWIN; ++i)
         rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
-      process_tile[tile_idx] = 0;
+      wiener_level[tile_idx] = 0;
       continue;
     }
     quantize_sym_filter(vfilterd, rsi.vfilter[tile_idx]);
     quantize_sym_filter(hfilterd, rsi.hfilter[tile_idx]);
-    process_tile[tile_idx] = 1;
+    wiener_level[tile_idx] = 1;
 
     // Filter score computes the value of the function x'*A*x - x'*b for the
     // learned filter and compares it against identity filer. If there is no
@@ -640,31 +737,41 @@ static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
     if (score > 0.0) {
       for (i = 0; i < RESTORATION_HALFWIN; ++i)
         rsi.vfilter[tile_idx][i] = rsi.hfilter[tile_idx][i] = 0;
-      process_tile[tile_idx] = 0;
+      wiener_level[tile_idx] = 0;
       continue;
     }
 
     for (j = 0; j < ntiles; ++j) rsi.wiener_level[j] = 0;
     rsi.wiener_level[tile_idx] = 1;
 
-    err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-    bits = 1 + WIENER_FILT_BITS;
-    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv,
-                             (bits << (AV1_PROB_COST_SHIFT - 4)), err);
-    if (cost_wiener >= cost_norestore) process_tile[tile_idx] = 0;
+    err = try_restoration_tile(src, cpi, &rsi, partial_frame, tile_idx, 0, 0);
+    bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, 1);
+    cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+    if (cost_wiener >= cost_norestore_tile) wiener_level[tile_idx] = 0;
+    if (best_tile_cost) {
+      bits = WIENER_FILT_BITS << AV1_PROB_COST_SHIFT;
+      best_tile_cost[tile_idx] = RDCOST_DBL(
+          x->rdmult, x->rddiv,
+          (bits + cpi->switchable_restore_cost[RESTORE_WIENER]) >> 4, err);
+    }
   }
   // Cost for Wiener filtering
-  bits = 0;
+  bits = frame_level_restore_bits[rsi.frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    bits += (process_tile[tile_idx] ? (WIENER_FILT_BITS + 1) : 1);
-    rsi.wiener_level[tile_idx] = process_tile[tile_idx];
+    bits += av1_cost_bit(RESTORE_NONE_WIENER_PROB, wiener_level[tile_idx]);
+    if (wiener_level[tile_idx])
+      bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
+    rsi.wiener_level[tile_idx] = wiener_level[tile_idx];
   }
+  // TODO(debargha): This is a pretty inefficient way to find the error
+  // for the whole frame. Specialize for a specific tile.
   err = try_restoration_frame(src, cpi, &rsi, partial_frame);
-  cost_wiener =
-      RDCOST_DBL(x->rdmult, x->rddiv, (bits << (AV1_PROB_COST_SHIFT - 4)), err);
+  cost_wiener = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
 
   for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
-    if (process_tile[tile_idx] == 0) continue;
+    if (wiener_level[tile_idx] == 0) continue;
     for (i = 0; i < RESTORATION_HALFWIN; ++i) {
       vfilter[tile_idx][i] = rsi.vfilter[tile_idx][i];
       hfilter[tile_idx][i] = rsi.hfilter[tile_idx][i];
@@ -685,40 +792,125 @@ static int search_wiener_filter(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
   }
 }
 
-void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
+static int search_switchable_restoration(
+    const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi, int filter_level,
+    int partial_frame, RestorationInfo *rsi, double *tile_cost_bilateral,
+    double *tile_cost_wiener, double *best_cost_ret) {
+  AV1_COMMON *const cm = &cpi->common;
+  const int bilateral_level_bits = av1_bilateral_level_bits(&cpi->common);
+  MACROBLOCK *x = &cpi->td.mb;
+  double err, cost_norestore, cost_norestore_tile, cost_switchable;
+  int bits, tile_idx;
+  int tile_width, tile_height, nhtiles, nvtiles;
+  int h_start, h_end, v_start, v_end;
+  const int ntiles = av1_get_rest_ntiles(cm->width, cm->height, &tile_width,
+                                         &tile_height, &nhtiles, &nvtiles);
+
+  //  Make a copy of the unfiltered / processed recon buffer
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_uf);
+  av1_loop_filter_frame(cm->frame_to_show, cm, &cpi->td.mb.e_mbd, filter_level,
+                        1, partial_frame);
+  aom_yv12_copy_y(cm->frame_to_show, &cpi->last_frame_db);
+
+  // RD cost associated with no restoration
+  rsi->frame_restoration_type = RESTORE_NONE;
+  err = sse_restoration_tile(src, cm, 0, cm->width, 0, cm->height);
+  bits = frame_level_restore_bits[rsi->frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  cost_norestore = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+
+  rsi->frame_restoration_type = RESTORE_SWITCHABLE;
+  bits = frame_level_restore_bits[rsi->frame_restoration_type]
+         << AV1_PROB_COST_SHIFT;
+  for (tile_idx = 0; tile_idx < ntiles; ++tile_idx) {
+    av1_get_rest_tile_limits(tile_idx, 0, 0, nhtiles, nvtiles, tile_width,
+                             tile_height, cm->width, cm->height, 0, 0, &h_start,
+                             &h_end, &v_start, &v_end);
+    err = sse_restoration_tile(src, cm, h_start, h_end - h_start, v_start,
+                               v_end - v_start);
+    cost_norestore_tile =
+        RDCOST_DBL(x->rdmult, x->rddiv,
+                   (cpi->switchable_restore_cost[RESTORE_NONE] >> 4), err);
+    if (tile_cost_wiener[tile_idx] > cost_norestore_tile &&
+        tile_cost_bilateral[tile_idx] > cost_norestore_tile) {
+      rsi->restoration_type[tile_idx] = RESTORE_NONE;
+    } else {
+      rsi->restoration_type[tile_idx] =
+          tile_cost_wiener[tile_idx] < tile_cost_bilateral[tile_idx]
+              ? RESTORE_WIENER
+              : RESTORE_BILATERAL;
+      if (rsi->restoration_type[tile_idx] == RESTORE_WIENER) {
+        if (rsi->wiener_level[tile_idx]) {
+          bits += (WIENER_FILT_BITS << AV1_PROB_COST_SHIFT);
+        } else {
+          rsi->restoration_type[tile_idx] = RESTORE_NONE;
+        }
+      } else {
+        int s;
+        for (s = 0; s < BILATERAL_SUBTILES; ++s) {
+#if BILATERAL_SUBTILES
+          bits += av1_cost_bit(
+              RESTORE_NONE_BILATERAL_PROB,
+              rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + s] >= 0);
+#endif
+          if (rsi->bilateral_level[tile_idx * BILATERAL_SUBTILES + s] >= 0)
+            bits += bilateral_level_bits << AV1_PROB_COST_SHIFT;
+        }
+      }
+    }
+    bits += cpi->switchable_restore_cost[rsi->restoration_type[tile_idx]];
+  }
+  err = try_restoration_frame(src, cpi, rsi, partial_frame);
+  cost_switchable = RDCOST_DBL(x->rdmult, x->rddiv, (bits >> 4), err);
+  aom_yv12_copy_y(&cpi->last_frame_uf, cm->frame_to_show);
+  if (cost_switchable < cost_norestore) {
+    *best_cost_ret = cost_switchable;
+    return 1;
+  } else {
+    *best_cost_ret = cost_norestore;
+    return 0;
+  }
+}
+
+void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *src, AV1_COMP *cpi,
                                  LPF_PICK_METHOD method) {
   AV1_COMMON *const cm = &cpi->common;
   struct loopfilter *const lf = &cm->lf;
   int wiener_success = 0;
   int bilateral_success = 0;
+  int switchable_success = 0;
   double cost_bilateral = DBL_MAX;
   double cost_wiener = DBL_MAX;
-  double cost_norestore = DBL_MAX;
-  int ntiles;
-
-  ntiles =
-      av1_get_restoration_ntiles(BILATERAL_TILESIZE, cm->width, cm->height);
-  cm->rst_info.bilateral_level =
-      (int *)aom_realloc(cm->rst_info.bilateral_level,
-                         sizeof(*cm->rst_info.bilateral_level) * ntiles);
+  // double cost_norestore = DBL_MAX;
+  double cost_switchable = DBL_MAX;
+  double *tile_cost_bilateral, *tile_cost_wiener;
+  const int ntiles =
+      av1_get_rest_ntiles(cm->width, cm->height, NULL, NULL, NULL, NULL);
+  cm->rst_info.restoration_type = (RestorationType *)aom_realloc(
+      cm->rst_info.restoration_type,
+      sizeof(*cm->rst_info.restoration_type) * ntiles);
+  cm->rst_info.bilateral_level = (int *)aom_realloc(
+      cm->rst_info.bilateral_level,
+      sizeof(*cm->rst_info.bilateral_level) * ntiles * BILATERAL_SUBTILES);
   assert(cm->rst_info.bilateral_level != NULL);
 
-  ntiles = av1_get_restoration_ntiles(WIENER_TILESIZE, cm->width, cm->height);
   cm->rst_info.wiener_level = (int *)aom_realloc(
       cm->rst_info.wiener_level, sizeof(*cm->rst_info.wiener_level) * ntiles);
   assert(cm->rst_info.wiener_level != NULL);
-  cm->rst_info.vfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+  cm->rst_info.vfilter = (int(*)[RESTORATION_WIN])aom_realloc(
       cm->rst_info.vfilter, sizeof(*cm->rst_info.vfilter) * ntiles);
   assert(cm->rst_info.vfilter != NULL);
-  cm->rst_info.hfilter = (int(*)[RESTORATION_HALFWIN])aom_realloc(
+  cm->rst_info.hfilter = (int(*)[RESTORATION_WIN])aom_realloc(
       cm->rst_info.hfilter, sizeof(*cm->rst_info.hfilter) * ntiles);
   assert(cm->rst_info.hfilter != NULL);
+  tile_cost_wiener = (double *)aom_malloc(sizeof(cost_wiener) * ntiles);
+  tile_cost_bilateral = (double *)aom_malloc(sizeof(cost_bilateral) * ntiles);
 
   lf->sharpness_level = cm->frame_type == KEY_FRAME ? 0 : cpi->oxcf.sharpness;
 
   if (method == LPF_PICK_MINIMAL_LPF && lf->filter_level) {
     lf->filter_level = 0;
-    cm->rst_info.restoration_type = RESTORE_NONE;
+    cm->rst_info.frame_restoration_type = RESTORE_NONE;
   } else if (method >= LPF_PICK_FROM_Q) {
     const int min_filter_level = 0;
     const int max_filter_level = av1_get_max_filter_level(cpi);
@@ -749,60 +941,51 @@ void av1_pick_filter_restoration(const YV12_BUFFER_CONFIG *sd, AV1_COMP *cpi,
     if (cm->frame_type == KEY_FRAME) filt_guess -= 4;
     lf->filter_level = clamp(filt_guess, min_filter_level, max_filter_level);
     bilateral_success = search_bilateral_level(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
-        cm->rst_info.bilateral_level, &cost_bilateral);
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     wiener_success = search_wiener_filter(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
         cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
-        &cost_wiener);
-    if (cost_bilateral < cost_wiener) {
-      if (bilateral_success)
-        cm->rst_info.restoration_type = RESTORE_BILATERAL;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    } else {
-      if (wiener_success)
-        cm->rst_info.restoration_type = RESTORE_WIENER;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    }
+        &cost_wiener, tile_cost_wiener);
+    switchable_success = search_switchable_restoration(
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        &cm->rst_info, tile_cost_bilateral, tile_cost_wiener, &cost_switchable);
   } else {
-    int blf_filter_level = -1;
+    // lf->filter_level = av1_search_filter_level(
+    //     src, cpi, method == LPF_PICK_FROM_SUBIMAGE, &cost_norestore);
+    // bilateral_success = search_bilateral_level(
+    //     src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+    //     cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     bilateral_success = search_filter_bilateral_level(
-        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &blf_filter_level,
-        cm->rst_info.bilateral_level, &cost_bilateral);
-    lf->filter_level = av1_search_filter_level(
-        sd, cpi, method == LPF_PICK_FROM_SUBIMAGE, &cost_norestore);
+        src, cpi, method == LPF_PICK_FROM_SUBIMAGE, &lf->filter_level,
+        cm->rst_info.bilateral_level, &cost_bilateral, tile_cost_bilateral);
     wiener_success = search_wiener_filter(
-        sd, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
         cm->rst_info.vfilter, cm->rst_info.hfilter, cm->rst_info.wiener_level,
-        &cost_wiener);
-    if (cost_bilateral < cost_wiener) {
-      lf->filter_level = blf_filter_level;
-      if (bilateral_success)
-        cm->rst_info.restoration_type = RESTORE_BILATERAL;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    } else {
-      if (wiener_success)
-        cm->rst_info.restoration_type = RESTORE_WIENER;
-      else
-        cm->rst_info.restoration_type = RESTORE_NONE;
-    }
-    // printf("[%d] Costs %g %g (%d) %g (%d)\n", cm->rst_info.restoration_type,
-    //        cost_norestore, cost_bilateral, lf->filter_level, cost_wiener,
-    //        wiener_success);
-  }
-  if (cm->rst_info.restoration_type != RESTORE_BILATERAL) {
-    aom_free(cm->rst_info.bilateral_level);
-    cm->rst_info.bilateral_level = NULL;
-  }
-  if (cm->rst_info.restoration_type != RESTORE_WIENER) {
-    aom_free(cm->rst_info.vfilter);
-    cm->rst_info.vfilter = NULL;
-    aom_free(cm->rst_info.hfilter);
-    cm->rst_info.hfilter = NULL;
-    aom_free(cm->rst_info.wiener_level);
-    cm->rst_info.wiener_level = NULL;
+        &cost_wiener, tile_cost_wiener);
+    switchable_success = search_switchable_restoration(
+        src, cpi, lf->filter_level, method == LPF_PICK_FROM_SUBIMAGE,
+        &cm->rst_info, tile_cost_bilateral, tile_cost_wiener, &cost_switchable);
+  }
+  if (cost_bilateral < AOMMIN(cost_wiener, cost_switchable)) {
+    if (bilateral_success)
+      cm->rst_info.frame_restoration_type = RESTORE_BILATERAL;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
+  } else if (cost_wiener < AOMMIN(cost_bilateral, cost_switchable)) {
+    if (wiener_success)
+      cm->rst_info.frame_restoration_type = RESTORE_WIENER;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
+  } else {
+    if (switchable_success)
+      cm->rst_info.frame_restoration_type = RESTORE_SWITCHABLE;
+    else
+      cm->rst_info.frame_restoration_type = RESTORE_NONE;
   }
+  printf("Frame %d frame_restore_type %d [%d]: %f %f %f\n",
+         cm->current_video_frame, cm->rst_info.frame_restoration_type, ntiles,
+         cost_bilateral, cost_wiener, cost_switchable);
+  aom_free(tile_cost_bilateral);
+  aom_free(tile_cost_wiener);
 }
index 566036972359973ffb9698956cc66a60305e55b1..b56e3c14de59dd289790ff93bd32e3beb47c070c 100644 (file)
@@ -146,6 +146,10 @@ static void fill_mode_costs(AV1_COMP *cpi) {
     av1_cost_tokens(cpi->intra_filter_cost[i], fc->intra_filter_probs[i],
                     av1_intra_filter_tree);
 #endif  // CONFIG_EXT_INTRA
+#if CONFIG_LOOP_RESTORATION
+  av1_cost_tokens(cpi->switchable_restore_cost, fc->switchable_restore_prob,
+                  av1_switchable_restore_tree);
+#endif  // CONFIG_LOOP_RESTORATION
 }
 
 void av1_fill_token_costs(av1_coeff_cost *c,