shithub: libvpx

Download patch

ref: 35cae7f1b30ec6299127b28c6ff865dce52dae11
parent: e34c7e3f59ad7bf398c457865ca06af958c89f87
parent: aa8f85223b7ed3568914c10dba0cd76d530d3369
author: Debargha Mukherjee <[email protected]>
date: Mon Oct 26 14:03:46 EDT 2015

Merge "Optimize vp9_highbd_block_error_8bit assembly."

--- a/test/vp9_error_block_test.cc
+++ b/test/vp9_error_block_test.cc
@@ -67,12 +67,22 @@
   int64_t ret;
   int64_t ref_ssz;
   int64_t ref_ret;
+  const int msb = bit_depth_ + 8 - 1;
   for (int i = 0; i < kNumIterations; ++i) {
     int err_count = 0;
     block_size = 16 << (i % 9);  // All block sizes from 4x4, 8x4 ..64x64
     for (int j = 0; j < block_size; j++) {
-      coeff[j]   = rnd(2 << 20) - (1 << 20);
-      dqcoeff[j] = rnd(2 << 20) - (1 << 20);
+      // coeff and dqcoeff will always have at least the same sign, and this
+      // can be used for optimization, so generate test input precisely.
+      if (rnd(2)) {
+        // Positive number
+        coeff[j]   = rnd(1 << msb);
+        dqcoeff[j] = rnd(1 << msb);
+      } else {
+        // Negative number
+        coeff[j]   = -rnd(1 << msb);
+        dqcoeff[j] = -rnd(1 << msb);
+      }
     }
     ref_ret = ref_error_block_op_(coeff, dqcoeff, block_size, &ref_ssz,
                                   bit_depth_);
@@ -85,7 +95,7 @@
     err_count_total += err_count;
   }
   EXPECT_EQ(0, err_count_total)
-      << "Error: Error Block Test, C output doesn't match SSE2 output. "
+      << "Error: Error Block Test, C output doesn't match optimized output. "
       << "First failed at test case " << first_failure;
 }
 
@@ -100,23 +110,36 @@
   int64_t ret;
   int64_t ref_ssz;
   int64_t ref_ret;
-  int max_val = ((1 << 20) - 1);
+  const int msb = bit_depth_ + 8 - 1;
+  int max_val = ((1 << msb) - 1);
   for (int i = 0; i < kNumIterations; ++i) {
     int err_count = 0;
-    int k = (i / 9) % 5;
+    int k = (i / 9) % 9;
 
     // Change the maximum coeff value, to test different bit boundaries
-    if ( k == 4 && (i % 9) == 0 ) {
+    if ( k == 8 && (i % 9) == 0 ) {
       max_val >>= 1;
     }
     block_size = 16 << (i % 9);  // All block sizes from 4x4, 8x4 ..64x64
     for (int j = 0; j < block_size; j++) {
-      if (k < 4) {  // Test at maximum values
-        coeff[j]   = k % 2 ? max_val : -max_val;
-        dqcoeff[j] = (k >> 1) % 2 ? max_val : -max_val;
+      if (k < 4) {
+        // Test at positive maximum values
+        coeff[j]   = k % 2 ? max_val : 0;
+        dqcoeff[j] = (k >> 1) % 2 ? max_val : 0;
+      } else if (k < 8) {
+        // Test at negative maximum values
+        coeff[j]   = k % 2 ? -max_val : 0;
+        dqcoeff[j] = (k >> 1) % 2 ? -max_val : 0;
       } else {
-        coeff[j]   = rnd(2 << 14) - (1 << 14);
-        dqcoeff[j] = rnd(2 << 14) - (1 << 14);
+        if (rnd(2)) {
+          // Positive number
+          coeff[j]   = rnd(1 << 14);
+          dqcoeff[j] = rnd(1 << 14);
+        } else {
+          // Negative number
+          coeff[j]   = -rnd(1 << 14);
+          dqcoeff[j] = -rnd(1 << 14);
+        }
       }
     }
     ref_ret = ref_error_block_op_(coeff, dqcoeff, block_size, &ref_ssz,
@@ -130,21 +153,13 @@
     err_count_total += err_count;
   }
   EXPECT_EQ(0, err_count_total)
-      << "Error: Error Block Test, C output doesn't match SSE2 output. "
+      << "Error: Error Block Test, C output doesn't match optimized output. "
       << "First failed at test case " << first_failure;
 }
 
 using std::tr1::make_tuple;
 
-#if CONFIG_USE_X86INC && HAVE_SSE2
-int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff,
-                                              const tran_low_t *dqcoeff,
-                                              intptr_t block_size,
-                                              int64_t *ssz, int bps) {
-  assert(bps == 8);
-  return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz);
-}
-
+#if CONFIG_USE_X86INC
 int64_t wrap_vp9_highbd_block_error_8bit_c(const tran_low_t *coeff,
                                            const tran_low_t *dqcoeff,
                                            intptr_t block_size,
@@ -153,6 +168,15 @@
   return vp9_highbd_block_error_8bit_c(coeff, dqcoeff, block_size, ssz);
 }
 
+#if HAVE_SSE2
+int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff,
+                                              const tran_low_t *dqcoeff,
+                                              intptr_t block_size,
+                                              int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz);
+}
+
 INSTANTIATE_TEST_CASE_P(
     SSE2, ErrorBlockTest,
     ::testing::Values(
@@ -165,5 +189,23 @@
         make_tuple(&wrap_vp9_highbd_block_error_8bit_sse2,
                    &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8)));
 #endif  // HAVE_SSE2
+
+#if HAVE_AVX
+int64_t wrap_vp9_highbd_block_error_8bit_avx(const tran_low_t *coeff,
+                                              const tran_low_t *dqcoeff,
+                                              intptr_t block_size,
+                                              int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_avx(coeff, dqcoeff, block_size, ssz);
+}
+
+INSTANTIATE_TEST_CASE_P(
+    AVX, ErrorBlockTest,
+    ::testing::Values(
+        make_tuple(&wrap_vp9_highbd_block_error_8bit_avx,
+                   &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8)));
+#endif  // HAVE_AVX
+
+#endif  // CONFIG_USE_X86INC
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 }  // namespace
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -248,7 +248,7 @@
   specialize qw/vp9_highbd_block_error/, "$sse2_x86inc";
 
   add_proto qw/int64_t vp9_highbd_block_error_8bit/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz";
-  specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc";
+  specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc", "$avx_x86inc";
 
   add_proto qw/void vp9_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan";
   specialize qw/vp9_quantize_fp/;
--- a/vp9/encoder/vp9_rdopt.c
+++ b/vp9/encoder/vp9_rdopt.c
@@ -296,30 +296,11 @@
                                       const tran_low_t *dqcoeff,
                                       intptr_t block_size,
                                       int64_t *ssz) {
-  int i;
-  int32_t c, d;
-  int64_t error = 0, sqcoeff = 0;
-  int16_t diff;
-
-  const int32_t hi = 0x00007fff;
-  const int32_t lo = 0xffff8000;
-
-  for (i = 0; i < block_size; i++) {
-    c = coeff[i];
-    d = dqcoeff[i];
-
-    // Saturate to 16 bits
-    c = (c > hi) ? hi : ((c < lo) ? lo : c);
-    d = (d > hi) ? hi : ((d < lo) ? lo : d);
-
-    diff = d - c;
-    error +=  diff * diff;
-    sqcoeff += c * c;
-  }
-  assert(error >= 0 && sqcoeff >= 0);
-
-  *ssz = sqcoeff;
-  return error;
+  // Note that the C versions of these 2 functions (vp9_block_error and
+  // vp9_highbd_block_error_8bit are the same, but the optimized assembly
+  // routines are not compatible in the non high bitdepth configuration, so
+  // they still cannot share the same name.
+  return vp9_block_error_c(coeff, dqcoeff, block_size, ssz);
 }
 
 static int64_t vp9_highbd_block_error_dispatch(const tran_low_t *coeff,
--- /dev/null
+++ b/vp9/encoder/x86/vp9_highbd_error_avx.asm
@@ -1,0 +1,261 @@
+;
+;  Copyright (c) 2015 The WebM project authors. All Rights Reserved.
+;
+;  Use of this source code is governed by a BSD-style license
+;  that can be found in the LICENSE file in the root of the source
+;  tree. An additional intellectual property rights grant can be found
+;  in the file PATENTS.  All contributing project authors may
+;  be found in the AUTHORS file in the root of the source tree.
+;
+
+%define private_prefix vp9
+
+%include "third_party/x86inc/x86inc.asm"
+
+SECTION .text
+ALIGN 16
+
+;
+; int64_t vp9_highbd_block_error_8bit(int32_t *coeff, int32_t *dqcoeff,
+;                                     intptr_t block_size, int64_t *ssz)
+;
+
+INIT_XMM avx
+cglobal highbd_block_error_8bit, 4, 5, 8, uqc, dqc, size, ssz
+  vzeroupper
+
+  ; If only one iteration is required, then handle this as a special case.
+  ; It is the most frequent case, so we can have a significant gain here
+  ; by not setting up a loop and accumulators.
+  cmp    sizeq, 16
+  jne   .generic
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Common case of size == 16
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  ; Load input vectors
+  mova      xm0, [dqcq]
+  packssdw  xm0, [dqcq+16]
+  mova      xm2, [uqcq]
+  packssdw  xm2, [uqcq+16]
+
+  mova      xm1, [dqcq+32]
+  packssdw  xm1, [dqcq+48]
+  mova      xm3, [uqcq+32]
+  packssdw  xm3, [uqcq+48]
+
+  ; Compute the errors.
+  psubw     xm0, xm2
+  psubw     xm1, xm3
+
+  ; Individual errors are max 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit).
+  pmaddwd   xm2, xm2
+  pmaddwd   xm3, xm3
+
+  pmaddwd   xm0, xm0
+  pmaddwd   xm1, xm1
+
+  ; Squares are always positive, so we can use unsigned arithmetic after
+  ; squaring. As mentioned earlier 2 sums fit in 31 bits, so 4 sums will
+  ; fit in 32bits
+  paddd     xm2, xm3
+  paddd     xm0, xm1
+
+  ; Accumulate horizontally in 64 bits, there is no chance of overflow here
+  pxor      xm5, xm5
+
+  pblendw   xm3, xm5, xm2, 0x33 ; Zero extended  low of a pair of 32 bits
+  psrlq     xm2, 32             ; Zero extended high of a pair of 32 bits
+
+  pblendw   xm1, xm5, xm0, 0x33 ; Zero extended  low of a pair of 32 bits
+  psrlq     xm0, 32             ; Zero extended high of a pair of 32 bits
+
+  paddq     xm2, xm3
+  paddq     xm0, xm1
+
+  psrldq    xm3, xm2, 8
+  psrldq    xm1, xm0, 8
+
+  paddq     xm2, xm3
+  paddq     xm0, xm1
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm0
+  movq   [sszq], xm2
+%else
+  movd      eax, xm0
+  pextrd    edx, xm0, 1
+  movq   [sszd], xm2
+%endif
+  RET
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Generic case of size != 16, speculative low precision
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ALIGN 16
+.generic:
+  pxor      xm4, xm4                ; sse accumulator
+  pxor      xm5, xm5                ; overflow detection register for xm4
+  pxor      xm6, xm6                ; ssz accumulator
+  pxor      xm7, xm7                ; overflow detection register for xm6
+  lea      uqcq, [uqcq+sizeq*4]
+  lea      dqcq, [dqcq+sizeq*4]
+  neg     sizeq
+
+  ; Push the negative size as the high precision code might need it
+  push    sizeq
+
+.loop:
+  ; Load input vectors
+  mova      xm0, [dqcq+sizeq*4]
+  packssdw  xm0, [dqcq+sizeq*4+16]
+  mova      xm2, [uqcq+sizeq*4]
+  packssdw  xm2, [uqcq+sizeq*4+16]
+
+  mova      xm1, [dqcq+sizeq*4+32]
+  packssdw  xm1, [dqcq+sizeq*4+48]
+  mova      xm3, [uqcq+sizeq*4+32]
+  packssdw  xm3, [uqcq+sizeq*4+48]
+
+  add     sizeq, 16
+
+  ; Compute the squared errors.
+  ; Individual errors are max 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit).
+  psubw     xm0, xm2
+  pmaddwd   xm2, xm2
+  pmaddwd   xm0, xm0
+
+  psubw     xm1, xm3
+  pmaddwd   xm3, xm3
+  pmaddwd   xm1, xm1
+
+  ; Squares are always positive, so we can use unsigned arithmetic after
+  ; squaring. As mentioned earlier 2 sums fit in 31 bits, so 4 sums will
+  ; fit in 32bits
+  paddd     xm2, xm3
+  paddd     xm0, xm1
+
+  ; We accumulate using 32 bit arithmetic, but detect potential overflow
+  ; by checking if the MSB of the accumulators have ever been a set bit.
+  ; If yes, we redo the whole compute at the end on higher precision, but
+  ; this happens extremely rarely, so we still achieve a net gain.
+  paddd     xm4, xm0
+  paddd     xm6, xm2
+  por       xm5, xm4  ; OR in the accumulator for overflow detection
+  por       xm7, xm6  ; OR in the accumulator for overflow detection
+
+  jnz .loop
+
+  ; Add pairs horizontally (still only on 32 bits)
+  phaddd    xm4, xm4
+  por       xm5, xm4  ; OR in the accumulator for overflow detection
+  phaddd    xm6, xm6
+  por       xm7, xm6  ; OR in the accumulator for overflow detection
+
+  ; Check for possibility of overflow by testing if bit 32 of each dword lane
+  ; have ever been set. If they were not, then there was no overflow and the
+  ; final sum will fit in 32 bits. If overflow happened, then
+  ; we redo the whole computation on higher precision.
+  por       xm7, xm5
+  pmovmskb   r4, xm7
+  test       r4, 0x8888
+  jnz .highprec
+
+  phaddd    xm4, xm4
+  phaddd    xm6, xm6
+  pmovzxdq  xm4, xm4
+  pmovzxdq  xm6, xm6
+
+  ; Restore stack
+  pop     sizeq
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm4
+  movq   [sszq], xm6
+%else
+  movd      eax, xm4
+  pextrd    edx, xm4, 1
+  movq   [sszd], xm6
+%endif
+  RET
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Generic case of size != 16, high precision case
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+.highprec:
+  pxor      xm4, xm4                 ; sse accumulator
+  pxor      xm5, xm5                 ; dedicated zero register
+  pxor      xm6, xm6                 ; ssz accumulator
+  pop     sizeq
+
+.loophp:
+  mova      xm0, [dqcq+sizeq*4]
+  packssdw  xm0, [dqcq+sizeq*4+16]
+  mova      xm2, [uqcq+sizeq*4]
+  packssdw  xm2, [uqcq+sizeq*4+16]
+
+  mova      xm1, [dqcq+sizeq*4+32]
+  packssdw  xm1, [dqcq+sizeq*4+48]
+  mova      xm3, [uqcq+sizeq*4+32]
+  packssdw  xm3, [uqcq+sizeq*4+48]
+
+  add     sizeq, 16
+
+  ; individual errors are max. 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit)
+
+  psubw     xm0, xm2
+  pmaddwd   xm2, xm2
+  pmaddwd   xm0, xm0
+
+  psubw     xm1, xm3
+  pmaddwd   xm3, xm3
+  pmaddwd   xm1, xm1
+
+  ; accumulate in 64bit
+  punpckldq xm7, xm0, xm5
+  punpckhdq xm0, xm5
+  paddq     xm4, xm7
+
+  punpckldq xm7, xm2, xm5
+  punpckhdq xm2, xm5
+  paddq     xm6, xm7
+
+  punpckldq xm7, xm1, xm5
+  punpckhdq xm1, xm5
+  paddq     xm4, xm7
+
+  punpckldq xm7, xm3, xm5
+  punpckhdq xm3, xm5
+  paddq     xm6, xm7
+
+  paddq     xm4, xm0
+  paddq     xm4, xm1
+  paddq     xm6, xm2
+  paddq     xm6, xm3
+
+  jnz .loophp
+
+  ; Accumulate horizontally
+  movhlps   xm5, xm4
+  movhlps   xm7, xm6
+  paddq     xm4, xm5
+  paddq     xm6, xm7
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm4
+  movq   [sszq], xm6
+%else
+  movd      eax, xm4
+  pextrd    edx, xm4, 1
+  movq   [sszd], xm6
+%endif
+  RET
+
+END
--- a/vp9/vp9cx.mk
+++ b/vp9/vp9cx.mk
@@ -102,6 +102,7 @@
 VP9_CX_SRCS-$(HAVE_MMX) += encoder/x86/vp9_dct_mmx.asm
 ifeq ($(CONFIG_VP9_HIGHBITDEPTH),yes)
 VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_highbd_error_sse2.asm
+VP9_CX_SRCS-$(HAVE_AVX) += encoder/x86/vp9_highbd_error_avx.asm
 else
 VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_error_sse2.asm
 endif