shithub: libvpx

Download patch

ref: 3244f4738d16cfef982ebf94e70dadd95d4c47b8
parent: 6b5db3f9da3f8b21cf582ba0d8e375a70427d0e1
author: Dan Zhu <[email protected]>
date: Wed Jul 17 11:20:27 EDT 2019

Based Class of Motion Field Estimators

Change-Id: Id01ce15273c0cab0cd61d064099d200708360265

--- /dev/null
+++ b/tools/3D-Reconstruction/MotionEST/MotionEST.py
@@ -1,0 +1,113 @@
+#!/usr/bin/env python
+# coding: utf-8
+import numpy as np
+import numpy.linalg as LA
+import matplotlib.pyplot as plt
+from Util import drawMF, MSE
+"""The Base Class of Estimators"""
+
+
+class MotionEST(object):
+  """
+    constructor:
+        cur_f: current frame
+        ref_f: reference frame
+        blk_sz: block size
+    """
+
+  def __init__(self, cur_f, ref_f, blk_sz):
+    self.cur_f = cur_f
+    self.ref_f = ref_f
+    self.blk_sz = blk_sz
+    #convert RGB to YUV
+    self.cur_yuv = np.array(self.cur_f.convert('YCbCr'))
+    self.ref_yuv = np.array(self.ref_f.convert('YCbCr'))
+    #frame size
+    self.width = self.cur_f.size[0]
+    self.height = self.cur_f.size[1]
+    #motion field size
+    self.num_row = self.height // self.blk_sz
+    self.num_col = self.width // self.blk_sz
+    #initialize motion field
+    self.mf = np.zeros((self.num_row, self.num_col, 2))
+
+  """
+    estimation function
+        Override by child classes
+    """
+
+  def motion_field_estimation(self):
+    pass
+
+  """
+    distortion of a block:
+        cur_r: current row
+        cur_c: current column
+        mv: motion vector
+        metric: distortion metric
+    """
+
+  def block_dist(self, cur_r, cur_c, mv, metric=MSE):
+    cur_x = cur_c * self.blk_sz
+    cur_y = cur_r * self.blk_sz
+    h = min(self.blk_sz, self.height - cur_y)
+    w = min(self.blk_sz, self.width - cur_x)
+    cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :]
+    ref_x = cur_x + mv[1]
+    ref_y = cur_y + mv[0]
+    if 0 <= ref_x < self.width and 0 <= ref_y < self.height:
+      ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :]
+    else:
+      ref_blk = np.zeros((h, w, 3))
+    return self.metric(cur_blk, ref_blk)
+
+  """
+    distortion of motion field
+    """
+
+  def distortion(self, metric=MSE):
+    loss = 0
+    for i in xrange(self.num_row):
+      for j in xrange(self.num_col):
+        loss += self.dist(i, j, self.mf[i, j], metric)
+    return loss / self.num_row / self.num_col
+
+  """
+    evaluation
+        compare the difference with ground truth
+    """
+
+  def motion_field_evaluation(self, ground_truth):
+    loss = 0
+    count = 0
+    gt = ground_truth.mf
+    mask = ground_truth.mask
+    for i in xrange(self.num_row):
+      for j in xrange(self.num_col):
+        if not mask is None and mask[i][j]:
+          continue
+        loss += LA.norm(gt[i, j] - self.mf[i, j])
+        count += 1
+    return loss / count
+
+  """
+    render the motion field
+    """
+
+  def show(self, ground_truth=None):
+    cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf)
+    if ground_truth is None:
+      n_row = 1
+    else:
+      gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth)
+      n_row = 2
+    plt.figure(figsize=(n_row * 10, 10))
+    plt.subplot(1, n_row, 1)
+    plt.imshow(cur_mf)
+    plt.title('Estimated Motion Field')
+    if not ground_truth is None:
+      plt.subplot(1, n_row, 2)
+      plt.imshow(gt_mf)
+      plt.title('Ground Truth')
+    plt.tight_layout()
+    plt.show()
--- /dev/null
+++ b/tools/3D-Reconstruction/MotionEST/Util.py
@@ -1,0 +1,37 @@
+#!/usr/bin/env python
+# coding: utf-8
+import numpy as np
+import numpy.linalg as LA
+import matplotlib.pyplot as plt
+from scipy.ndimage import filters
+from PIL import Image, ImageDraw
+
+
+def MSE(blk1, blk2):
+  return np.mean(
+      LA.norm(
+          np.array(blk1, dtype=np.int) - np.array(blk2, dtype=np.int), axis=2))
+
+
+def drawMF(img, blk_sz, mf):
+  img_rgba = img.convert('RGBA')
+  mf_layer = Image.new(mode='RGBA', size=img_rgba.size, color=(0, 0, 0, 0))
+  draw = ImageDraw.Draw(mf_layer)
+  width = img_rgba.size[0]
+  height = img_rgba.size[1]
+  num_row = height // blk_sz
+  num_col = width // blk_sz
+  for i in xrange(num_row):
+    left = (0, i * blk_sz)
+    right = (width, i * blk_sz)
+    draw.line([left, right], fill=(0, 0, 255, 255))
+  for j in xrange(num_col):
+    up = (j * blk_sz, 0)
+    down = (j * blk_sz, height)
+    draw.line([up, down], fill=(0, 0, 255, 255))
+  for i in xrange(num_row):
+    for j in xrange(num_col):
+      center = (j * blk_sz + 0.5 * blk_sz, i * blk_sz + 0.5 * blk_sz)
+      head = (center[0] + mf[i, j][1], center[1] + mf[i, j][0])
+      draw.line([center, head], fill=(255, 0, 0, 255))
+  return Image.alpha_composite(img_rgba, mf_layer)