shithub: libvpx

Download patch

ref: aa8f85223b7ed3568914c10dba0cd76d530d3369
parent: 9897e1c27c664b45b81e60a277df3e8186e03c4d
author: Geza Lore <[email protected]>
date: Thu Oct 15 14:28:31 EDT 2015

Optimize vp9_highbd_block_error_8bit assembly.

A new version of vp9_highbd_error_8bit is now available which is
optimized with AVX assembly. AVX itself does not buy us too much, but
the non-destructive 3 operand format encoding of the 128bit SSEn integer
instructions helps to eliminate move instructions. The Sandy Bridge
micro-architecture cannot eliminate move instructions in the processor
front end, so AVX will help on these machines.

Further 2 optimizations are applied:

1. The common case of computing block error on 4x4 blocks is optimized
as a special case.
2. All arithmetic is speculatively done on 32 bits only. At the end of
the loop, the code detects if overflow might have happened and if so,
the whole computation is re-executed using higher precision arithmetic.
This case however is extremely rare in real use, so we can achieve a
large net gain here.

The optimizations rely on the fact that the coefficients are in the
range [-(2^15-1), 2^15-1], and that the quantized coefficients always
have the same sign as the input coefficients (in the worst case they are
0). These are the same assumptions that the old SSE2 assembly code for
the non high bitdepth configuration relied on. The unit tests have been
updated to take this constraint into consideration when generating test
input data.

Change-Id: I57d9888a74715e7145a5d9987d67891ef68f39b7

--- a/test/vp9_error_block_test.cc
+++ b/test/vp9_error_block_test.cc
@@ -67,12 +67,22 @@
   int64_t ret;
   int64_t ref_ssz;
   int64_t ref_ret;
+  const int msb = bit_depth_ + 8 - 1;
   for (int i = 0; i < kNumIterations; ++i) {
     int err_count = 0;
     block_size = 16 << (i % 9);  // All block sizes from 4x4, 8x4 ..64x64
     for (int j = 0; j < block_size; j++) {
-      coeff[j]   = rnd(2 << 20) - (1 << 20);
-      dqcoeff[j] = rnd(2 << 20) - (1 << 20);
+      // coeff and dqcoeff will always have at least the same sign, and this
+      // can be used for optimization, so generate test input precisely.
+      if (rnd(2)) {
+        // Positive number
+        coeff[j]   = rnd(1 << msb);
+        dqcoeff[j] = rnd(1 << msb);
+      } else {
+        // Negative number
+        coeff[j]   = -rnd(1 << msb);
+        dqcoeff[j] = -rnd(1 << msb);
+      }
     }
     ref_ret = ref_error_block_op_(coeff, dqcoeff, block_size, &ref_ssz,
                                   bit_depth_);
@@ -85,7 +95,7 @@
     err_count_total += err_count;
   }
   EXPECT_EQ(0, err_count_total)
-      << "Error: Error Block Test, C output doesn't match SSE2 output. "
+      << "Error: Error Block Test, C output doesn't match optimized output. "
       << "First failed at test case " << first_failure;
 }
 
@@ -100,23 +110,36 @@
   int64_t ret;
   int64_t ref_ssz;
   int64_t ref_ret;
-  int max_val = ((1 << 20) - 1);
+  const int msb = bit_depth_ + 8 - 1;
+  int max_val = ((1 << msb) - 1);
   for (int i = 0; i < kNumIterations; ++i) {
     int err_count = 0;
-    int k = (i / 9) % 5;
+    int k = (i / 9) % 9;
 
     // Change the maximum coeff value, to test different bit boundaries
-    if ( k == 4 && (i % 9) == 0 ) {
+    if ( k == 8 && (i % 9) == 0 ) {
       max_val >>= 1;
     }
     block_size = 16 << (i % 9);  // All block sizes from 4x4, 8x4 ..64x64
     for (int j = 0; j < block_size; j++) {
-      if (k < 4) {  // Test at maximum values
-        coeff[j]   = k % 2 ? max_val : -max_val;
-        dqcoeff[j] = (k >> 1) % 2 ? max_val : -max_val;
+      if (k < 4) {
+        // Test at positive maximum values
+        coeff[j]   = k % 2 ? max_val : 0;
+        dqcoeff[j] = (k >> 1) % 2 ? max_val : 0;
+      } else if (k < 8) {
+        // Test at negative maximum values
+        coeff[j]   = k % 2 ? -max_val : 0;
+        dqcoeff[j] = (k >> 1) % 2 ? -max_val : 0;
       } else {
-        coeff[j]   = rnd(2 << 14) - (1 << 14);
-        dqcoeff[j] = rnd(2 << 14) - (1 << 14);
+        if (rnd(2)) {
+          // Positive number
+          coeff[j]   = rnd(1 << 14);
+          dqcoeff[j] = rnd(1 << 14);
+        } else {
+          // Negative number
+          coeff[j]   = -rnd(1 << 14);
+          dqcoeff[j] = -rnd(1 << 14);
+        }
       }
     }
     ref_ret = ref_error_block_op_(coeff, dqcoeff, block_size, &ref_ssz,
@@ -130,21 +153,13 @@
     err_count_total += err_count;
   }
   EXPECT_EQ(0, err_count_total)
-      << "Error: Error Block Test, C output doesn't match SSE2 output. "
+      << "Error: Error Block Test, C output doesn't match optimized output. "
       << "First failed at test case " << first_failure;
 }
 
 using std::tr1::make_tuple;
 
-#if CONFIG_USE_X86INC && HAVE_SSE2
-int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff,
-                                              const tran_low_t *dqcoeff,
-                                              intptr_t block_size,
-                                              int64_t *ssz, int bps) {
-  assert(bps == 8);
-  return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz);
-}
-
+#if CONFIG_USE_X86INC
 int64_t wrap_vp9_highbd_block_error_8bit_c(const tran_low_t *coeff,
                                            const tran_low_t *dqcoeff,
                                            intptr_t block_size,
@@ -153,6 +168,15 @@
   return vp9_highbd_block_error_8bit_c(coeff, dqcoeff, block_size, ssz);
 }
 
+#if HAVE_SSE2
+int64_t wrap_vp9_highbd_block_error_8bit_sse2(const tran_low_t *coeff,
+                                              const tran_low_t *dqcoeff,
+                                              intptr_t block_size,
+                                              int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_sse2(coeff, dqcoeff, block_size, ssz);
+}
+
 INSTANTIATE_TEST_CASE_P(
     SSE2, ErrorBlockTest,
     ::testing::Values(
@@ -165,5 +189,23 @@
         make_tuple(&wrap_vp9_highbd_block_error_8bit_sse2,
                    &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8)));
 #endif  // HAVE_SSE2
+
+#if HAVE_AVX
+int64_t wrap_vp9_highbd_block_error_8bit_avx(const tran_low_t *coeff,
+                                              const tran_low_t *dqcoeff,
+                                              intptr_t block_size,
+                                              int64_t *ssz, int bps) {
+  assert(bps == 8);
+  return vp9_highbd_block_error_8bit_avx(coeff, dqcoeff, block_size, ssz);
+}
+
+INSTANTIATE_TEST_CASE_P(
+    AVX, ErrorBlockTest,
+    ::testing::Values(
+        make_tuple(&wrap_vp9_highbd_block_error_8bit_avx,
+                   &wrap_vp9_highbd_block_error_8bit_c, VPX_BITS_8)));
+#endif  // HAVE_AVX
+
+#endif  // CONFIG_USE_X86INC
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 }  // namespace
--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -248,7 +248,7 @@
   specialize qw/vp9_highbd_block_error/, "$sse2_x86inc";
 
   add_proto qw/int64_t vp9_highbd_block_error_8bit/, "const tran_low_t *coeff, const tran_low_t *dqcoeff, intptr_t block_size, int64_t *ssz";
-  specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc";
+  specialize qw/vp9_highbd_block_error_8bit/, "$sse2_x86inc", "$avx_x86inc";
 
   add_proto qw/void vp9_quantize_fp/, "const tran_low_t *coeff_ptr, intptr_t n_coeffs, int skip_block, const int16_t *zbin_ptr, const int16_t *round_ptr, const int16_t *quant_ptr, const int16_t *quant_shift_ptr, tran_low_t *qcoeff_ptr, tran_low_t *dqcoeff_ptr, const int16_t *dequant_ptr, uint16_t *eob_ptr, const int16_t *scan, const int16_t *iscan";
   specialize qw/vp9_quantize_fp/;
--- a/vp9/encoder/vp9_rdopt.c
+++ b/vp9/encoder/vp9_rdopt.c
@@ -296,30 +296,11 @@
                                       const tran_low_t *dqcoeff,
                                       intptr_t block_size,
                                       int64_t *ssz) {
-  int i;
-  int32_t c, d;
-  int64_t error = 0, sqcoeff = 0;
-  int16_t diff;
-
-  const int32_t hi = 0x00007fff;
-  const int32_t lo = 0xffff8000;
-
-  for (i = 0; i < block_size; i++) {
-    c = coeff[i];
-    d = dqcoeff[i];
-
-    // Saturate to 16 bits
-    c = (c > hi) ? hi : ((c < lo) ? lo : c);
-    d = (d > hi) ? hi : ((d < lo) ? lo : d);
-
-    diff = d - c;
-    error +=  diff * diff;
-    sqcoeff += c * c;
-  }
-  assert(error >= 0 && sqcoeff >= 0);
-
-  *ssz = sqcoeff;
-  return error;
+  // Note that the C versions of these 2 functions (vp9_block_error and
+  // vp9_highbd_block_error_8bit are the same, but the optimized assembly
+  // routines are not compatible in the non high bitdepth configuration, so
+  // they still cannot share the same name.
+  return vp9_block_error_c(coeff, dqcoeff, block_size, ssz);
 }
 
 static int64_t vp9_highbd_block_error_dispatch(const tran_low_t *coeff,
--- /dev/null
+++ b/vp9/encoder/x86/vp9_highbd_error_avx.asm
@@ -1,0 +1,261 @@
+;
+;  Copyright (c) 2015 The WebM project authors. All Rights Reserved.
+;
+;  Use of this source code is governed by a BSD-style license
+;  that can be found in the LICENSE file in the root of the source
+;  tree. An additional intellectual property rights grant can be found
+;  in the file PATENTS.  All contributing project authors may
+;  be found in the AUTHORS file in the root of the source tree.
+;
+
+%define private_prefix vp9
+
+%include "third_party/x86inc/x86inc.asm"
+
+SECTION .text
+ALIGN 16
+
+;
+; int64_t vp9_highbd_block_error_8bit(int32_t *coeff, int32_t *dqcoeff,
+;                                     intptr_t block_size, int64_t *ssz)
+;
+
+INIT_XMM avx
+cglobal highbd_block_error_8bit, 4, 5, 8, uqc, dqc, size, ssz
+  vzeroupper
+
+  ; If only one iteration is required, then handle this as a special case.
+  ; It is the most frequent case, so we can have a significant gain here
+  ; by not setting up a loop and accumulators.
+  cmp    sizeq, 16
+  jne   .generic
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Common case of size == 16
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+
+  ; Load input vectors
+  mova      xm0, [dqcq]
+  packssdw  xm0, [dqcq+16]
+  mova      xm2, [uqcq]
+  packssdw  xm2, [uqcq+16]
+
+  mova      xm1, [dqcq+32]
+  packssdw  xm1, [dqcq+48]
+  mova      xm3, [uqcq+32]
+  packssdw  xm3, [uqcq+48]
+
+  ; Compute the errors.
+  psubw     xm0, xm2
+  psubw     xm1, xm3
+
+  ; Individual errors are max 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit).
+  pmaddwd   xm2, xm2
+  pmaddwd   xm3, xm3
+
+  pmaddwd   xm0, xm0
+  pmaddwd   xm1, xm1
+
+  ; Squares are always positive, so we can use unsigned arithmetic after
+  ; squaring. As mentioned earlier 2 sums fit in 31 bits, so 4 sums will
+  ; fit in 32bits
+  paddd     xm2, xm3
+  paddd     xm0, xm1
+
+  ; Accumulate horizontally in 64 bits, there is no chance of overflow here
+  pxor      xm5, xm5
+
+  pblendw   xm3, xm5, xm2, 0x33 ; Zero extended  low of a pair of 32 bits
+  psrlq     xm2, 32             ; Zero extended high of a pair of 32 bits
+
+  pblendw   xm1, xm5, xm0, 0x33 ; Zero extended  low of a pair of 32 bits
+  psrlq     xm0, 32             ; Zero extended high of a pair of 32 bits
+
+  paddq     xm2, xm3
+  paddq     xm0, xm1
+
+  psrldq    xm3, xm2, 8
+  psrldq    xm1, xm0, 8
+
+  paddq     xm2, xm3
+  paddq     xm0, xm1
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm0
+  movq   [sszq], xm2
+%else
+  movd      eax, xm0
+  pextrd    edx, xm0, 1
+  movq   [sszd], xm2
+%endif
+  RET
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Generic case of size != 16, speculative low precision
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ALIGN 16
+.generic:
+  pxor      xm4, xm4                ; sse accumulator
+  pxor      xm5, xm5                ; overflow detection register for xm4
+  pxor      xm6, xm6                ; ssz accumulator
+  pxor      xm7, xm7                ; overflow detection register for xm6
+  lea      uqcq, [uqcq+sizeq*4]
+  lea      dqcq, [dqcq+sizeq*4]
+  neg     sizeq
+
+  ; Push the negative size as the high precision code might need it
+  push    sizeq
+
+.loop:
+  ; Load input vectors
+  mova      xm0, [dqcq+sizeq*4]
+  packssdw  xm0, [dqcq+sizeq*4+16]
+  mova      xm2, [uqcq+sizeq*4]
+  packssdw  xm2, [uqcq+sizeq*4+16]
+
+  mova      xm1, [dqcq+sizeq*4+32]
+  packssdw  xm1, [dqcq+sizeq*4+48]
+  mova      xm3, [uqcq+sizeq*4+32]
+  packssdw  xm3, [uqcq+sizeq*4+48]
+
+  add     sizeq, 16
+
+  ; Compute the squared errors.
+  ; Individual errors are max 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit).
+  psubw     xm0, xm2
+  pmaddwd   xm2, xm2
+  pmaddwd   xm0, xm0
+
+  psubw     xm1, xm3
+  pmaddwd   xm3, xm3
+  pmaddwd   xm1, xm1
+
+  ; Squares are always positive, so we can use unsigned arithmetic after
+  ; squaring. As mentioned earlier 2 sums fit in 31 bits, so 4 sums will
+  ; fit in 32bits
+  paddd     xm2, xm3
+  paddd     xm0, xm1
+
+  ; We accumulate using 32 bit arithmetic, but detect potential overflow
+  ; by checking if the MSB of the accumulators have ever been a set bit.
+  ; If yes, we redo the whole compute at the end on higher precision, but
+  ; this happens extremely rarely, so we still achieve a net gain.
+  paddd     xm4, xm0
+  paddd     xm6, xm2
+  por       xm5, xm4  ; OR in the accumulator for overflow detection
+  por       xm7, xm6  ; OR in the accumulator for overflow detection
+
+  jnz .loop
+
+  ; Add pairs horizontally (still only on 32 bits)
+  phaddd    xm4, xm4
+  por       xm5, xm4  ; OR in the accumulator for overflow detection
+  phaddd    xm6, xm6
+  por       xm7, xm6  ; OR in the accumulator for overflow detection
+
+  ; Check for possibility of overflow by testing if bit 32 of each dword lane
+  ; have ever been set. If they were not, then there was no overflow and the
+  ; final sum will fit in 32 bits. If overflow happened, then
+  ; we redo the whole computation on higher precision.
+  por       xm7, xm5
+  pmovmskb   r4, xm7
+  test       r4, 0x8888
+  jnz .highprec
+
+  phaddd    xm4, xm4
+  phaddd    xm6, xm6
+  pmovzxdq  xm4, xm4
+  pmovzxdq  xm6, xm6
+
+  ; Restore stack
+  pop     sizeq
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm4
+  movq   [sszq], xm6
+%else
+  movd      eax, xm4
+  pextrd    edx, xm4, 1
+  movq   [sszd], xm6
+%endif
+  RET
+
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+  ;; Generic case of size != 16, high precision case
+  ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
+.highprec:
+  pxor      xm4, xm4                 ; sse accumulator
+  pxor      xm5, xm5                 ; dedicated zero register
+  pxor      xm6, xm6                 ; ssz accumulator
+  pop     sizeq
+
+.loophp:
+  mova      xm0, [dqcq+sizeq*4]
+  packssdw  xm0, [dqcq+sizeq*4+16]
+  mova      xm2, [uqcq+sizeq*4]
+  packssdw  xm2, [uqcq+sizeq*4+16]
+
+  mova      xm1, [dqcq+sizeq*4+32]
+  packssdw  xm1, [dqcq+sizeq*4+48]
+  mova      xm3, [uqcq+sizeq*4+32]
+  packssdw  xm3, [uqcq+sizeq*4+48]
+
+  add     sizeq, 16
+
+  ; individual errors are max. 15bit+sign, so squares are 30bit, and
+  ; thus the sum of 2 should fit in a 31bit integer (+ unused sign bit)
+
+  psubw     xm0, xm2
+  pmaddwd   xm2, xm2
+  pmaddwd   xm0, xm0
+
+  psubw     xm1, xm3
+  pmaddwd   xm3, xm3
+  pmaddwd   xm1, xm1
+
+  ; accumulate in 64bit
+  punpckldq xm7, xm0, xm5
+  punpckhdq xm0, xm5
+  paddq     xm4, xm7
+
+  punpckldq xm7, xm2, xm5
+  punpckhdq xm2, xm5
+  paddq     xm6, xm7
+
+  punpckldq xm7, xm1, xm5
+  punpckhdq xm1, xm5
+  paddq     xm4, xm7
+
+  punpckldq xm7, xm3, xm5
+  punpckhdq xm3, xm5
+  paddq     xm6, xm7
+
+  paddq     xm4, xm0
+  paddq     xm4, xm1
+  paddq     xm6, xm2
+  paddq     xm6, xm3
+
+  jnz .loophp
+
+  ; Accumulate horizontally
+  movhlps   xm5, xm4
+  movhlps   xm7, xm6
+  paddq     xm4, xm5
+  paddq     xm6, xm7
+
+  ; Store the return value
+%if ARCH_X86_64
+  movq      rax, xm4
+  movq   [sszq], xm6
+%else
+  movd      eax, xm4
+  pextrd    edx, xm4, 1
+  movq   [sszd], xm6
+%endif
+  RET
+
+END
--- a/vp9/vp9cx.mk
+++ b/vp9/vp9cx.mk
@@ -102,6 +102,7 @@
 VP9_CX_SRCS-$(HAVE_MMX) += encoder/x86/vp9_dct_mmx.asm
 ifeq ($(CONFIG_VP9_HIGHBITDEPTH),yes)
 VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_highbd_error_sse2.asm
+VP9_CX_SRCS-$(HAVE_AVX) += encoder/x86/vp9_highbd_error_avx.asm
 else
 VP9_CX_SRCS-$(HAVE_SSE2) += encoder/x86/vp9_error_sse2.asm
 endif