shithub: libvpx

ref: d7a2451d48ca3b6a01afb88d775f8d0614211b88
dir: /tools/3D-Reconstruction/MotionEST/MotionEST.py/

View raw version
#!/usr/bin/env python
# coding: utf-8
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from Util import drawMF, MSE
"""The Base Class of Estimators"""


class MotionEST(object):
  """
    constructor:
        cur_f: current frame
        ref_f: reference frame
        blk_sz: block size
    """

  def __init__(self, cur_f, ref_f, blk_sz):
    self.cur_f = cur_f
    self.ref_f = ref_f
    self.blk_sz = blk_sz
    #convert RGB to YUV
    self.cur_yuv = np.array(self.cur_f.convert('YCbCr'))
    self.ref_yuv = np.array(self.ref_f.convert('YCbCr'))
    #frame size
    self.width = self.cur_f.size[0]
    self.height = self.cur_f.size[1]
    #motion field size
    self.num_row = self.height // self.blk_sz
    self.num_col = self.width // self.blk_sz
    #initialize motion field
    self.mf = np.zeros((self.num_row, self.num_col, 2))

  """
    estimation function
        Override by child classes
    """

  def motion_field_estimation(self):
    pass

  """
    distortion of a block:
        cur_r: current row
        cur_c: current column
        mv: motion vector
        metric: distortion metric
    """

  def block_dist(self, cur_r, cur_c, mv, metric=MSE):
    cur_x = cur_c * self.blk_sz
    cur_y = cur_r * self.blk_sz
    h = min(self.blk_sz, self.height - cur_y)
    w = min(self.blk_sz, self.width - cur_x)
    cur_blk = self.cur_yuv[cur_y:cur_y + h, cur_x:cur_x + w, :]
    ref_x = cur_x + mv[1]
    ref_y = cur_y + mv[0]
    if 0 <= ref_x < self.width and 0 <= ref_y < self.height:
      ref_blk = self.ref_yuv[ref_y:ref_y + h, ref_x:ref_x + w, :]
    else:
      ref_blk = np.zeros((h, w, 3))
    return self.metric(cur_blk, ref_blk)

  """
    distortion of motion field
    """

  def distortion(self, metric=MSE):
    loss = 0
    for i in xrange(self.num_row):
      for j in xrange(self.num_col):
        loss += self.dist(i, j, self.mf[i, j], metric)
    return loss / self.num_row / self.num_col

  """
    evaluation
        compare the difference with ground truth
    """

  def motion_field_evaluation(self, ground_truth):
    loss = 0
    count = 0
    gt = ground_truth.mf
    mask = ground_truth.mask
    for i in xrange(self.num_row):
      for j in xrange(self.num_col):
        if not mask is None and mask[i][j]:
          continue
        loss += LA.norm(gt[i, j] - self.mf[i, j])
        count += 1
    return loss / count

  """
    render the motion field
    """

  def show(self, ground_truth=None):
    cur_mf = drawMF(self.cur_f, self.blk_sz, self.mf)
    if ground_truth is None:
      n_row = 1
    else:
      gt_mf = drawMF(self.cur_f, self.blk_sz, ground_truth)
      n_row = 2
    plt.figure(figsize=(n_row * 10, 10))
    plt.subplot(1, n_row, 1)
    plt.imshow(cur_mf)
    plt.title('Estimated Motion Field')
    if not ground_truth is None:
      plt.subplot(1, n_row, 2)
      plt.imshow(gt_mf)
      plt.title('Ground Truth')
    plt.tight_layout()
    plt.show()