shithub: libvpx

Download patch

ref: f1d090e2f56e28b3218fb99f648bd5d59ae2f4b1
parent: 697a8e6fe606d67e84164de46e0dbb198531f613
parent: d76e5b3652fc7022141171cb7bb123607559104f
author: Hui Su <[email protected]>
date: Mon Aug 24 15:57:22 EDT 2015

Merge "Refactoring on transform types"

--- a/vp10/common/idct.c
+++ b/vp10/common/idct.c
@@ -178,33 +178,76 @@
     vpx_idct32x32_1024_add(input, dest, stride);
 }
 
-// iht
-void vp10_iht4x4_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                    int stride, int eob) {
-  if (tx_type == DCT_DCT)
-    vp10_idct4x4_add(input, dest, stride, eob);
-  else
-    vp10_iht4x4_16_add(input, dest, stride, tx_type);
+void vp10_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
+                           int stride, int eob, TX_TYPE tx_type,
+                           void (*itxm_add_4x4)(const tran_low_t *input,
+                               uint8_t *dest, int stride, int eob)) {
+  switch (tx_type) {
+    case DCT_DCT:
+      itxm_add_4x4(input, dest, stride, eob);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_iht4x4_16_add(input, dest, stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
 }
 
-void vp10_iht8x8_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                    int stride, int eob) {
-  if (tx_type == DCT_DCT) {
-    vp10_idct8x8_add(input, dest, stride, eob);
-  } else {
-    vp10_iht8x8_64_add(input, dest, stride, tx_type);
+void vp10_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
+                           int stride, int eob, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_idct8x8_add(input, dest, stride, eob);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_iht8x8_64_add(input, dest, stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
   }
 }
 
-void vp10_iht16x16_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                      int stride, int eob) {
-  if (tx_type == DCT_DCT) {
-    vp10_idct16x16_add(input, dest, stride, eob);
-  } else {
-    vp10_iht16x16_256_add(input, dest, stride, tx_type);
+void vp10_inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest,
+                             int stride, int eob, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_idct16x16_add(input, dest, stride, eob);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_iht16x16_256_add(input, dest, stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
   }
 }
 
+void vp10_inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest,
+                             int stride, int eob, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_idct32x32_add(input, dest, stride, eob);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      assert(0);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
 #if CONFIG_VP9_HIGHBITDEPTH
 void vp10_highbd_iht4x4_16_add_c(const tran_low_t *input, uint8_t *dest8,
                                 int stride, int tx_type, int bd) {
@@ -373,30 +416,77 @@
   }
 }
 
-// iht
-void vp10_highbd_iht4x4_add(TX_TYPE tx_type, const tran_low_t *input,
-                           uint8_t *dest, int stride, int eob, int bd) {
-  if (tx_type == DCT_DCT)
-    vp10_highbd_idct4x4_add(input, dest, stride, eob, bd);
-  else
-    vp10_highbd_iht4x4_16_add(input, dest, stride, tx_type, bd);
+void vp10_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
+                                  int stride, int eob, int bd, TX_TYPE tx_type,
+                                  void (*highbd_itxm_add_4x4)
+                                  (const tran_low_t *input, uint8_t *dest,
+                                      int stride, int eob, int bd)) {
+  switch (tx_type) {
+    case DCT_DCT:
+      highbd_itxm_add_4x4(input, dest, stride, eob, bd);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_iht4x4_16_add(input, dest, stride, tx_type, bd);
+      break;
+    default:
+      assert(0);
+      break;
+  }
 }
 
-void vp10_highbd_iht8x8_add(TX_TYPE tx_type, const tran_low_t *input,
-                           uint8_t *dest, int stride, int eob, int bd) {
-  if (tx_type == DCT_DCT) {
-    vp10_highbd_idct8x8_add(input, dest, stride, eob, bd);
-  } else {
-    vp10_highbd_iht8x8_64_add(input, dest, stride, tx_type, bd);
+void vp10_highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
+                                  int stride, int eob, int bd,
+                                  TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_highbd_idct8x8_add(input, dest, stride, eob, bd);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_iht8x8_64_add(input, dest, stride, tx_type, bd);
+      break;
+    default:
+      assert(0);
+      break;
   }
 }
 
-void vp10_highbd_iht16x16_add(TX_TYPE tx_type, const tran_low_t *input,
-                           uint8_t *dest, int stride, int eob, int bd) {
-  if (tx_type == DCT_DCT) {
-    vp10_highbd_idct16x16_add(input, dest, stride, eob, bd);
-  } else {
-    vp10_highbd_iht16x16_256_add(input, dest, stride, tx_type, bd);
+void vp10_highbd_inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest,
+                                    int stride, int eob, int bd,
+                                    TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_highbd_idct16x16_add(input, dest, stride, eob, bd);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_iht16x16_256_add(input, dest, stride, tx_type, bd);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+void vp10_highbd_inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest,
+                                    int stride, int eob, int bd,
+                                    TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vp10_highbd_idct32x32_add(input, dest, stride, eob, bd);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      assert(0);
+      break;
+    default:
+      assert(0);
+      break;
   }
 }
 #endif  // CONFIG_VP9_HIGHBITDEPTH
--- a/vp10/common/idct.h
+++ b/vp10/common/idct.h
@@ -42,19 +42,17 @@
                      int eob);
 void vp10_idct4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
                      int eob);
-void vp10_idct8x8_add(const tran_low_t *input, uint8_t *dest, int stride,
-                     int eob);
-void vp10_idct16x16_add(const tran_low_t *input, uint8_t *dest, int stride,
-                       int eob);
-void vp10_idct32x32_add(const tran_low_t *input, uint8_t *dest, int stride,
-                       int eob);
 
-void vp10_iht4x4_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                    int stride, int eob);
-void vp10_iht8x8_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                    int stride, int eob);
-void vp10_iht16x16_add(TX_TYPE tx_type, const tran_low_t *input, uint8_t *dest,
-                      int stride, int eob);
+void vp10_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
+                           int stride, int eob, TX_TYPE tx_type,
+                           void (*itxm_add_4x4)(const tran_low_t *input,
+                               uint8_t *dest, int stride, int eob));
+void vp10_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
+                           int stride, int eob, TX_TYPE tx_type);
+void vp10_inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest,
+                             int stride, int eob, TX_TYPE tx_type);
+void vp10_inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest,
+                             int stride, int eob, TX_TYPE tx_type);
 
 #if CONFIG_VP9_HIGHBITDEPTH
 void vp10_highbd_iwht4x4_add(const tran_low_t *input, uint8_t *dest, int stride,
@@ -67,12 +65,19 @@
                               int stride, int eob, int bd);
 void vp10_highbd_idct32x32_add(const tran_low_t *input, uint8_t *dest,
                               int stride, int eob, int bd);
-void vp10_highbd_iht4x4_add(TX_TYPE tx_type, const tran_low_t *input,
-                           uint8_t *dest, int stride, int eob, int bd);
-void vp10_highbd_iht8x8_add(TX_TYPE tx_type, const tran_low_t *input,
-                           uint8_t *dest, int stride, int eob, int bd);
-void vp10_highbd_iht16x16_add(TX_TYPE tx_type, const tran_low_t *input,
-                             uint8_t *dest, int stride, int eob, int bd);
+void vp10_highbd_inv_txfm_add_4x4(const tran_low_t *input, uint8_t *dest,
+                                  int stride, int eob, int bd, TX_TYPE tx_type,
+                                  void (*highbd_itxm_add_4x4)
+                                  (const tran_low_t *input, uint8_t *dest,
+                                      int stride, int eob, int bd));
+void vp10_highbd_inv_txfm_add_8x8(const tran_low_t *input, uint8_t *dest,
+                                  int stride, int eob, int bd, TX_TYPE tx_type);
+void vp10_highbd_inv_txfm_add_16x16(const tran_low_t *input, uint8_t *dest,
+                                    int stride, int eob, int bd,
+                                    TX_TYPE tx_type);
+void vp10_highbd_inv_txfm_add_32x32(const tran_low_t *input, uint8_t *dest,
+                                    int stride, int eob, int bd,
+                                    TX_TYPE tx_type);
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 #ifdef __cplusplus
 }  // extern "C"
--- a/vp10/decoder/decodeframe.c
+++ b/vp10/decoder/decodeframe.c
@@ -186,76 +186,59 @@
 static void inverse_transform_block_inter(MACROBLOCKD* xd, int plane,
                                           const TX_SIZE tx_size,
                                           uint8_t *dst, int stride,
-                                          int eob) {
+                                          int eob, int block) {
   struct macroblockd_plane *const pd = &xd->plane[plane];
+  TX_TYPE tx_type = get_tx_type(pd->plane_type, xd, block);
   if (eob > 0) {
     tran_low_t *const dqcoeff = pd->dqcoeff;
 #if CONFIG_VP9_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-      if (xd->lossless) {
-        vp10_highbd_iwht4x4_add(dqcoeff, dst, stride, eob, xd->bd);
-      } else {
-        switch (tx_size) {
-          case TX_4X4:
-            vp10_highbd_idct4x4_add(dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_8X8:
-            vp10_highbd_idct8x8_add(dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_16X16:
-            vp10_highbd_idct16x16_add(dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_32X32:
-            vp10_highbd_idct32x32_add(dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          default:
-            assert(0 && "Invalid transform size");
-        }
+      switch (tx_size) {
+        case TX_4X4:
+          vp10_highbd_inv_txfm_add_4x4(dqcoeff, dst, stride, eob, xd->bd,
+                                       tx_type, xd->lossless ?
+                                           vp10_highbd_iwht4x4_add :
+                                           vp10_highbd_idct4x4_add);
+          break;
+        case TX_8X8:
+          vp10_highbd_inv_txfm_add_8x8(dqcoeff, dst, stride, eob, xd->bd,
+                                       tx_type);
+          break;
+        case TX_16X16:
+          vp10_highbd_inv_txfm_add_16x16(dqcoeff, dst, stride, eob, xd->bd,
+                                         tx_type);
+          break;
+        case TX_32X32:
+          vp10_highbd_inv_txfm_add_32x32(dqcoeff, dst, stride, eob, xd->bd,
+                                         tx_type);
+          break;
+        default:
+          assert(0 && "Invalid transform size");
+          return;
       }
     } else {
-      if (xd->lossless) {
-        vp10_iwht4x4_add(dqcoeff, dst, stride, eob);
-      } else {
-        switch (tx_size) {
-          case TX_4X4:
-            vp10_idct4x4_add(dqcoeff, dst, stride, eob);
-            break;
-          case TX_8X8:
-            vp10_idct8x8_add(dqcoeff, dst, stride, eob);
-            break;
-          case TX_16X16:
-            vp10_idct16x16_add(dqcoeff, dst, stride, eob);
-            break;
-          case TX_32X32:
-            vp10_idct32x32_add(dqcoeff, dst, stride, eob);
-            break;
-          default:
-            assert(0 && "Invalid transform size");
-            return;
-        }
-      }
-    }
-#else
-    if (xd->lossless) {
-      vp10_iwht4x4_add(dqcoeff, dst, stride, eob);
-    } else {
+#else  // CONFIG_VP9_HIGHBITDEPTH
       switch (tx_size) {
         case TX_4X4:
-          vp10_idct4x4_add(dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_4x4(dqcoeff, dst, stride, eob, tx_type,
+                                xd->lossless ? vp10_iwht4x4_add :
+                                    vp10_idct4x4_add);
           break;
         case TX_8X8:
-          vp10_idct8x8_add(dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_8x8(dqcoeff, dst, stride, eob, tx_type);
           break;
         case TX_16X16:
-          vp10_idct16x16_add(dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_16x16(dqcoeff, dst, stride, eob, tx_type);
           break;
         case TX_32X32:
-          vp10_idct32x32_add(dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_32x32(dqcoeff, dst, stride, eob, tx_type);
           break;
         default:
           assert(0 && "Invalid transform size");
           return;
       }
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+#if CONFIG_VP9_HIGHBITDEPTH
     }
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 
@@ -282,70 +265,52 @@
     tran_low_t *const dqcoeff = pd->dqcoeff;
 #if CONFIG_VP9_HIGHBITDEPTH
     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
-      if (xd->lossless) {
-        vp10_highbd_iwht4x4_add(dqcoeff, dst, stride, eob, xd->bd);
-      } else {
-        switch (tx_size) {
-          case TX_4X4:
-            vp10_highbd_iht4x4_add(tx_type, dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_8X8:
-            vp10_highbd_iht8x8_add(tx_type, dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_16X16:
-            vp10_highbd_iht16x16_add(tx_type, dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          case TX_32X32:
-            vp10_highbd_idct32x32_add(dqcoeff, dst, stride, eob, xd->bd);
-            break;
-          default:
-            assert(0 && "Invalid transform size");
-        }
+      switch (tx_size) {
+        case TX_4X4:
+          vp10_highbd_inv_txfm_add_4x4(dqcoeff, dst, stride, eob, xd->bd,
+                                       tx_type, xd->lossless ?
+                                           vp10_highbd_iwht4x4_add :
+                                           vp10_highbd_idct4x4_add);
+          break;
+        case TX_8X8:
+          vp10_highbd_inv_txfm_add_8x8(dqcoeff, dst, stride, eob, xd->bd,
+                                       tx_type);
+          break;
+        case TX_16X16:
+          vp10_highbd_inv_txfm_add_16x16(dqcoeff, dst, stride, eob, xd->bd,
+                                         tx_type);
+          break;
+        case TX_32X32:
+          vp10_highbd_inv_txfm_add_32x32(dqcoeff, dst, stride, eob, xd->bd,
+                                         tx_type);
+          break;
+        default:
+          assert(0 && "Invalid transform size");
+          return;
       }
     } else {
-      if (xd->lossless) {
-        vp10_iwht4x4_add(dqcoeff, dst, stride, eob);
-      } else {
-        switch (tx_size) {
-          case TX_4X4:
-            vp10_iht4x4_add(tx_type, dqcoeff, dst, stride, eob);
-            break;
-          case TX_8X8:
-            vp10_iht8x8_add(tx_type, dqcoeff, dst, stride, eob);
-            break;
-          case TX_16X16:
-            vp10_iht16x16_add(tx_type, dqcoeff, dst, stride, eob);
-            break;
-          case TX_32X32:
-            vp10_idct32x32_add(dqcoeff, dst, stride, eob);
-            break;
-          default:
-            assert(0 && "Invalid transform size");
-            return;
-        }
-      }
-    }
-#else
-    if (xd->lossless) {
-      vp10_iwht4x4_add(dqcoeff, dst, stride, eob);
-    } else {
+#else  // CONFIG_VP9_HIGHBITDEPTH
       switch (tx_size) {
         case TX_4X4:
-          vp10_iht4x4_add(tx_type, dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_4x4(dqcoeff, dst, stride, eob, tx_type,
+                                xd->lossless ? vp10_iwht4x4_add :
+                                    vp10_idct4x4_add);
           break;
         case TX_8X8:
-          vp10_iht8x8_add(tx_type, dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_8x8(dqcoeff, dst, stride, eob, tx_type);
           break;
         case TX_16X16:
-          vp10_iht16x16_add(tx_type, dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_16x16(dqcoeff, dst, stride, eob, tx_type);
           break;
         case TX_32X32:
-          vp10_idct32x32_add(dqcoeff, dst, stride, eob);
+          vp10_inv_txfm_add_32x32(dqcoeff, dst, stride, eob, tx_type);
           break;
         default:
           assert(0 && "Invalid transform size");
           return;
       }
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+#if CONFIG_VP9_HIGHBITDEPTH
     }
 #endif  // CONFIG_VP9_HIGHBITDEPTH
 
@@ -406,7 +371,7 @@
 
   inverse_transform_block_inter(xd, plane, tx_size,
                             &pd->dst.buf[4 * row * pd->dst.stride + 4 * col],
-                            pd->dst.stride, eob);
+                            pd->dst.stride, eob, block_idx);
   return eob;
 }
 
--- a/vp10/encoder/encodemb.c
+++ b/vp10/encoder/encodemb.c
@@ -496,6 +496,146 @@
   }
 }
 
+void vp10_fwd_txfm_4x4(const int16_t *src_diff,
+                       tran_low_t *coeff, int diff_stride, TX_TYPE tx_type,
+                       void (*fwd_txm4x4)(const int16_t *input,
+                           tran_low_t *output, int stride)) {
+  switch (tx_type) {
+    case DCT_DCT:
+      fwd_txm4x4(src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_fht4x4(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff,
+                         int diff_stride, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_fht8x8(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void fwd_txfm_16x16(const int16_t *src_diff, tran_low_t *coeff,
+                           int diff_stride, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_fht16x16(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void fwd_txfm_32x32(int rd_transform, const int16_t *src_diff,
+                           tran_low_t *coeff, int diff_stride,
+                           TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      fdct32x32(rd_transform, src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      assert(0);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+#if CONFIG_VP9_HIGHBITDEPTH
+void vp10_highbd_fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff,
+                              int diff_stride, TX_TYPE tx_type,
+                              void (*highbd_fwd_txm4x4)(const int16_t *input,
+                                  tran_low_t *output, int stride)) {
+  switch (tx_type) {
+    case DCT_DCT:
+      highbd_fwd_txm4x4(src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_fht4x4(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void highbd_fwd_txfm_8x8(const int16_t *src_diff, tran_low_t *coeff,
+                         int diff_stride, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vpx_highbd_fdct8x8(src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_fht8x8(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void highbd_fwd_txfm_16x16(const int16_t *src_diff, tran_low_t *coeff,
+                           int diff_stride, TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      vpx_highbd_fdct16x16(src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      vp10_highbd_fht16x16(src_diff, coeff, diff_stride, tx_type);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+
+static void highbd_fwd_txfm_32x32(int rd_transform, const int16_t *src_diff,
+                                  tran_low_t *coeff, int diff_stride,
+                                  TX_TYPE tx_type) {
+  switch (tx_type) {
+    case DCT_DCT:
+      highbd_fdct32x32(rd_transform, src_diff, coeff, diff_stride);
+      break;
+    case ADST_DCT:
+    case DCT_ADST:
+    case ADST_ADST:
+      assert(0);
+      break;
+    default:
+      assert(0);
+      break;
+  }
+}
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+
 void vp10_xform_quant(MACROBLOCK *x, int plane, int block,
                      BLOCK_SIZE plane_bsize, TX_SIZE tx_size) {
   MACROBLOCKD *const xd = &x->e_mbd;
@@ -518,7 +658,8 @@
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
      switch (tx_size) {
       case TX_32X32:
-        highbd_fdct32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride);
+        highbd_fwd_txfm_32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride,
+                         tx_type);
         vpx_highbd_quantize_b_32x32(coeff, 1024, x->skip_block, p->zbin,
                                     p->round, p->quant, p->quant_shift, qcoeff,
                                     dqcoeff, pd->dequant, eob,
@@ -525,7 +666,7 @@
                                     scan_order->scan, scan_order->iscan);
         break;
       case TX_16X16:
-        vpx_highbd_fdct16x16(src_diff, coeff, diff_stride);
+        highbd_fwd_txfm_16x16(src_diff, coeff, diff_stride, tx_type);
         vpx_highbd_quantize_b(coeff, 256, x->skip_block, p->zbin, p->round,
                               p->quant, p->quant_shift, qcoeff, dqcoeff,
                               pd->dequant, eob,
@@ -532,7 +673,7 @@
                               scan_order->scan, scan_order->iscan);
         break;
       case TX_8X8:
-        vpx_highbd_fdct8x8(src_diff, coeff, diff_stride);
+        highbd_fwd_txfm_8x8(src_diff, coeff, diff_stride, tx_type);
         vpx_highbd_quantize_b(coeff, 64, x->skip_block, p->zbin, p->round,
                               p->quant, p->quant_shift, qcoeff, dqcoeff,
                               pd->dequant, eob,
@@ -539,7 +680,8 @@
                               scan_order->scan, scan_order->iscan);
         break;
       case TX_4X4:
-        x->fwd_txm4x4(src_diff, coeff, diff_stride);
+        vp10_highbd_fwd_txfm_4x4(src_diff, coeff, diff_stride, tx_type,
+                                 x->fwd_txm4x4);
         vpx_highbd_quantize_b(coeff, 16, x->skip_block, p->zbin, p->round,
                               p->quant, p->quant_shift, qcoeff, dqcoeff,
                               pd->dequant, eob,
@@ -554,7 +696,7 @@
 
   switch (tx_size) {
     case TX_32X32:
-      fdct32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride);
+      fwd_txfm_32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride, tx_type);
       vpx_quantize_b_32x32(coeff, 1024, x->skip_block, p->zbin, p->round,
                            p->quant, p->quant_shift, qcoeff, dqcoeff,
                            pd->dequant, eob, scan_order->scan,
@@ -561,7 +703,7 @@
                            scan_order->iscan);
       break;
     case TX_16X16:
-      vpx_fdct16x16(src_diff, coeff, diff_stride);
+      fwd_txfm_16x16(src_diff, coeff, diff_stride, tx_type);
       vpx_quantize_b(coeff, 256, x->skip_block, p->zbin, p->round,
                      p->quant, p->quant_shift, qcoeff, dqcoeff,
                      pd->dequant, eob,
@@ -568,7 +710,7 @@
                      scan_order->scan, scan_order->iscan);
       break;
     case TX_8X8:
-      vpx_fdct8x8(src_diff, coeff, diff_stride);
+      fwd_txfm_8x8(src_diff, coeff, diff_stride, tx_type);
       vpx_quantize_b(coeff, 64, x->skip_block, p->zbin, p->round,
                      p->quant, p->quant_shift, qcoeff, dqcoeff,
                      pd->dequant, eob,
@@ -575,7 +717,7 @@
                      scan_order->scan, scan_order->iscan);
       break;
     case TX_4X4:
-      x->fwd_txm4x4(src_diff, coeff, diff_stride);
+      vp10_fwd_txfm_4x4(src_diff, coeff, diff_stride, tx_type, x->fwd_txm4x4);
       vpx_quantize_b(coeff, 16, x->skip_block, p->zbin, p->round,
                      p->quant, p->quant_shift, qcoeff, dqcoeff,
                      pd->dequant, eob,
@@ -599,6 +741,7 @@
   int i, j;
   uint8_t *dst;
   ENTROPY_CONTEXT *a, *l;
+  TX_TYPE tx_type = get_tx_type(pd->plane_type, xd, block);
   txfrm_block_to_raster_xy(plane_bsize, tx_size, block, &i, &j);
   dst = &pd->dst.buf[4 * j * pd->dst.stride + 4 * i];
   a = &ctx->ta[plane][i];
@@ -660,27 +803,30 @@
   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
     switch (tx_size) {
       case TX_32X32:
-        vp10_highbd_idct32x32_add(dqcoeff, dst, pd->dst.stride,
-                                 p->eobs[block], xd->bd);
+        vp10_highbd_inv_txfm_add_32x32(dqcoeff, dst, pd->dst.stride,
+                                       p->eobs[block], xd->bd, tx_type);
         break;
       case TX_16X16:
-        vp10_highbd_idct16x16_add(dqcoeff, dst, pd->dst.stride,
-                                 p->eobs[block], xd->bd);
+        vp10_highbd_inv_txfm_add_16x16(dqcoeff, dst, pd->dst.stride,
+                                       p->eobs[block], xd->bd, tx_type);
         break;
       case TX_8X8:
-        vp10_highbd_idct8x8_add(dqcoeff, dst, pd->dst.stride,
-                               p->eobs[block], xd->bd);
+        vp10_highbd_inv_txfm_add_8x8(dqcoeff, dst, pd->dst.stride,
+                                     p->eobs[block], xd->bd, tx_type);
         break;
       case TX_4X4:
         // this is like vp10_short_idct4x4 but has a special case around eob<=1
         // which is significant (not just an optimization) for the lossless
         // case.
-        x->highbd_itxm_add(dqcoeff, dst, pd->dst.stride,
-                           p->eobs[block], xd->bd);
+        vp10_highbd_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride,
+                                     p->eobs[block], xd->bd, tx_type,
+                                     x->highbd_itxm_add);
         break;
       default:
         assert(0 && "Invalid transform size");
+        break;
     }
+
     return;
   }
 #endif  // CONFIG_VP9_HIGHBITDEPTH
@@ -687,19 +833,23 @@
 
   switch (tx_size) {
     case TX_32X32:
-      vp10_idct32x32_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+      vp10_inv_txfm_add_32x32(dqcoeff, dst, pd->dst.stride, p->eobs[block],
+                              tx_type);
       break;
     case TX_16X16:
-      vp10_idct16x16_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+      vp10_inv_txfm_add_16x16(dqcoeff, dst, pd->dst.stride, p->eobs[block],
+                              tx_type);
       break;
     case TX_8X8:
-      vp10_idct8x8_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+      vp10_inv_txfm_add_8x8(dqcoeff, dst, pd->dst.stride, p->eobs[block],
+                            tx_type);
       break;
     case TX_4X4:
       // this is like vp10_short_idct4x4 but has a special case around eob<=1
       // which is significant (not just an optimization) for the lossless
       // case.
-      x->itxm_add(dqcoeff, dst, pd->dst.stride, p->eobs[block]);
+      vp10_inv_txfm_add_4x4(dqcoeff, dst, pd->dst.stride, p->eobs[block],
+                            tx_type, x->itxm_add);
       break;
     default:
       assert(0 && "Invalid transform size");
@@ -806,60 +956,51 @@
         if (!x->skip_recode) {
           vpx_highbd_subtract_block(32, 32, src_diff, diff_stride,
                                     src, src_stride, dst, dst_stride, xd->bd);
-          highbd_fdct32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride);
+          highbd_fwd_txfm_32x32(x->use_lp32x32fdct, src_diff, coeff,
+                                diff_stride, tx_type);
           vpx_highbd_quantize_b_32x32(coeff, 1024, x->skip_block, p->zbin,
                                       p->round, p->quant, p->quant_shift,
                                       qcoeff, dqcoeff, pd->dequant, eob,
                                       scan_order->scan, scan_order->iscan);
         }
-        if (!x->skip_encode && *eob) {
-          vp10_highbd_idct32x32_add(dqcoeff, dst, dst_stride, *eob, xd->bd);
-        }
+        if (!x->skip_encode && *eob)
+          vp10_highbd_inv_txfm_add_32x32(dqcoeff, dst, dst_stride, *eob, xd->bd,
+                                         tx_type);
         break;
       case TX_16X16:
         if (!x->skip_recode) {
           vpx_highbd_subtract_block(16, 16, src_diff, diff_stride,
                                     src, src_stride, dst, dst_stride, xd->bd);
-          if (tx_type == DCT_DCT)
-            vpx_highbd_fdct16x16(src_diff, coeff, diff_stride);
-          else
-            vp10_highbd_fht16x16(src_diff, coeff, diff_stride, tx_type);
+          highbd_fwd_txfm_16x16(src_diff, coeff, diff_stride, tx_type);
           vpx_highbd_quantize_b(coeff, 256, x->skip_block, p->zbin, p->round,
                                 p->quant, p->quant_shift, qcoeff, dqcoeff,
                                 pd->dequant, eob,
                                 scan_order->scan, scan_order->iscan);
         }
-        if (!x->skip_encode && *eob) {
-          vp10_highbd_iht16x16_add(tx_type, dqcoeff, dst, dst_stride,
-                                  *eob, xd->bd);
-        }
+        if (!x->skip_encode && *eob)
+          vp10_highbd_inv_txfm_add_16x16(dqcoeff, dst, dst_stride, *eob, xd->bd,
+                                         tx_type);
         break;
       case TX_8X8:
         if (!x->skip_recode) {
           vpx_highbd_subtract_block(8, 8, src_diff, diff_stride,
                                     src, src_stride, dst, dst_stride, xd->bd);
-          if (tx_type == DCT_DCT)
-            vpx_highbd_fdct8x8(src_diff, coeff, diff_stride);
-          else
-            vp10_highbd_fht8x8(src_diff, coeff, diff_stride, tx_type);
+          highbd_fwd_txfm_8x8(src_diff, coeff, diff_stride, tx_type);
           vpx_highbd_quantize_b(coeff, 64, x->skip_block, p->zbin, p->round,
                                 p->quant, p->quant_shift, qcoeff, dqcoeff,
                                 pd->dequant, eob,
                                 scan_order->scan, scan_order->iscan);
         }
-        if (!x->skip_encode && *eob) {
-          vp10_highbd_iht8x8_add(tx_type, dqcoeff, dst, dst_stride, *eob,
-                                xd->bd);
-        }
+        if (!x->skip_encode && *eob)
+          vp10_highbd_inv_txfm_add_8x8(dqcoeff, dst, dst_stride, *eob, xd->bd,
+                                       tx_type);
         break;
       case TX_4X4:
         if (!x->skip_recode) {
           vpx_highbd_subtract_block(4, 4, src_diff, diff_stride,
                                     src, src_stride, dst, dst_stride, xd->bd);
-          if (tx_type != DCT_DCT)
-            vp10_highbd_fht4x4(src_diff, coeff, diff_stride, tx_type);
-          else
-            x->fwd_txm4x4(src_diff, coeff, diff_stride);
+          vp10_highbd_fwd_txfm_4x4(src_diff, coeff, diff_stride, tx_type,
+                                   x->fwd_txm4x4);
           vpx_highbd_quantize_b(coeff, 16, x->skip_block, p->zbin, p->round,
                                 p->quant, p->quant_shift, qcoeff, dqcoeff,
                                 pd->dequant, eob,
@@ -866,16 +1007,12 @@
                                 scan_order->scan, scan_order->iscan);
         }
 
-        if (!x->skip_encode && *eob) {
-          if (tx_type == DCT_DCT) {
-            // this is like vp10_short_idct4x4 but has a special case around
-            // eob<=1 which is significant (not just an optimization) for the
-            // lossless case.
-            x->highbd_itxm_add(dqcoeff, dst, dst_stride, *eob, xd->bd);
-          } else {
-            vp10_highbd_iht4x4_16_add(dqcoeff, dst, dst_stride, tx_type, xd->bd);
-          }
-        }
+        if (!x->skip_encode && *eob)
+          // this is like vp10_short_idct4x4 but has a special case around
+          // eob<=1 which is significant (not just an optimization) for the
+          // lossless case.
+          vp10_highbd_inv_txfm_add_4x4(dqcoeff, dst, dst_stride, *eob, xd->bd,
+                                       tx_type, x->highbd_itxm_add);
         break;
       default:
         assert(0);
@@ -892,7 +1029,8 @@
       if (!x->skip_recode) {
         vpx_subtract_block(32, 32, src_diff, diff_stride,
                            src, src_stride, dst, dst_stride);
-        fdct32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride);
+        fwd_txfm_32x32(x->use_lp32x32fdct, src_diff, coeff, diff_stride,
+                       tx_type);
         vpx_quantize_b_32x32(coeff, 1024, x->skip_block, p->zbin, p->round,
                              p->quant, p->quant_shift, qcoeff, dqcoeff,
                              pd->dequant, eob, scan_order->scan,
@@ -899,13 +1037,13 @@
                              scan_order->iscan);
       }
       if (!x->skip_encode && *eob)
-        vp10_idct32x32_add(dqcoeff, dst, dst_stride, *eob);
+        vp10_inv_txfm_add_32x32(dqcoeff, dst, dst_stride, *eob, tx_type);
       break;
     case TX_16X16:
       if (!x->skip_recode) {
         vpx_subtract_block(16, 16, src_diff, diff_stride,
                            src, src_stride, dst, dst_stride);
-        vp10_fht16x16(src_diff, coeff, diff_stride, tx_type);
+        fwd_txfm_16x16(src_diff, coeff, diff_stride, tx_type);
         vpx_quantize_b(coeff, 256, x->skip_block, p->zbin, p->round,
                        p->quant, p->quant_shift, qcoeff, dqcoeff,
                        pd->dequant, eob, scan_order->scan,
@@ -912,13 +1050,13 @@
                        scan_order->iscan);
       }
       if (!x->skip_encode && *eob)
-        vp10_iht16x16_add(tx_type, dqcoeff, dst, dst_stride, *eob);
+        vp10_inv_txfm_add_16x16(dqcoeff, dst, dst_stride, *eob, tx_type);
       break;
     case TX_8X8:
       if (!x->skip_recode) {
         vpx_subtract_block(8, 8, src_diff, diff_stride,
                            src, src_stride, dst, dst_stride);
-        vp10_fht8x8(src_diff, coeff, diff_stride, tx_type);
+        fwd_txfm_8x8(src_diff, coeff, diff_stride, tx_type);
         vpx_quantize_b(coeff, 64, x->skip_block, p->zbin, p->round, p->quant,
                        p->quant_shift, qcoeff, dqcoeff,
                        pd->dequant, eob, scan_order->scan,
@@ -925,16 +1063,13 @@
                        scan_order->iscan);
       }
       if (!x->skip_encode && *eob)
-        vp10_iht8x8_add(tx_type, dqcoeff, dst, dst_stride, *eob);
+        vp10_inv_txfm_add_8x8(dqcoeff, dst, dst_stride, *eob, tx_type);
       break;
     case TX_4X4:
       if (!x->skip_recode) {
         vpx_subtract_block(4, 4, src_diff, diff_stride,
                            src, src_stride, dst, dst_stride);
-        if (tx_type != DCT_DCT)
-          vp10_fht4x4(src_diff, coeff, diff_stride, tx_type);
-        else
-          x->fwd_txm4x4(src_diff, coeff, diff_stride);
+        vp10_fwd_txfm_4x4(src_diff, coeff, diff_stride, tx_type, x->fwd_txm4x4);
         vpx_quantize_b(coeff, 16, x->skip_block, p->zbin, p->round, p->quant,
                        p->quant_shift, qcoeff, dqcoeff,
                        pd->dequant, eob, scan_order->scan,
@@ -942,13 +1077,11 @@
       }
 
       if (!x->skip_encode && *eob) {
-        if (tx_type == DCT_DCT)
-          // this is like vp10_short_idct4x4 but has a special case around eob<=1
-          // which is significant (not just an optimization) for the lossless
-          // case.
-          x->itxm_add(dqcoeff, dst, dst_stride, *eob);
-        else
-          vp10_iht4x4_16_add(dqcoeff, dst, dst_stride, tx_type);
+        // this is like vp10_short_idct4x4 but has a special case around eob<=1
+        // which is significant (not just an optimization) for the lossless
+        // case.
+        vp10_inv_txfm_add_4x4(dqcoeff, dst, dst_stride, *eob, tx_type,
+                              x->itxm_add);
       }
       break;
     default:
--- a/vp10/encoder/encodemb.h
+++ b/vp10/encoder/encodemb.h
@@ -39,6 +39,18 @@
 
 void vp10_encode_intra_block_plane(MACROBLOCK *x, BLOCK_SIZE bsize, int plane);
 
+void vp10_fwd_txfm_4x4(const int16_t *src_diff,
+                       tran_low_t *coeff, int diff_stride, TX_TYPE tx_type,
+                       void (*fwd_txm4x4)(const int16_t *input,
+                           tran_low_t *output, int stride));
+
+#if CONFIG_VP9_HIGHBITDEPTH
+void vp10_highbd_fwd_txfm_4x4(const int16_t *src_diff, tran_low_t *coeff,
+                              int diff_stride, TX_TYPE tx_type,
+                              void (*highbd_fwd_txm4x4)(const int16_t *input,
+                                  tran_low_t *output, int stride));
+#endif  // CONFIG_VP9_HIGHBITDEPTH
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -799,7 +799,8 @@
           if (xd->lossless) {
             TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block);
             const scan_order *so = get_scan(TX_4X4, tx_type);
-            vp10_highbd_fwht4x4(src_diff, coeff, 8);
+            vp10_highbd_fwd_txfm_4x4(src_diff, coeff, 8, DCT_DCT,
+                                     vp10_highbd_fwht4x4);
             vp10_regular_quantize_b_4x4(x, 0, block, so->scan, so->iscan);
             ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                                  so->scan, so->neighbors,
@@ -806,17 +807,16 @@
                                  cpi->sf.use_fast_coef_costing);
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
               goto next_highbd;
-            vp10_highbd_iwht4x4_add(BLOCK_OFFSET(pd->dqcoeff, block),
-                                   dst, dst_stride,
-                                   p->eobs[block], xd->bd);
+            vp10_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block),
+                                         dst, dst_stride, p->eobs[block],
+                                         xd->bd, DCT_DCT,
+                                         vp10_highbd_iwht4x4_add);
           } else {
             int64_t unused;
             TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block);
             const scan_order *so = get_scan(TX_4X4, tx_type);
-            if (tx_type == DCT_DCT)
-              vpx_highbd_fdct4x4(src_diff, coeff, 8);
-            else
-              vp10_highbd_fht4x4(src_diff, coeff, 8, tx_type);
+            vp10_highbd_fwd_txfm_4x4(src_diff, coeff, 8, tx_type,
+                                     vpx_highbd_fdct4x4);
             vp10_regular_quantize_b_4x4(x, 0, block, so->scan, so->iscan);
             ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                                  so->scan, so->neighbors,
@@ -826,8 +826,10 @@
                 16, &unused, xd->bd) >> 2;
             if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
               goto next_highbd;
-            vp10_highbd_iht4x4_add(tx_type, BLOCK_OFFSET(pd->dqcoeff, block),
-                                  dst, dst_stride, p->eobs[block], xd->bd);
+            vp10_highbd_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block),
+                                         dst, dst_stride, p->eobs[block],
+                                         xd->bd, tx_type,
+                                         vp10_highbd_idct4x4_add);
           }
         }
       }
@@ -902,7 +904,7 @@
         if (xd->lossless) {
           TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block);
           const scan_order *so = get_scan(TX_4X4, tx_type);
-          vp10_fwht4x4(src_diff, coeff, 8);
+          vp10_fwd_txfm_4x4(src_diff, coeff, 8, DCT_DCT, vp10_fwht4x4);
           vp10_regular_quantize_b_4x4(x, 0, block, so->scan, so->iscan);
           ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                                so->scan, so->neighbors,
@@ -909,13 +911,14 @@
                                cpi->sf.use_fast_coef_costing);
           if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
             goto next;
-          vp10_iwht4x4_add(BLOCK_OFFSET(pd->dqcoeff, block), dst, dst_stride,
-                          p->eobs[block]);
+          vp10_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block),
+                                dst, dst_stride, p->eobs[block], DCT_DCT,
+                                vp10_iwht4x4_add);
         } else {
           int64_t unused;
           TX_TYPE tx_type = get_tx_type(PLANE_TYPE_Y, xd, block);
           const scan_order *so = get_scan(TX_4X4, tx_type);
-          vp10_fht4x4(src_diff, coeff, 8, tx_type);
+          vp10_fwd_txfm_4x4(src_diff, coeff, 8, tx_type, vpx_fdct4x4);
           vp10_regular_quantize_b_4x4(x, 0, block, so->scan, so->iscan);
           ratey += cost_coeffs(x, 0, block, tempa + idx, templ + idy, TX_4X4,
                              so->scan, so->neighbors,
@@ -924,8 +927,9 @@
                                         16, &unused) >> 2;
           if (RDCOST(x->rdmult, x->rddiv, ratey, distortion) >= best_rd)
             goto next;
-          vp10_iht4x4_add(tx_type, BLOCK_OFFSET(pd->dqcoeff, block),
-                         dst, dst_stride, p->eobs[block]);
+          vp10_inv_txfm_add_4x4(BLOCK_OFFSET(pd->dqcoeff, block),
+                                dst, dst_stride, p->eobs[block], tx_type,
+                                vp10_idct4x4_add);
         }
       }
     }