shithub: libvpx

Download patch

ref: 406030d1b0278b1373c62e1a70fb0711e2fbde2b
parent: 7d28d12ef34f6cbb6b1e18f3b23b71392fd3ddf5
author: Julia Robson <[email protected]>
date: Mon Sep 28 12:50:39 EDT 2015

Accelerated transform in high bit depth

When configured with high bitdepth enabled, the 8bit transform
stopped using optimised code. This made 8bit content decode slowly.

Change-Id: I67d91f9b212921d5320f949fc0a0d3f32f90c0ea

--- a/vp9/common/vp9_rtcd_defs.pl
+++ b/vp9/common/vp9_rtcd_defs.pl
@@ -85,16 +85,26 @@
 # dct
 #
 if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
-  # Note as optimized versions of these functions are added we need to add a check to ensure
-  # that when CONFIG_EMULATE_HARDWARE is on, it defaults to the C versions only.
-  add_proto qw/void vp9_iht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
-  specialize qw/vp9_iht4x4_16_add/;
+  # Force C versions if CONFIG_EMULATE_HARDWARE is 1
+  if (vpx_config("CONFIG_EMULATE_HARDWARE") eq "yes") {
+    add_proto qw/void vp9_iht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
+    specialize qw/vp9_iht4x4_16_add/;
 
-  add_proto qw/void vp9_iht8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
-  specialize qw/vp9_iht8x8_64_add/;
+    add_proto qw/void vp9_iht8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
+    specialize qw/vp9_iht8x8_64_add/;
 
-  add_proto qw/void vp9_iht16x16_256_add/, "const tran_low_t *input, uint8_t *output, int pitch, int tx_type";
-  specialize qw/vp9_iht16x16_256_add/;
+    add_proto qw/void vp9_iht16x16_256_add/, "const tran_low_t *input, uint8_t *output, int pitch, int tx_type";
+    specialize qw/vp9_iht16x16_256_add/;
+  } else {
+    add_proto qw/void vp9_iht4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
+    specialize qw/vp9_iht4x4_16_add sse2/;
+
+    add_proto qw/void vp9_iht8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int tx_type";
+    specialize qw/vp9_iht8x8_64_add sse2/;
+
+    add_proto qw/void vp9_iht16x16_256_add/, "const tran_low_t *input, uint8_t *output, int pitch, int tx_type";
+    specialize qw/vp9_iht16x16_256_add sse2/;
+  }
 } else {
   # Force C versions if CONFIG_EMULATE_HARDWARE is 1
   if (vpx_config("CONFIG_EMULATE_HARDWARE") eq "yes") {
--- a/vp9/common/x86/vp9_idct_intrin_sse2.c
+++ b/vp9/common/x86/vp9_idct_intrin_sse2.c
@@ -12,14 +12,14 @@
 #include "vpx_dsp/x86/txfm_common_sse2.h"
 #include "vpx_ports/mem.h"
 
-void vp9_iht4x4_16_add_sse2(const int16_t *input, uint8_t *dest, int stride,
+void vp9_iht4x4_16_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
                             int tx_type) {
   __m128i in[2];
   const __m128i zero = _mm_setzero_si128();
   const __m128i eight = _mm_set1_epi16(8);
 
-  in[0] = _mm_loadu_si128((const __m128i *)(input));
-  in[1] = _mm_loadu_si128((const __m128i *)(input + 8));
+  in[0] = load_input_data(input);
+  in[1] = load_input_data(input + 8);
 
   switch (tx_type) {
     case 0:  // DCT_DCT
@@ -77,7 +77,7 @@
   }
 }
 
-void vp9_iht8x8_64_add_sse2(const int16_t *input, uint8_t *dest, int stride,
+void vp9_iht8x8_64_add_sse2(const tran_low_t *input, uint8_t *dest, int stride,
                             int tx_type) {
   __m128i in[8];
   const __m128i zero = _mm_setzero_si128();
@@ -84,14 +84,14 @@
   const __m128i final_rounding = _mm_set1_epi16(1 << 4);
 
   // load input data
-  in[0] = _mm_load_si128((const __m128i *)input);
-  in[1] = _mm_load_si128((const __m128i *)(input + 8 * 1));
-  in[2] = _mm_load_si128((const __m128i *)(input + 8 * 2));
-  in[3] = _mm_load_si128((const __m128i *)(input + 8 * 3));
-  in[4] = _mm_load_si128((const __m128i *)(input + 8 * 4));
-  in[5] = _mm_load_si128((const __m128i *)(input + 8 * 5));
-  in[6] = _mm_load_si128((const __m128i *)(input + 8 * 6));
-  in[7] = _mm_load_si128((const __m128i *)(input + 8 * 7));
+  in[0] = load_input_data(input);
+  in[1] = load_input_data(input + 8 * 1);
+  in[2] = load_input_data(input + 8 * 2);
+  in[3] = load_input_data(input + 8 * 3);
+  in[4] = load_input_data(input + 8 * 4);
+  in[5] = load_input_data(input + 8 * 5);
+  in[6] = load_input_data(input + 8 * 6);
+  in[7] = load_input_data(input + 8 * 7);
 
   switch (tx_type) {
     case 0:  // DCT_DCT
@@ -144,8 +144,8 @@
   RECON_AND_STORE(dest + 7 * stride, in[7]);
 }
 
-void vp9_iht16x16_256_add_sse2(const int16_t *input, uint8_t *dest, int stride,
-                               int tx_type) {
+void vp9_iht16x16_256_add_sse2(const tran_low_t *input, uint8_t *dest,
+                               int stride, int tx_type) {
   __m128i in0[16], in1[16];
 
   load_buffer_8x16(input, in0);
--- a/vpx_dsp/vpx_dsp_rtcd_defs.pl
+++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl
@@ -616,39 +616,6 @@
 if (vpx_config("CONFIG_VP9_HIGHBITDEPTH") eq "yes") {
   # Note as optimized versions of these functions are added we need to add a check to ensure
   # that when CONFIG_EMULATE_HARDWARE is on, it defaults to the C versions only.
-  add_proto qw/void vpx_idct4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct4x4_1_add/;
-
-  add_proto qw/void vpx_idct4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct4x4_16_add/;
-
-  add_proto qw/void vpx_idct8x8_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct8x8_1_add/;
-
-  add_proto qw/void vpx_idct8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct8x8_64_add/;
-
-  add_proto qw/void vpx_idct8x8_12_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct8x8_12_add/;
-
-  add_proto qw/void vpx_idct16x16_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct16x16_1_add/;
-
-  add_proto qw/void vpx_idct16x16_256_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct16x16_256_add/;
-
-  add_proto qw/void vpx_idct16x16_10_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct16x16_10_add/;
-
-  add_proto qw/void vpx_idct32x32_1024_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct32x32_1024_add/;
-
-  add_proto qw/void vpx_idct32x32_34_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct32x32_34_add/;
-
-  add_proto qw/void vpx_idct32x32_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
-  specialize qw/vpx_idct32x32_1_add/;
-
   add_proto qw/void vpx_iwht4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
   specialize qw/vpx_iwht4x4_1_add/;
 
@@ -681,6 +648,39 @@
 
   # Force C versions if CONFIG_EMULATE_HARDWARE is 1
   if (vpx_config("CONFIG_EMULATE_HARDWARE") eq "yes") {
+    add_proto qw/void vpx_idct4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct4x4_16_add/;
+
+    add_proto qw/void vpx_idct4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct4x4_1_add/;
+
+    add_proto qw/void vpx_idct8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_64_add/;
+
+    add_proto qw/void vpx_idct8x8_12_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_12_add/;
+
+    add_proto qw/void vpx_idct8x8_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_1_add/;
+
+    add_proto qw/void vpx_idct16x16_256_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_256_add/;
+
+    add_proto qw/void vpx_idct16x16_10_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_10_add/;
+
+    add_proto qw/void vpx_idct16x16_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_1_add/;
+
+    add_proto qw/void vpx_idct32x32_1024_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_1024_add/;
+
+    add_proto qw/void vpx_idct32x32_34_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_34_add/;
+
+    add_proto qw/void vpx_idct32x32_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_1_add/;
+    
     add_proto qw/void vpx_highbd_idct4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
     specialize qw/vpx_highbd_idct4x4_16_add/;
 
@@ -696,6 +696,39 @@
     add_proto qw/void vpx_highbd_idct16x16_10_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
     specialize qw/vpx_highbd_idct16x16_10_add/;
   } else {
+    add_proto qw/void vpx_idct4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct4x4_16_add sse2/;
+
+    add_proto qw/void vpx_idct4x4_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct4x4_1_add sse2/;
+
+    add_proto qw/void vpx_idct8x8_64_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_64_add sse2/;
+
+    add_proto qw/void vpx_idct8x8_12_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_12_add sse2/;
+
+    add_proto qw/void vpx_idct8x8_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct8x8_1_add sse2/;
+
+    add_proto qw/void vpx_idct16x16_256_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_256_add sse2/;
+
+    add_proto qw/void vpx_idct16x16_10_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_10_add sse2/;
+
+    add_proto qw/void vpx_idct16x16_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct16x16_1_add sse2/;
+
+    add_proto qw/void vpx_idct32x32_1024_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_1024_add sse2/;
+
+    add_proto qw/void vpx_idct32x32_34_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_34_add sse2/;
+
+    add_proto qw/void vpx_idct32x32_1_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride";
+    specialize qw/vpx_idct32x32_1_add sse2/;
+
     add_proto qw/void vpx_highbd_idct4x4_16_add/, "const tran_low_t *input, uint8_t *dest, int dest_stride, int bd";
     specialize qw/vpx_highbd_idct4x4_16_add sse2/;
 
--- a/vpx_dsp/x86/inv_txfm_sse2.c
+++ b/vpx_dsp/x86/inv_txfm_sse2.c
@@ -21,7 +21,8 @@
   *(int *)(dest) = _mm_cvtsi128_si32(d0); \
 }
 
-void vpx_idct4x4_16_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct4x4_16_add_sse2(const tran_low_t *input, uint8_t *dest,
+                             int stride) {
   const __m128i zero = _mm_setzero_si128();
   const __m128i eight = _mm_set1_epi16(8);
   const __m128i cst = _mm_setr_epi16(
@@ -32,8 +33,8 @@
   __m128i input0, input1, input2, input3;
 
   // Rows
-  input0 = _mm_load_si128((const __m128i *)input);
-  input2 = _mm_load_si128((const __m128i *)(input + 8));
+  input0 = load_input_data(input);
+  input2 = load_input_data(input + 8);
 
   // Construct i3, i1, i3, i1, i2, i0, i2, i0
   input0 = _mm_shufflelo_epi16(input0, 0xd8);
@@ -151,7 +152,8 @@
   }
 }
 
-void vpx_idct4x4_1_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct4x4_1_add_sse2(const tran_low_t *input, uint8_t *dest,
+                            int stride) {
   __m128i dc_value;
   const __m128i zero = _mm_setzero_si128();
   int a;
@@ -449,7 +451,8 @@
   out7 = _mm_subs_epi16(stp1_0, stp2_7); \
   }
 
-void vpx_idct8x8_64_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct8x8_64_add_sse2(const tran_low_t *input, uint8_t *dest,
+                             int stride) {
   const __m128i zero = _mm_setzero_si128();
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1 << 4);
@@ -469,14 +472,14 @@
   int i;
 
   // Load input data.
-  in0 = _mm_load_si128((const __m128i *)input);
-  in1 = _mm_load_si128((const __m128i *)(input + 8 * 1));
-  in2 = _mm_load_si128((const __m128i *)(input + 8 * 2));
-  in3 = _mm_load_si128((const __m128i *)(input + 8 * 3));
-  in4 = _mm_load_si128((const __m128i *)(input + 8 * 4));
-  in5 = _mm_load_si128((const __m128i *)(input + 8 * 5));
-  in6 = _mm_load_si128((const __m128i *)(input + 8 * 6));
-  in7 = _mm_load_si128((const __m128i *)(input + 8 * 7));
+  in0 = load_input_data(input);
+  in1 = load_input_data(input + 8 * 1);
+  in2 = load_input_data(input + 8 * 2);
+  in3 = load_input_data(input + 8 * 3);
+  in4 = load_input_data(input + 8 * 4);
+  in5 = load_input_data(input + 8 * 5);
+  in6 = load_input_data(input + 8 * 6);
+  in7 = load_input_data(input + 8 * 7);
 
   // 2-D
   for (i = 0; i < 2; i++) {
@@ -518,7 +521,8 @@
   RECON_AND_STORE(dest + 7 * stride, in7);
 }
 
-void vpx_idct8x8_1_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct8x8_1_add_sse2(const tran_low_t *input, uint8_t *dest,
+                            int stride) {
   __m128i dc_value;
   const __m128i zero = _mm_setzero_si128();
   int a;
@@ -792,7 +796,8 @@
   in[7] = _mm_sub_epi16(k__const_0, s1);
 }
 
-void vpx_idct8x8_12_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct8x8_12_add_sse2(const tran_low_t *input, uint8_t *dest,
+                             int stride) {
   const __m128i zero = _mm_setzero_si128();
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1 << 4);
@@ -812,10 +817,10 @@
   __m128i tmp0, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
 
   // Rows. Load 4-row input data.
-  in0 = _mm_load_si128((const __m128i *)input);
-  in1 = _mm_load_si128((const __m128i *)(input + 8 * 1));
-  in2 = _mm_load_si128((const __m128i *)(input + 8 * 2));
-  in3 = _mm_load_si128((const __m128i *)(input + 8 * 3));
+  in0 = load_input_data(input);
+  in1 = load_input_data(input + 8 * 1);
+  in2 = load_input_data(input + 8 * 2);
+  in3 = load_input_data(input + 8 * 3);
 
   // 8x4 Transpose
   TRANSPOSE_8X8_10(in0, in1, in2, in3, in0, in1);
@@ -1169,7 +1174,7 @@
                              stp2_10, stp2_13, stp2_11, stp2_12) \
     }
 
-void vpx_idct16x16_256_add_sse2(const int16_t *input, uint8_t *dest,
+void vpx_idct16x16_256_add_sse2(const tran_low_t *input, uint8_t *dest,
                                 int stride) {
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1 << 5);
@@ -1214,22 +1219,22 @@
     // 1-D idct
 
     // Load input data.
-    in[0] = _mm_load_si128((const __m128i *)input);
-    in[8] = _mm_load_si128((const __m128i *)(input + 8 * 1));
-    in[1] = _mm_load_si128((const __m128i *)(input + 8 * 2));
-    in[9] = _mm_load_si128((const __m128i *)(input + 8 * 3));
-    in[2] = _mm_load_si128((const __m128i *)(input + 8 * 4));
-    in[10] = _mm_load_si128((const __m128i *)(input + 8 * 5));
-    in[3] = _mm_load_si128((const __m128i *)(input + 8 * 6));
-    in[11] = _mm_load_si128((const __m128i *)(input + 8 * 7));
-    in[4] = _mm_load_si128((const __m128i *)(input + 8 * 8));
-    in[12] = _mm_load_si128((const __m128i *)(input + 8 * 9));
-    in[5] = _mm_load_si128((const __m128i *)(input + 8 * 10));
-    in[13] = _mm_load_si128((const __m128i *)(input + 8 * 11));
-    in[6] = _mm_load_si128((const __m128i *)(input + 8 * 12));
-    in[14] = _mm_load_si128((const __m128i *)(input + 8 * 13));
-    in[7] = _mm_load_si128((const __m128i *)(input + 8 * 14));
-    in[15] = _mm_load_si128((const __m128i *)(input + 8 * 15));
+    in[0] = load_input_data(input);
+    in[8] = load_input_data(input + 8 * 1);
+    in[1] = load_input_data(input + 8 * 2);
+    in[9] = load_input_data(input + 8 * 3);
+    in[2] = load_input_data(input + 8 * 4);
+    in[10] = load_input_data(input + 8 * 5);
+    in[3] = load_input_data(input + 8 * 6);
+    in[11] = load_input_data(input + 8 * 7);
+    in[4] = load_input_data(input + 8 * 8);
+    in[12] = load_input_data(input + 8 * 9);
+    in[5] = load_input_data(input + 8 * 10);
+    in[13] = load_input_data(input + 8 * 11);
+    in[6] = load_input_data(input + 8 * 12);
+    in[14] = load_input_data(input + 8 * 13);
+    in[7] = load_input_data(input + 8 * 14);
+    in[15] = load_input_data(input + 8 * 15);
 
     array_transpose_8x8(in, in);
     array_transpose_8x8(in + 8, in + 8);
@@ -1294,7 +1299,8 @@
   }
 }
 
-void vpx_idct16x16_1_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct16x16_1_add_sse2(const tran_low_t *input, uint8_t *dest,
+                              int stride) {
   __m128i dc_value;
   const __m128i zero = _mm_setzero_si128();
   int a, i;
@@ -2152,7 +2158,7 @@
   iadst16_8col(in1);
 }
 
-void vpx_idct16x16_10_add_sse2(const int16_t *input, uint8_t *dest,
+void vpx_idct16x16_10_add_sse2(const tran_low_t *input, uint8_t *dest,
                                int stride) {
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1 << 5);
@@ -2184,10 +2190,10 @@
   int i;
   // First 1-D inverse DCT
   // Load input data.
-  in[0] = _mm_load_si128((const __m128i *)input);
-  in[1] = _mm_load_si128((const __m128i *)(input + 8 * 2));
-  in[2] = _mm_load_si128((const __m128i *)(input + 8 * 4));
-  in[3] = _mm_load_si128((const __m128i *)(input + 8 * 6));
+  in[0] = load_input_data(input);
+  in[1] = load_input_data(input + 8 * 2);
+  in[2] = load_input_data(input + 8 * 4);
+  in[3] = load_input_data(input + 8 * 6);
 
   TRANSPOSE_8X4(in[0], in[1], in[2], in[3], in[0], in[1]);
 
@@ -2391,7 +2397,7 @@
 
 #define LOAD_DQCOEFF(reg, input) \
   {  \
-    reg = _mm_load_si128((const __m128i *) input); \
+    reg = load_input_data(input); \
     input += 8; \
   }  \
 
@@ -3029,7 +3035,7 @@
 }
 
 // Only upper-left 8x8 has non-zero coeff
-void vpx_idct32x32_34_add_sse2(const int16_t *input, uint8_t *dest,
+void vpx_idct32x32_34_add_sse2(const tran_low_t *input, uint8_t *dest,
                                int stride) {
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1<<5);
@@ -3081,14 +3087,14 @@
   int i;
 
   // Load input data. Only need to load the top left 8x8 block.
-  in[0] = _mm_load_si128((const __m128i *)input);
-  in[1] = _mm_load_si128((const __m128i *)(input + 32));
-  in[2] = _mm_load_si128((const __m128i *)(input + 64));
-  in[3] = _mm_load_si128((const __m128i *)(input + 96));
-  in[4] = _mm_load_si128((const __m128i *)(input + 128));
-  in[5] = _mm_load_si128((const __m128i *)(input + 160));
-  in[6] = _mm_load_si128((const __m128i *)(input + 192));
-  in[7] = _mm_load_si128((const __m128i *)(input + 224));
+  in[0] = load_input_data(input);
+  in[1] = load_input_data(input + 32);
+  in[2] = load_input_data(input + 64);
+  in[3] = load_input_data(input + 96);
+  in[4] = load_input_data(input + 128);
+  in[5] = load_input_data(input + 160);
+  in[6] = load_input_data(input + 192);
+  in[7] = load_input_data(input + 224);
 
   for (i = 8; i < 32; ++i) {
     in[i] = _mm_setzero_si128();
@@ -3188,7 +3194,7 @@
   }
 }
 
-void vpx_idct32x32_1024_add_sse2(const int16_t *input, uint8_t *dest,
+void vpx_idct32x32_1024_add_sse2(const tran_low_t *input, uint8_t *dest,
                                  int stride) {
   const __m128i rounding = _mm_set1_epi32(DCT_CONST_ROUNDING);
   const __m128i final_rounding = _mm_set1_epi16(1 << 5);
@@ -3464,7 +3470,8 @@
   }
 }
 
-void vpx_idct32x32_1_add_sse2(const int16_t *input, uint8_t *dest, int stride) {
+void vpx_idct32x32_1_add_sse2(const tran_low_t *input, uint8_t *dest,
+                              int stride) {
   __m128i dc_value;
   const __m128i zero = _mm_setzero_si128();
   int a, i;
--- a/vpx_dsp/x86/inv_txfm_sse2.h
+++ b/vpx_dsp/x86/inv_txfm_sse2.h
@@ -15,6 +15,7 @@
 #include "./vpx_config.h"
 #include "vpx/vpx_integer.h"
 #include "vpx_dsp/inv_txfm.h"
+#include "vpx_dsp/x86/txfm_common_sse2.h"
 
 // perform 8x8 transpose
 static INLINE void array_transpose_8x8(__m128i *in, __m128i *res) {
@@ -89,24 +90,35 @@
   res0[15] = tbuf[7];
 }
 
-static INLINE void load_buffer_8x16(const int16_t *input, __m128i *in) {
-  in[0]  = _mm_load_si128((const __m128i *)(input + 0 * 16));
-  in[1]  = _mm_load_si128((const __m128i *)(input + 1 * 16));
-  in[2]  = _mm_load_si128((const __m128i *)(input + 2 * 16));
-  in[3]  = _mm_load_si128((const __m128i *)(input + 3 * 16));
-  in[4]  = _mm_load_si128((const __m128i *)(input + 4 * 16));
-  in[5]  = _mm_load_si128((const __m128i *)(input + 5 * 16));
-  in[6]  = _mm_load_si128((const __m128i *)(input + 6 * 16));
-  in[7]  = _mm_load_si128((const __m128i *)(input + 7 * 16));
+// Function to allow 8 bit optimisations to be used when profile 0 is used with
+// highbitdepth enabled
+static INLINE __m128i load_input_data(const tran_low_t *data) {
+#if CONFIG_VP9_HIGHBITDEPTH
+  return octa_set_epi16(data[0], data[1], data[2], data[3], data[4], data[5],
+      data[6], data[7]);
+#else
+  return _mm_load_si128((const __m128i *)data);
+#endif
+}
 
-  in[8]  = _mm_load_si128((const __m128i *)(input + 8 * 16));
-  in[9]  = _mm_load_si128((const __m128i *)(input + 9 * 16));
-  in[10]  = _mm_load_si128((const __m128i *)(input + 10 * 16));
-  in[11]  = _mm_load_si128((const __m128i *)(input + 11 * 16));
-  in[12]  = _mm_load_si128((const __m128i *)(input + 12 * 16));
-  in[13]  = _mm_load_si128((const __m128i *)(input + 13 * 16));
-  in[14]  = _mm_load_si128((const __m128i *)(input + 14 * 16));
-  in[15]  = _mm_load_si128((const __m128i *)(input + 15 * 16));
+static INLINE void load_buffer_8x16(const tran_low_t *input, __m128i *in) {
+  in[0]  = load_input_data(input + 0 * 16);
+  in[1]  = load_input_data(input + 1 * 16);
+  in[2]  = load_input_data(input + 2 * 16);
+  in[3]  = load_input_data(input + 3 * 16);
+  in[4]  = load_input_data(input + 4 * 16);
+  in[5]  = load_input_data(input + 5 * 16);
+  in[6]  = load_input_data(input + 6 * 16);
+  in[7]  = load_input_data(input + 7 * 16);
+
+  in[8]  = load_input_data(input + 8 * 16);
+  in[9]  = load_input_data(input + 9 * 16);
+  in[10]  = load_input_data(input + 10 * 16);
+  in[11]  = load_input_data(input + 11 * 16);
+  in[12]  = load_input_data(input + 12 * 16);
+  in[13]  = load_input_data(input + 13 * 16);
+  in[14]  = load_input_data(input + 14 * 16);
+  in[15]  = load_input_data(input + 15 * 16);
 }
 
 #define RECON_AND_STORE(dest, in_x) \