shithub: libvpx

Download patch

ref: 94bedd013ea6200dbbcf21ecf70928d77bf0ffa7
parent: 66bf68697591047c46543ecdb72da61e1d3dd62a
parent: 0134764fa6be4c46bd24b7bcf2fd092321335341
author: Debargha Mukherjee <[email protected]>
date: Fri Oct 9 09:36:47 EDT 2015

Merge "Optimization of 8bit block error for high bitdepth"

--- a/test/vp9_error_block_test.cc
+++ b/test/vp9_error_block_test.cc
@@ -136,7 +136,23 @@
 
 using std::tr1::make_tuple;
 
-#if HAVE_SSE2
+#if CONFIG_USE_X86INC && HAVE_SSE2
+int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff,
+                                              const tran_low_t *dqcoeff,
+                                              intptr_t block_size,
+                                              int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz);
+}
+
+int64_t wrap_vp9_highbd_block_error_8bit_c(const tran_low_t *coeff,
+                                           const tran_low_t *dqcoeff,
+                                           intptr_t block_size,
+                                           int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_c(coeff, dqcoeff, block_size, ssz);
+}
+
 INSTANTIATE_TEST_CASE_P(
     SSE2, ErrorBlockTest,
     ::testing::Values(
@@ -145,7 +161,9 @@
         make_tuple(&vp9_highbd_block_error_sse2,
                    &vp9_highbd_block_error_c, VPX_BITS_12),
         make_tuple(&vp9_highbd_block_error_sse2,
-                   &vp9_highbd_block_error_c, VPX_BITS_8)));
+                   &vp9_highbd_block_error_c, VPX_BITS_8),
+        make_tuple(&wrap_vp9_highbd_block_error_8bit_sse2,
+                   &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8)));
 #endif  // HAVE_SSE2
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 }  // namespace
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -241,11 +241,15 @@
 }
 
 if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
-# the transform coefficients are held in 32-bit
-# values, so the assembler code for  vp9_block_error can no longer be used.
   add_proto qw/int64_t vp9_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz";
   specialize qw/vp9_block_error/;
 
+  add_proto qw/int64_t vp9_highbd_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz, int bd";
+  specialize qw/vp9_highbd_block_error/, "$sse2_x86inc";
+
+  add_proto qw/int64_t vp9_highbd_block_error_8bit/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz";
+  specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc";
+
   add_proto qw/void vp9_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan";
   specialize qw/vp9_quantize_fp/;
 
@@ -319,9 +323,6 @@
 if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
 
   # ENCODEMB INVOKE
-
-  add_proto qw/int64_t vp9_highbd_block_error/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz, int bd";
-  specialize qw/vp9_highbd_block_error sse2/;
 
   add_proto qw/void vp9_highbd_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan";
   specialize qw/vp9_highbd_quantize_fp/;
--- a/vp9/encoder/vp9_rdopt.c
+++ b/vp9/encoder/vp9_rdopt.c
@@ -269,58 +269,99 @@
   *out_dist_sum = dist_sum << 4;
 }
 
-int64_t vp9_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
-                          intptr_t block_size, int64_t *ssz) {
+#if CONFIG_VP9_HIGHBITDEPTH
+int64_t vp9_highbd_block_error_c(const tran_low_t *coeff,
+                                 const tran_low_t *dqcoeff,
+                                 intptr_t block_size,
+                                 int64_t *ssz, int bd) {
   int i;
   int64_t error = 0, sqcoeff = 0;
+  int shift = 2 * (bd - 8);
+  int rounding = shift > 0 ? 1 << (shift - 1) : 0;
 
   for (i = 0; i < block_size; i++) {
-    const int diff = coeff[i] - dqcoeff[i];
+    const int64_t diff = coeff[i] - dqcoeff[i];
     error +=  diff * diff;
-    sqcoeff += coeff[i] * coeff[i];
+    sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i];
   }
+  assert(error >= 0 && sqcoeff >= 0);
+  error = (error + rounding) >> shift;
+  sqcoeff = (sqcoeff + rounding) >> shift;
 
   *ssz = sqcoeff;
   return error;
 }
 
-int64_t vp9_block_error_fp_c(const int16_t *coeff, const int16_t *dqcoeff,
-                             int block_size) {
+int64_t vp9_highbd_block_error_8bit_c(const tran_low_t *coeff,
+                                      const tran_low_t *dqcoeff,
+                                      intptr_t block_size,
+                                      int64_t *ssz) {
   int i;
-  int64_t error = 0;
+  int32_t c, d;
+  int64_t error = 0, sqcoeff = 0;
+  int16_t diff;
 
+  const int32_t hi = 0x00007fff;
+  const int32_t lo = 0xffff8000;
+
   for (i = 0; i < block_size; i++) {
-    const int diff = coeff[i] - dqcoeff[i];
+    c = coeff[i];
+    d = dqcoeff[i];
+
+    // Saturate to 16 bits
+    c = (c > hi) ? hi : ((c < lo) ? lo : c);
+    d = (d > hi) ? hi : ((d < lo) ? lo : d);
+
+    diff = d - c;
     error +=  diff * diff;
+    sqcoeff += c * c;
   }
+  assert(error >= 0 && sqcoeff >= 0);
 
+  *ssz = sqcoeff;
   return error;
 }
 
-#if CONFIG_VP9_HIGHBITDEPTH
-int64_t vp9_highbd_block_error_c(const tran_low_t *coeff,
-                                 const tran_low_t *dqcoeff,
-                                 intptr_t block_size,
-                                 int64_t *ssz, int bd) {
+static int64_t vp9_highbd_block_error_dispatch(const tran_low_t *coeff,
+                                               const tran_low_t *dqcoeff,
+                                               intptr_t block_size,
+                                               int64_t *ssz, int bd) {
+  if (bd == 8) {
+    return vp9_highbd_block_error_8bit(coeff, dqcoeff, block_size, ssz);
+  } else {
+    return vp9_highbd_block_error(coeff, dqcoeff, block_size, ssz, bd);
+  }
+}
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+
+int64_t vp9_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
+                          intptr_t block_size, int64_t *ssz) {
   int i;
   int64_t error = 0, sqcoeff = 0;
-  int shift = 2 * (bd - 8);
-  int rounding = shift > 0 ? 1 << (shift - 1) : 0;
 
   for (i = 0; i < block_size; i++) {
-    const int64_t diff = coeff[i] - dqcoeff[i];
+    const int diff = coeff[i] - dqcoeff[i];
     error +=  diff * diff;
-    sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i];
+    sqcoeff += coeff[i] * coeff[i];
   }
-  assert(error >= 0 && sqcoeff >= 0);
-  error = (error + rounding) >> shift;
-  sqcoeff = (sqcoeff + rounding) >> shift;
 
   *ssz = sqcoeff;
   return error;
 }
-#endif  // CONFIG_VP9_HIGHBITDEPTH
 
+int64_t vp9_block_error_fp_c(const int16_t *coeff, const int16_t *dqcoeff,
+                             int block_size) {
+  int i;
+  int64_t error = 0;
+
+  for (i = 0; i < block_size; i++) {
+    const int diff = coeff[i] - dqcoeff[i];
+    error +=  diff * diff;
+  }
+
+  return error;
+}
+
 /* The trailing '0' is a terminator which is used inside cost_coeffs() to
  * decide whether to include cost of a trailing EOB node or not (i.e. we
  * can skip this if the last coefficient in this transform block, e.g. the
@@ -430,8 +471,9 @@
   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
 #if CONFIG_VP9_HIGHBITDEPTH
   const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
-  *out_dist = vp9_highbd_block_error(coeff, dqcoeff, 16 << ss_txfrm_size,
-                                     &this_sse, bd) >> shift;
+  *out_dist = vp9_highbd_block_error_dispatch(coeff, dqcoeff,
+                                              16 << ss_txfrm_size,
+                                              &this_sse, bd) >> shift;
 #else
   *out_dist = vp9_block_error(coeff, dqcoeff, 16 << ss_txfrm_size,
                               &this_sse) >> shift;
@@ -831,7 +873,7 @@
             ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                                  so->scan, so->neighbors,
                                  cpi->sf.use_fast_coef_costing);
-            distortion += vp9_highbd_block_error(
+            distortion += vp9_highbd_block_error_dispatch(
                 coeff, BLOCK_OFFSET(pd->dqcoeff, block),
                 16, &unused, xd->bd) >> 2;
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
@@ -929,8 +971,13 @@
           ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                              so->scan, so->neighbors,
                              cpi->sf.use_fast_coef_costing);
+#if CONFIG_VP9_HIGHBITDEPTH
+          distortion += vp9_highbd_block_error_8bit(
+              coeff, BLOCK_OFFSET(pd->dqcoeff, block), 16, &unused) >> 2;
+#else
           distortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, block),
                                         16, &unused) >> 2;
+#endif
           if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
             goto next;
           vp9_iht4x4_add(tx_type, BLOCK_OFFSET(pd->dqcoeff, block),
@@ -1368,6 +1415,9 @@
   k = i;
   for (idy = 0; idy < height / 4; ++idy) {
     for (idx = 0; idx < width / 4; ++idx) {
+#if CONFIG_VP9_HIGHBITDEPTH
+      const int bd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd : 8;
+#endif
       int64_t ssz, rd, rd1, rd2;
       tran_low_t* coeff;
 
@@ -1377,14 +1427,8 @@
                     coeff, 8);
       vp9_regular_quantize_b_4x4(x, 0, k, so->scan, so->iscan);
 #if CONFIG_VP9_HIGHBITDEPTH
-      if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-        thisdistortion += vp9_highbd_block_error(coeff,
-                                                 BLOCK_OFFSET(pd->dqcoeff, k),
-                                                 16, &ssz, xd->bd);
-      } else {
-        thisdistortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, k),
-                                          16, &ssz);
-      }
+      thisdistortion += vp9_highbd_block_error_dispatch(
+          coeff, BLOCK_OFFSET(pd->dqcoeff, k), 16, &ssz, bd);
 #else
       thisdistortion += vp9_block_error(coeff, BLOCK_OFFSET(pd->dqcoeff, k),
                                         16, &ssz);
--- /dev/null
+++ b/vp9/encoder/x86/vp9_highbd_error_sse2.asm
@@ -1,0 +1,98 @@
+;
+;  Copyright (c) 2010 The WebM project authors. All Rights Reserved.
+;
+;  Use of this source code is governed by a BSD-style license
+;  that can be found in the LICENSE file in the root of the source
+;  tree. An additional intellectual property rights grant can be found
+;  in the file PATENTS.  All contributing project authors may
+;  be found in the AUTHORS file in the root of the source tree.
+;
+
+%define private_prefix vp9
+
+%include "third_party/x86inc/x86inc.asm"
+
+SECTION .text
+ALIGN 16
+
+;
+; int64_t vp9_highbd_block_error_8bit(int32_t *coeff, int32_t *dqcoeff,
+;                                     intptr_t block_size, int64_t *ssz)
+;
+
+INIT_XMM sse2
+cglobal highbd_block_error_8bit, 3, 3, 8, uqc, dqc, size, ssz
+  pxor      m4, m4                 ; sse accumulator
+  pxor      m6, m6                 ; ssz accumulator
+  pxor      m5, m5                 ; dedicated zero register
+  lea     uqcq, [uqcq+sizeq*4]
+  lea     dqcq, [dqcq+sizeq*4]
+  neg    sizeq
+
+  ALIGN 16
+
+.loop:
+  mova      m0, [dqcq+sizeq*4]
+  packssdw  m0, [dqcq+sizeq*4+mmsize]
+  mova      m2, [uqcq+sizeq*4]
+  packssdw  m2, [uqcq+sizeq*4+mmsize]
+
+  mova      m1, [dqcq+sizeq*4+mmsize*2]
+  packssdw  m1, [dqcq+sizeq*4+mmsize*3]
+  mova      m3, [uqcq+sizeq*4+mmsize*2]
+  packssdw  m3, [uqcq+sizeq*4+mmsize*3]
+
+  add    sizeq, mmsize
+
+  ; individual errors are max. 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit)
+
+  psubw     m0, m2
+  pmaddwd   m2, m2
+  pmaddwd   m0, m0
+
+  psubw     m1, m3
+  pmaddwd   m3, m3
+  pmaddwd   m1, m1
+
+  ; accumulate in 64bit
+  punpckldq m7, m0, m5
+  punpckhdq m0, m5
+  paddq     m4, m7
+
+  punpckldq m7, m2, m5
+  punpckhdq m2, m5
+  paddq     m6, m7
+
+  punpckldq m7, m1, m5
+  punpckhdq m1, m5
+  paddq     m4, m7
+
+  punpckldq m7, m3, m5
+  punpckhdq m3, m5
+  paddq     m6, m7
+
+  paddq     m4, m0
+  paddq     m4, m1
+  paddq     m6, m2
+  paddq     m6, m3
+
+  jnz .loop
+
+  ; accumulate horizontally and store in return value
+  movhlps   m5, m4
+  movhlps   m7, m6
+  paddq     m4, m5
+  paddq     m6, m7
+
+%if ARCH_X86_64
+  movq    rax, m4
+  movq [sszq], m6
+%else
+  mov     eax, sszm
+  pshufd   m5, m4, 0x1
+  movq  [eax], m6
+  movd    eax, m4
+  movd    edx, m5
+%endif
+  RET
--- a/vp9/vp9cx.mk
+++ b/vp9/vp9cx.mk
@@ -100,7 +100,11 @@
 
 ifeq ($(CONFIG_USE_X86INC),yes)
 VP9_CX_SRCS-$(HAVE_MMX) += encoder/x86/vp9_dct_mmx.asm
+ifeq ($(CONFIG_VP9_HIGHBITDEPTH),yes)
+VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_highbd_error_sse2.asm
+else
 VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_error_sse2.asm
+endif
 endif
 
 ifeq ($(ARCH_X86_64),yes)