shithub: libvpx

Download patch

ref: f7dfa4ece7b4e2aef190923abe4a3f1d3ca3ece8
parent: 2bd4f444092bf1902a1caca66e14e8e75189191d
author: Debargha Mukherjee <[email protected]>
date: Wed Jan 6 06:24:57 EST 2016

Modifies inter/intra coding to allow all tx types

The nominal tx_type for a given mode is used as a context
to encode the actual tx_type for intra.

Results:
derflr: -0.241% BDRATE
hevcmr: -0.366% BDRATE

Change-Id: Icfe7b0a58d79bc6497a06e3441779afec6e01e21

--- a/configure
+++ b/configure
@@ -273,6 +273,7 @@
     fp_mb_stats
     emulate_hardware
     misc_fixes
+    ext_tx
 "
 CONFIG_LIST="
     dependency_tracking
--- a/vp10/common/blockd.h
+++ b/vp10/common/blockd.h
@@ -82,6 +82,9 @@
   // Only for INTER blocks
   INTERP_FILTER interp_filter;
   MV_REFERENCE_FRAME ref_frame[2];
+#if CONFIG_EXT_TX
+  TX_TYPE tx_type;
+#endif  // CONFIG_EXT_TX
 
   // TODO(slavarnway): Delete and use bmi[3].as_mv[] instead.
   int_mv mv[2];
@@ -207,7 +210,7 @@
   return subsize_lookup[partition][bsize];
 }
 
-static const TX_TYPE intra_mode_to_tx_type_lookup[INTRA_MODES] = {
+static const TX_TYPE intra_mode_to_tx_type_context[INTRA_MODES] = {
   DCT_DCT,    // DC
   ADST_DCT,   // V
   DCT_ADST,   // H
@@ -225,11 +228,20 @@
   const MODE_INFO *const mi = xd->mi[0];
   const MB_MODE_INFO *const mbmi = &mi->mbmi;
 
+#if CONFIG_EXT_TX
+  (void) block_idx;
   if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
+      mbmi->tx_size >= TX_32X32)
+    return DCT_DCT;
+
+  return mbmi->tx_type;
+#else
+  if (plane_type != PLANE_TYPE_Y || xd->lossless[mbmi->segment_id] ||
       is_inter_block(mbmi) || mbmi->tx_size >= TX_32X32)
     return DCT_DCT;
 
-  return intra_mode_to_tx_type_lookup[get_y_mode(mi, block_idx)];
+  return intra_mode_to_tx_type_context[get_y_mode(mi, block_idx)];
+#endif  // CONFIG_EXT_TX
 }
 
 void vp10_setup_block_planes(MACROBLOCKD *xd, int ss_x, int ss_y);
--- a/vp10/common/entropy.h
+++ b/vp10/common/entropy.h
@@ -21,7 +21,8 @@
 extern "C" {
 #endif
 
-#define DIFF_UPDATE_PROB 252
+#define DIFF_UPDATE_PROB        252
+#define GROUP_DIFF_UPDATE_PROB  252
 
 // Coefficient token alphabet
 #define ZERO_TOKEN      0   // 0     Extra Bits 0+0
--- a/vp10/common/entropymode.c
+++ b/vp10/common/entropymode.c
@@ -326,6 +326,28 @@
 };
 #endif
 
+#if CONFIG_EXT_TX
+const vpx_tree_index vp10_ext_tx_tree[TREE_SIZE(TX_TYPES)] = {
+  -DCT_DCT, 2,
+  -ADST_ADST, 4,
+  -ADST_DCT, -DCT_ADST
+};
+
+static const vpx_prob default_intra_ext_tx_prob[EXT_TX_SIZES]
+                                               [TX_TYPES][TX_TYPES - 1] = {
+  {{240, 85, 128}, {4, 1, 248}, {4, 1, 8}, {4, 248, 128}},
+  {{244, 85, 128}, {8, 2, 248}, {8, 2, 8}, {8, 248, 128}},
+  {{248, 85, 128}, {16, 4, 248}, {16, 4, 8}, {16, 248, 128}},
+};
+
+static const vpx_prob default_inter_ext_tx_prob[EXT_TX_SIZES]
+                                               [TX_TYPES - 1] = {
+  {160, 85, 128},
+  {176, 85, 128},
+  {192, 85, 128},
+};
+#endif  // CONFIG_EXT_TX
+
 static void init_mode_probs(FRAME_CONTEXT *fc) {
   vp10_copy(fc->uv_mode_prob, default_uv_probs);
   vp10_copy(fc->y_mode_prob, default_if_y_probs);
@@ -342,6 +364,10 @@
   vp10_copy(fc->seg.tree_probs, default_seg_probs.tree_probs);
   vp10_copy(fc->seg.pred_probs, default_seg_probs.pred_probs);
 #endif
+#if CONFIG_EXT_TX
+  vp10_copy(fc->intra_ext_tx_prob, default_intra_ext_tx_prob);
+  vp10_copy(fc->inter_ext_tx_prob, default_inter_ext_tx_prob);
+#endif  // CONFIG_EXT_TX
 }
 
 const vpx_tree_index vp10_switchable_interp_tree
@@ -430,6 +456,23 @@
   for (i = 0; i < SKIP_CONTEXTS; ++i)
     fc->skip_probs[i] = mode_mv_merge_probs(
         pre_fc->skip_probs[i], counts->skip[i]);
+
+#if CONFIG_EXT_TX
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    int j;
+    for (j = 0; j < TX_TYPES; ++j)
+      vpx_tree_merge_probs(vp10_ext_tx_tree,
+                           pre_fc->intra_ext_tx_prob[i][j],
+                           counts->intra_ext_tx[i][j],
+                           fc->intra_ext_tx_prob[i][j]);
+  }
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    vpx_tree_merge_probs(vp10_ext_tx_tree,
+                         pre_fc->inter_ext_tx_prob[i],
+                         counts->inter_ext_tx[i],
+                         fc->inter_ext_tx_prob[i]);
+  }
+#endif  // CONFIG_EXT_TX
 
 #if CONFIG_MISC_FIXES
   if (cm->seg.temporal_update) {
--- a/vp10/common/entropymode.h
+++ b/vp10/common/entropymode.h
@@ -66,6 +66,10 @@
 #if CONFIG_MISC_FIXES
   struct segmentation_probs seg;
 #endif
+#if CONFIG_EXT_TX
+  vpx_prob intra_ext_tx_prob[EXT_TX_SIZES][TX_TYPES][TX_TYPES - 1];
+  vpx_prob inter_ext_tx_prob[EXT_TX_SIZES][TX_TYPES - 1];
+#endif  // CONFIG_EXT_TX
   int initialized;
 } FRAME_CONTEXT;
 
@@ -90,6 +94,10 @@
 #if CONFIG_MISC_FIXES
   struct seg_counts seg;
 #endif
+#if CONFIG_EXT_TX
+  unsigned int intra_ext_tx[EXT_TX_SIZES][TX_TYPES][TX_TYPES];
+  unsigned int inter_ext_tx[EXT_TX_SIZES][TX_TYPES];
+#endif  // CONFIG_EXT_TX
 } FRAME_COUNTS;
 
 extern const vpx_prob vp10_kf_y_mode_prob[INTRA_MODES][INTRA_MODES]
@@ -118,6 +126,11 @@
                                       unsigned int (*ct_16x16p)[2]);
 void vp10_tx_counts_to_branch_counts_8x8(const unsigned int *tx_count_8x8p,
                                     unsigned int (*ct_8x8p)[2]);
+
+#if CONFIG_EXT_TX
+extern const vpx_tree_index
+    vp10_ext_tx_tree[TREE_SIZE(TX_TYPES)];
+#endif  // CONFIG_EXT_TX
 
 static INLINE int vp10_ceil_log2(int n) {
   int i = 1, p = 2;
--- a/vp10/common/enums.h
+++ b/vp10/common/enums.h
@@ -97,6 +97,10 @@
   TX_TYPES = 4
 } TX_TYPE;
 
+#if CONFIG_EXT_TX
+#define EXT_TX_SIZES       3  // number of sizes that use extended transforms
+#endif  // CONFIG_EXT_TX
+
 typedef enum {
   VP9_LAST_FLAG = 1 << 0,
   VP9_GOLD_FLAG = 1 << 1,
--- a/vp10/common/thread_common.c
+++ b/vp10/common/thread_common.c
@@ -435,6 +435,19 @@
       comps->fp[i] += comps_t->fp[i];
   }
 
+#if CONFIG_EXT_TX
+  for (i = 0; i < EXT_TX_SIZES; i++) {
+    int j;
+    for (j = 0; j < TX_TYPES; ++j)
+      for (k = 0; k < TX_TYPES; k++)
+        cm->counts.intra_ext_tx[i][j][k] += counts->intra_ext_tx[i][j][k];
+  }
+  for (i = 0; i < EXT_TX_SIZES; i++) {
+    for (k = 0; k < TX_TYPES; k++)
+      cm->counts.inter_ext_tx[i][k] += counts->inter_ext_tx[i][k];
+  }
+#endif  // CONFIG_EXT_TX
+
 #if CONFIG_MISC_FIXES
   for (i = 0; i < PREDICTION_PROBS; i++)
     for (j = 0; j < 2; j++)
--- a/vp10/decoder/decodeframe.c
+++ b/vp10/decoder/decodeframe.c
@@ -268,7 +268,11 @@
     if (eob == 1) {
       dqcoeff[0] = 0;
     } else {
+#if CONFIG_EXT_TX
+      if (tx_type == DCT_DCT && tx_size <= TX_16X16 && eob <= 10)
+#else
       if (tx_size <= TX_16X16 && eob <= 10)
+#endif  // CONFIG_EXT_TX
         memset(dqcoeff, 0, 4 * (4 << tx_size) * sizeof(dqcoeff[0]));
       else if (tx_size == TX_32X32 && eob <= 34)
         memset(dqcoeff, 0, 256 * sizeof(dqcoeff[0]));
@@ -2124,6 +2128,25 @@
   return sz;
 }
 
+#if CONFIG_EXT_TX
+static void read_ext_tx_probs(FRAME_CONTEXT *fc, vpx_reader *r) {
+  int i, j, k;
+  if (vpx_read(r, GROUP_DIFF_UPDATE_PROB)) {
+    for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+      for (j = 0; j < TX_TYPES; ++j)
+        for (k = 0; k < TX_TYPES - 1; ++k)
+          vp10_diff_update_prob(r, &fc->intra_ext_tx_prob[i][j][k]);
+    }
+  }
+  if (vpx_read(r, GROUP_DIFF_UPDATE_PROB)) {
+    for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+      for (k = 0; k < TX_TYPES - 1; ++k)
+        vp10_diff_update_prob(r, &fc->inter_ext_tx_prob[i][k]);
+    }
+  }
+}
+#endif  // CONFIG_EXT_TX
+
 static int read_compressed_header(VP10Decoder *pbi, const uint8_t *data,
                                   size_t partition_size) {
   VP10_COMMON *const cm = &pbi->common;
@@ -2205,6 +2228,9 @@
 #endif
 
     read_mv_probs(nmvc, cm->allow_high_precision_mv, &r);
+#if CONFIG_EXT_TX
+    read_ext_tx_probs(fc, &r);
+#endif
   }
 
   return vpx_reader_has_error(&r);
@@ -2245,6 +2271,12 @@
   assert(!memcmp(&cm->counts.tx, &zero_counts.tx, sizeof(cm->counts.tx)));
   assert(!memcmp(cm->counts.skip, zero_counts.skip, sizeof(cm->counts.skip)));
   assert(!memcmp(&cm->counts.mv, &zero_counts.mv, sizeof(cm->counts.mv)));
+#if CONFIG_EXT_TX
+  assert(!memcmp(cm->counts.intra_ext_tx, zero_counts.intra_ext_tx,
+                 sizeof(cm->counts.intra_ext_tx)));
+  assert(!memcmp(cm->counts.inter_ext_tx, zero_counts.inter_ext_tx,
+                 sizeof(cm->counts.inter_ext_tx)));
+#endif  // CONFIG_EXT_TX
 }
 #endif  // NDEBUG
 
--- a/vp10/decoder/decodemv.c
+++ b/vp10/decoder/decodemv.c
@@ -296,6 +296,22 @@
   }
 
   mbmi->uv_mode = read_intra_mode_uv(cm, xd, r, mbmi->mode);
+
+#if CONFIG_EXT_TX
+    if (mbmi->tx_size < TX_32X32 &&
+        cm->base_qindex > 0 && !mbmi->skip &&
+        !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+      FRAME_COUNTS *counts = xd->counts;
+      TX_TYPE tx_type_nom = intra_mode_to_tx_type_context[mbmi->mode];
+      mbmi->tx_type = vpx_read_tree(
+          r, vp10_ext_tx_tree,
+          cm->fc->intra_ext_tx_prob[mbmi->tx_size][tx_type_nom]);
+      if (counts)
+        ++counts->intra_ext_tx[mbmi->tx_size][tx_type_nom][mbmi->tx_type];
+    } else {
+      mbmi->tx_type = DCT_DCT;
+    }
+#endif  // CONFIG_EXT_TX
 }
 
 static int read_mv_component(vpx_reader *r,
@@ -652,6 +668,30 @@
     read_inter_block_mode_info(pbi, xd, mi, mi_row, mi_col, r);
   else
     read_intra_block_mode_info(cm, xd, mi, r);
+
+#if CONFIG_EXT_TX
+  if (mbmi->tx_size < TX_32X32 &&
+      cm->base_qindex > 0 && !mbmi->skip &&
+      !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+    FRAME_COUNTS *counts = xd->counts;
+    if (inter_block) {
+      mbmi->tx_type = vpx_read_tree(
+          r, vp10_ext_tx_tree,
+          cm->fc->inter_ext_tx_prob[mbmi->tx_size]);
+      if (counts)
+        ++counts->inter_ext_tx[mbmi->tx_size][mbmi->tx_type];
+    } else {
+      const TX_TYPE tx_type_nom = intra_mode_to_tx_type_context[mbmi->mode];
+      mbmi->tx_type = vpx_read_tree(
+          r, vp10_ext_tx_tree,
+          cm->fc->intra_ext_tx_prob[mbmi->tx_size][tx_type_nom]);
+      if (counts)
+        ++counts->intra_ext_tx[mbmi->tx_size][tx_type_nom][mbmi->tx_type];
+    }
+  } else {
+    mbmi->tx_type = DCT_DCT;
+  }
+#endif  // CONFIG_EXT_TX
 }
 
 void vp10_read_mode_info(VP10Decoder *const pbi, MACROBLOCKD *xd,
--- a/vp10/encoder/bitstream.c
+++ b/vp10/encoder/bitstream.c
@@ -58,6 +58,16 @@
   }
 }
 
+#if CONFIG_EXT_TX
+static struct vp10_token ext_tx_encodings[TX_TYPES];
+#endif  // CONFIG_EXT_TX
+
+void vp10_encode_token_init() {
+#if CONFIG_EXT_TX
+  vp10_tokens_from_tree(ext_tx_encodings, vp10_ext_tx_tree);
+#endif  // CONFIG_EXT_TX
+}
+
 static void write_intra_mode(vpx_writer *w, PREDICTION_MODE mode,
                              const vpx_prob *probs) {
   vp10_write_token(w, vp10_intra_mode_tree, probs, &intra_mode_encodings[mode]);
@@ -90,6 +100,24 @@
     vp10_cond_prob_diff_update(w, &probs[i], branch_ct[i]);
 }
 
+static int prob_diff_update_savings(const vpx_tree_index *tree,
+                                    vpx_prob probs[/*n - 1*/],
+                                    const unsigned int counts[/*n - 1*/],
+                                    int n) {
+  int i;
+  unsigned int branch_ct[32][2];
+  int savings = 0;
+
+  // Assuming max number of probabilities <= 32
+  assert(n <= 32);
+  vp10_tree_probs_from_distribution(tree, branch_ct, counts);
+  for (i = 0; i < n - 1; ++i) {
+    savings += vp10_cond_prob_diff_update_savings(&probs[i],
+                                                  branch_ct[i]);
+  }
+  return savings;
+}
+
 static void write_selected_tx_size(const VP10_COMMON *cm,
                                    const MACROBLOCKD *xd, vpx_writer *w) {
   TX_SIZE tx_size = xd->mi[0]->mbmi.tx_size;
@@ -133,6 +161,51 @@
                      counts->switchable_interp[j], SWITCHABLE_FILTERS, w);
 }
 
+#if CONFIG_EXT_TX
+static void update_ext_tx_probs(VP10_COMMON *cm, vpx_writer *w) {
+  const int savings_thresh = vp10_cost_one(GROUP_DIFF_UPDATE_PROB) -
+                             vp10_cost_zero(GROUP_DIFF_UPDATE_PROB);
+  int i, j;
+
+  int savings = 0;
+  int do_update = 0;
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    for (j = 0; j < TX_TYPES; ++j)
+      savings += prob_diff_update_savings(
+          vp10_ext_tx_tree, cm->fc->intra_ext_tx_prob[i][j],
+          cm->counts.intra_ext_tx[i][j], TX_TYPES);
+  }
+  do_update = savings > savings_thresh;
+  vpx_write(w, do_update, GROUP_DIFF_UPDATE_PROB);
+  if (do_update) {
+    for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+      for (j = 0; j < TX_TYPES; ++j)
+        prob_diff_update(vp10_ext_tx_tree,
+                         cm->fc->intra_ext_tx_prob[i][j],
+                         cm->counts.intra_ext_tx[i][j],
+                         TX_TYPES, w);
+    }
+  }
+  savings = 0;
+  do_update = 0;
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    savings += prob_diff_update_savings(
+        vp10_ext_tx_tree, cm->fc->inter_ext_tx_prob[i],
+        cm->counts.inter_ext_tx[i], TX_TYPES);
+  }
+  do_update = savings > savings_thresh;
+  vpx_write(w, do_update, GROUP_DIFF_UPDATE_PROB);
+  if (do_update) {
+    for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+      prob_diff_update(vp10_ext_tx_tree,
+                       cm->fc->inter_ext_tx_prob[i],
+                       cm->counts.inter_ext_tx[i],
+                       TX_TYPES, w);
+    }
+  }
+}
+#endif  // CONFIG_EXT_TX
+
 static void pack_mb_tokens(vpx_writer *w,
                            TOKENEXTRA **tp, const TOKENEXTRA *const stop,
                            vpx_bit_depth_t bit_depth, const TX_SIZE tx) {
@@ -370,6 +443,27 @@
       }
     }
   }
+#if CONFIG_EXT_TX
+  if (mbmi->tx_size < TX_32X32 &&
+      cm->base_qindex > 0 && !mbmi->skip &&
+      !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+    if (is_inter) {
+      vp10_write_token(
+          w, vp10_ext_tx_tree,
+          cm->fc->inter_ext_tx_prob[mbmi->tx_size],
+          &ext_tx_encodings[mbmi->tx_type]);
+    } else {
+      vp10_write_token(
+          w, vp10_ext_tx_tree,
+          cm->fc->intra_ext_tx_prob[mbmi->tx_size]
+                                   [intra_mode_to_tx_type_context[mbmi->mode]],
+          &ext_tx_encodings[mbmi->tx_type]);
+    }
+  } else {
+    if (!mbmi->skip)
+      assert(mbmi->tx_type == DCT_DCT);
+  }
+#endif  // CONFIG_EXT_TX
 }
 
 static void write_mb_modes_kf(const VP10_COMMON *cm, const MACROBLOCKD *xd,
@@ -413,6 +507,18 @@
   }
 
   write_intra_mode(w, mbmi->uv_mode, cm->fc->uv_mode_prob[mbmi->mode]);
+
+#if CONFIG_EXT_TX
+  if (mbmi->tx_size < TX_32X32 &&
+      cm->base_qindex > 0 && !mbmi->skip &&
+      !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+    vp10_write_token(
+        w, vp10_ext_tx_tree,
+        cm->fc->intra_ext_tx_prob[mbmi->tx_size]
+                                 [intra_mode_to_tx_type_context[mbmi->mode]],
+        &ext_tx_encodings[mbmi->tx_type]);
+  }
+#endif  // CONFIG_EXT_TX
 }
 
 static void write_modes_b(VP10_COMP *cpi, const TileInfo *const tile,
@@ -1381,6 +1487,9 @@
 
     vp10_write_nmv_probs(cm, cm->allow_high_precision_mv, &header_bc,
                         &counts->mv);
+#if CONFIG_EXT_TX
+    update_ext_tx_probs(cm, &header_bc);
+#endif  // CONFIG_EXT_TX
   }
 
   vpx_stop_encode(&header_bc);
--- a/vp10/encoder/bitstream.h
+++ b/vp10/encoder/bitstream.h
@@ -18,6 +18,7 @@
 
 #include "vp10/encoder/encoder.h"
 
+void vp10_encode_token_init();
 void vp10_pack_bitstream(VP10_COMP *const cpi, uint8_t *dest, size_t *size);
 
 static INLINE int vp10_preserve_existing_gf(VP10_COMP *cpi) {
--- a/vp10/encoder/encodeframe.c
+++ b/vp10/encoder/encodeframe.c
@@ -3024,5 +3024,18 @@
     }
     ++td->counts->tx.tx_totals[mbmi->tx_size];
     ++td->counts->tx.tx_totals[get_uv_tx_size(mbmi, &xd->plane[1])];
+#if CONFIG_EXT_TX
+    if (mbmi->tx_size < TX_32X32 &&
+        cm->base_qindex > 0 && !mbmi->skip &&
+        !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP)) {
+      if (is_inter_block(mbmi)) {
+        ++td->counts->inter_ext_tx[mbmi->tx_size][mbmi->tx_type];
+      } else {
+        ++td->counts->intra_ext_tx[mbmi->tx_size]
+                                  [intra_mode_to_tx_type_context[mbmi->mode]]
+                                  [mbmi->tx_type];
+      }
+    }
+#endif  // CONFIG_EXT_TX
   }
 }
--- a/vp10/encoder/encoder.c
+++ b/vp10/encoder/encoder.c
@@ -328,6 +328,7 @@
     vp10_rc_init_minq_luts();
     vp10_entropy_mv_init();
     vp10_temporal_filter_init();
+    vp10_encode_token_init();
     init_done = 1;
   }
 }
@@ -2654,7 +2655,7 @@
   MACROBLOCKD *xd = &cpi->td.mb.e_mbd;
   struct loopfilter *lf = &cm->lf;
   if (is_lossless_requested(&cpi->oxcf)) {
-      lf->filter_level = 0;
+    lf->filter_level = 0;
   } else {
     struct vpx_usec_timer timer;
 
--- a/vp10/encoder/encoder.h
+++ b/vp10/encoder/encoder.h
@@ -467,6 +467,10 @@
   int multi_arf_enabled;
   int multi_arf_last_grp_enabled;
 
+#if CONFIG_EXT_TX
+  int intra_tx_type_costs[EXT_TX_SIZES][TX_TYPES][TX_TYPES];
+  int inter_tx_type_costs[EXT_TX_SIZES][TX_TYPES];
+#endif  // CONFIG_EXT_TX
 #if CONFIG_VP9_TEMPORAL_DENOISING
   VP9_DENOISER denoiser;
 #endif
--- a/vp10/encoder/rd.c
+++ b/vp10/encoder/rd.c
@@ -83,6 +83,20 @@
   for (i = 0; i < SWITCHABLE_FILTER_CONTEXTS; ++i)
     vp10_cost_tokens(cpi->switchable_interp_costs[i],
                     fc->switchable_interp_prob[i], vp10_switchable_interp_tree);
+
+#if CONFIG_EXT_TX
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    for (j = 0; j < TX_TYPES; ++j)
+      vp10_cost_tokens(cpi->intra_tx_type_costs[i][j],
+                       fc->intra_ext_tx_prob[i][j],
+                       vp10_ext_tx_tree);
+  }
+  for (i = TX_4X4; i < EXT_TX_SIZES; ++i) {
+    vp10_cost_tokens(cpi->inter_tx_type_costs[i],
+                     fc->inter_ext_tx_prob[i],
+                     vp10_ext_tx_tree);
+  }
+#endif  // CONFIG_EXT_TX
 }
 
 static void fill_token_costs(vp10_coeff_cost *c,
--- a/vp10/encoder/rdopt.c
+++ b/vp10/encoder/rdopt.c
@@ -54,6 +54,10 @@
 #define MIN_EARLY_TERM_INDEX    3
 #define NEW_MV_DISCOUNT_FACTOR  8
 
+#if CONFIG_EXT_TX
+const double ext_tx_th = 0.99;
+#endif  // CONFIG_EXT_TX
+
 typedef struct {
   PREDICTION_MODE mode;
   MV_REFERENCE_FRAME ref_frame[2];
@@ -598,11 +602,63 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
 
+#if CONFIG_EXT_TX
+  TX_TYPE tx_type, best_tx_type = DCT_DCT;
+  int r, s;
+  int64_t d, psse, this_rd, best_rd = INT64_MAX;
+  vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
+  int  s0 = vp10_cost_bit(skip_prob, 0);
+  int  s1 = vp10_cost_bit(skip_prob, 1);
+  const int is_inter = is_inter_block(mbmi);
+
   mbmi->tx_size = VPXMIN(max_tx_size, largest_tx_size);
+  if (mbmi->tx_size < TX_32X32 &&
+      !xd->lossless[mbmi->segment_id]) {
+    for (tx_type = 0; tx_type < TX_TYPES; ++tx_type) {
+      mbmi->tx_type = tx_type;
+      txfm_rd_in_plane(x, &r, &d, &s,
+                       &psse, ref_best_rd, 0, bs, mbmi->tx_size,
+                       cpi->sf.use_fast_coef_costing);
+      if (r == INT_MAX)
+        continue;
+      if (is_inter)
+        r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
+      else
+        r += cpi->intra_tx_type_costs[mbmi->tx_size]
+                                     [intra_mode_to_tx_type_context[mbmi->mode]]
+                                     [mbmi->tx_type];
+      if (s)
+        this_rd = RDCOST(x->rdmult, x->rddiv, s1, psse);
+      else
+        this_rd = RDCOST(x->rdmult, x->rddiv, r + s0, d);
+      if (is_inter && !xd->lossless[mbmi->segment_id] && !s)
+        this_rd = VPXMIN(this_rd, RDCOST(x->rdmult, x->rddiv, s1, psse));
 
+      if (this_rd < ((best_tx_type == DCT_DCT) ? ext_tx_th : 1) * best_rd) {
+        best_rd = this_rd;
+        best_tx_type = mbmi->tx_type;
+      }
+    }
+  }
+  mbmi->tx_type = best_tx_type;
   txfm_rd_in_plane(x, rate, distortion, skip,
                    sse, ref_best_rd, 0, bs,
                    mbmi->tx_size, cpi->sf.use_fast_coef_costing);
+  if (mbmi->tx_size < TX_32X32 && !xd->lossless[mbmi->segment_id]) {
+    if (is_inter)
+      *rate += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
+    else
+      *rate += cpi->intra_tx_type_costs[mbmi->tx_size]
+          [intra_mode_to_tx_type_context[mbmi->mode]]
+          [mbmi->tx_type];
+  }
+#else
+  mbmi->tx_size = VPXMIN(max_tx_size, largest_tx_size);
+
+  txfm_rd_in_plane(x, rate, distortion, skip,
+                   sse, ref_best_rd, 0, bs,
+                   mbmi->tx_size, cpi->sf.use_fast_coef_costing);
+#endif  // CONFIG_EXT_TX
 }
 
 static void choose_smallest_tx_size(VP10_COMP *cpi, MACROBLOCK *x,
@@ -632,17 +688,19 @@
   MACROBLOCKD *const xd = &x->e_mbd;
   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
   vpx_prob skip_prob = vp10_get_skip_prob(cm, xd);
-  int r[TX_SIZES][2], s[TX_SIZES];
-  int64_t d[TX_SIZES], sse[TX_SIZES];
-  int64_t rd[TX_SIZES][2] = {{INT64_MAX, INT64_MAX},
-                             {INT64_MAX, INT64_MAX},
-                             {INT64_MAX, INT64_MAX},
-                             {INT64_MAX, INT64_MAX}};
+  int r, s;
+  int64_t d, sse;
+  int64_t rd = INT64_MAX;
   int n, m;
   int s0, s1;
-  int64_t best_rd = INT64_MAX;
+  int64_t best_rd = INT64_MAX, last_rd = INT64_MAX;
   TX_SIZE best_tx = max_tx_size;
   int start_tx, end_tx;
+  const int tx_select = cm->tx_mode == TX_MODE_SELECT;
+#if CONFIG_EXT_TX
+  TX_TYPE tx_type, best_tx_type = DCT_DCT;
+#endif  // CONFIG_EXT_TX
+  const int is_inter = is_inter_block(mbmi);
 
   const vpx_prob *tx_probs = get_tx_probs2(max_tx_size, xd, &cm->fc->tx_probs);
   assert(skip_prob > 0);
@@ -649,70 +707,118 @@
   s0 = vp10_cost_bit(skip_prob, 0);
   s1 = vp10_cost_bit(skip_prob, 1);
 
-  if (cm->tx_mode == TX_MODE_SELECT) {
+  if (tx_select) {
     start_tx = max_tx_size;
     end_tx = 0;
   } else {
-    TX_SIZE chosen_tx_size = VPXMIN(max_tx_size,
-                                    tx_mode_to_biggest_tx_size[cm->tx_mode]);
+    const TX_SIZE chosen_tx_size =
+        VPXMIN(max_tx_size, tx_mode_to_biggest_tx_size[cm->tx_mode]);
     start_tx = chosen_tx_size;
     end_tx = chosen_tx_size;
   }
 
-  for (n = start_tx; n >= end_tx; n--) {
-    int r_tx_size = 0;
-    for (m = 0; m <= n - (n == (int) max_tx_size); m++) {
-      if (m == n)
-        r_tx_size += vp10_cost_zero(tx_probs[m]);
-      else
-        r_tx_size += vp10_cost_one(tx_probs[m]);
-    }
-    txfm_rd_in_plane(x, &r[n][0], &d[n], &s[n],
-                     &sse[n], ref_best_rd, 0, bs, n,
-                     cpi->sf.use_fast_coef_costing);
-    r[n][1] = r[n][0];
-    if (r[n][0] < INT_MAX) {
-      r[n][1] += r_tx_size;
-    }
-    if (d[n] == INT64_MAX || r[n][0] == INT_MAX) {
-      rd[n][0] = rd[n][1] = INT64_MAX;
-    } else if (s[n]) {
-      if (is_inter_block(mbmi)) {
-        rd[n][0] = rd[n][1] = RDCOST(x->rdmult, x->rddiv, s1, sse[n]);
-        r[n][1] -= r_tx_size;
+  *distortion = INT64_MAX;
+  *rate       = INT_MAX;
+  *skip       = 0;
+  *psse       = INT64_MAX;
+
+#if CONFIG_EXT_TX
+  for (tx_type = DCT_DCT; tx_type < TX_TYPES; ++tx_type) {
+#endif  // CONFIG_EXT_TX
+    last_rd = INT64_MAX;
+    for (n = start_tx; n >= end_tx; --n) {
+      int r_tx_size = 0;
+      for (m = 0; m <= n - (n == (int) max_tx_size); ++m) {
+        if (m == n)
+          r_tx_size += vp10_cost_zero(tx_probs[m]);
+        else
+          r_tx_size += vp10_cost_one(tx_probs[m]);
+      }
+
+#if CONFIG_EXT_TX
+      if (n >= TX_32X32 && tx_type != DCT_DCT) {
+        continue;
+      }
+      mbmi->tx_type = tx_type;
+      txfm_rd_in_plane(x, &r, &d, &s,
+                       &sse, ref_best_rd, 0, bs, n,
+                       cpi->sf.use_fast_coef_costing);
+      if (n < TX_32X32 &&
+          !xd->lossless[xd->mi[0]->mbmi.segment_id] &&
+          r != INT_MAX) {
+        if (is_inter)
+          r += cpi->inter_tx_type_costs[mbmi->tx_size][mbmi->tx_type];
+        else
+          r += cpi->intra_tx_type_costs[mbmi->tx_size]
+              [intra_mode_to_tx_type_context[mbmi->mode]]
+              [mbmi->tx_type];
+      }
+#else  // CONFIG_EXT_TX
+      txfm_rd_in_plane(x, &r, &d, &s,
+                       &sse, ref_best_rd, 0, bs, n,
+                       cpi->sf.use_fast_coef_costing);
+#endif  // CONFIG_EXT_TX
+
+      if (r == INT_MAX)
+        continue;
+
+      if (s) {
+        if (is_inter) {
+          rd = RDCOST(x->rdmult, x->rddiv, s1, sse);
+        } else {
+          rd =  RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size * tx_select, sse);
+        }
       } else {
-        rd[n][0] = RDCOST(x->rdmult, x->rddiv, s1, sse[n]);
-        rd[n][1] = RDCOST(x->rdmult, x->rddiv, s1 + r_tx_size, sse[n]);
+        rd = RDCOST(x->rdmult, x->rddiv, r + s0 + r_tx_size * tx_select, d);
       }
-    } else {
-      rd[n][0] = RDCOST(x->rdmult, x->rddiv, r[n][0] + s0, d[n]);
-      rd[n][1] = RDCOST(x->rdmult, x->rddiv, r[n][1] + s0, d[n]);
-    }
 
-    if (is_inter_block(mbmi) && !xd->lossless[mbmi->segment_id] &&
-        !s[n] && sse[n] != INT64_MAX) {
-      rd[n][0] = VPXMIN(rd[n][0], RDCOST(x->rdmult, x->rddiv, s1, sse[n]));
-      rd[n][1] = VPXMIN(rd[n][1], RDCOST(x->rdmult, x->rddiv, s1, sse[n]));
-    }
+      if (tx_select && !(s && is_inter))
+        r += r_tx_size;
 
-    // Early termination in transform size search.
-    if (cpi->sf.tx_size_search_breakout &&
-        (rd[n][1] == INT64_MAX ||
-        (n < (int) max_tx_size && rd[n][1] > rd[n + 1][1]) ||
-        s[n] == 1))
-      break;
+      if (is_inter && !xd->lossless[xd->mi[0]->mbmi.segment_id] && !s)
+        rd = VPXMIN(rd, RDCOST(x->rdmult, x->rddiv, s1, sse));
 
-    if (rd[n][1] < best_rd) {
-      best_tx = n;
-      best_rd = rd[n][1];
+      // Early termination in transform size search.
+      if (cpi->sf.tx_size_search_breakout &&
+          (rd == INT64_MAX ||
+#if CONFIG_EXT_TX
+           (s == 1 && tx_type != DCT_DCT && n < start_tx) ||
+#else
+           (s == 1 && n < start_tx) ||
+#endif
+           (n < (int) max_tx_size && rd > last_rd)))
+        break;
+
+      last_rd = rd;
+      if (rd <
+#if CONFIG_EXT_TX
+          (is_inter && best_tx_type == DCT_DCT ? ext_tx_th : 1) *
+#endif  // CONFIG_EXT_TX
+          best_rd) {
+        best_tx = n;
+        best_rd = rd;
+        *distortion = d;
+        *rate       = r;
+        *skip       = s;
+        *psse       = sse;
+#if CONFIG_EXT_TX
+        best_tx_type = mbmi->tx_type;
+#endif  // CONFIG_EXT_TX
+      }
     }
+#if CONFIG_EXT_TX
   }
-  mbmi->tx_size = best_tx;
+#endif  // CONFIG_EXT_TX
 
-  *distortion = d[mbmi->tx_size];
-  *rate       = r[mbmi->tx_size][cm->tx_mode == TX_MODE_SELECT];
-  *skip       = s[mbmi->tx_size];
-  *psse       = sse[mbmi->tx_size];
+  mbmi->tx_size = best_tx;
+#if CONFIG_EXT_TX
+  mbmi->tx_type = best_tx_type;
+  if (mbmi->tx_size >= TX_32X32)
+    assert(mbmi->tx_type == DCT_DCT);
+  txfm_rd_in_plane(x, &r, &d, &s,
+                   &sse, ref_best_rd, 0, bs, best_tx,
+                   cpi->sf.use_fast_coef_costing);
+#endif  // CONFIG_EXT_TX
 }
 
 static void super_block_yrd(VP10_COMP *cpi, MACROBLOCK *x, int *rate,
@@ -1065,6 +1171,9 @@
   int this_rate, this_rate_tokenonly, s;
   int64_t this_distortion, this_rd;
   TX_SIZE best_tx = TX_4X4;
+#if CONFIG_EXT_TX
+  TX_TYPE best_tx_type = DCT_DCT;
+#endif  // CONFIG_EXT_TX
   int *bmode_costs;
   const MODE_INFO *above_mi = xd->above_mi;
   const MODE_INFO *left_mi = xd->left_mi;
@@ -1091,6 +1200,9 @@
       mode_selected   = mode;
       best_rd         = this_rd;
       best_tx         = mic->mbmi.tx_size;
+#if CONFIG_EXT_TX
+      best_tx_type    = mic->mbmi.tx_type;
+#endif  // CONFIG_EXT_TX
       *rate           = this_rate;
       *rate_tokenonly = this_rate_tokenonly;
       *distortion     = this_distortion;
@@ -1100,6 +1212,9 @@
 
   mic->mbmi.mode = mode_selected;
   mic->mbmi.tx_size = best_tx;
+#if CONFIG_EXT_TX
+  mic->mbmi.tx_type = best_tx_type;
+#endif  // CONFIG_EXT_TX
 
   return best_rd;
 }
--- a/vp10/encoder/subexp.c
+++ b/vp10/encoder/subexp.c
@@ -212,3 +212,12 @@
     vpx_write(w, 0, upd);
   }
 }
+
+int vp10_cond_prob_diff_update_savings(vpx_prob *oldp,
+                                       const unsigned int ct[2]) {
+  const vpx_prob upd = DIFF_UPDATE_PROB;
+  vpx_prob newp = get_binary_prob(ct[0], ct[1]);
+  const int savings = vp10_prob_diff_update_savings_search(ct, *oldp, &newp,
+                                                           upd);
+  return savings;
+}
--- a/vp10/encoder/subexp.h
+++ b/vp10/encoder/subexp.h
@@ -37,6 +37,8 @@
                                               vpx_prob upd,
                                               int stepsize);
 
+int vp10_cond_prob_diff_update_savings(vpx_prob *oldp,
+                                       const unsigned int ct[2]);
 #ifdef __cplusplus
 }  // extern "C"
 #endif