shithub: libvpx

Download patch

ref: 109f58acfd8d46deec1e0bd4d0f82daa36cd6b8e
parent: 025969d910c39ad43b95447efa2fdd46a335276a
parent: 18c08607e0a216962fd3df4f7926ce288ee81b9b
author: Joey Parrish <[email protected]>
date: Thu Apr 24 03:45:20 EDT 2014

Merge "Add VPXD_SET_DECRYPTOR support to the VP9 decoder."

--- a/configure
+++ b/configure
@@ -317,7 +317,6 @@
     multi_res_encoding
     temporal_denoising
     experimental
-    decrypt
     ${EXPERIMENT_LIST}
 "
 CMDLINE_SELECT="
@@ -371,7 +370,6 @@
     multi_res_encoding
     temporal_denoising
     experimental
-    decrypt
 "
 
 process_cmdline() {
--- a/test/test.mk
+++ b/test/test.mk
@@ -110,6 +110,7 @@
 
 LIBVPX_TEST_SRCS-$(CONFIG_VP9)         += convolve_test.cc
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_DECODER) += vp9_thread_test.cc
+LIBVPX_TEST_SRCS-$(CONFIG_VP9_DECODER) += vp9_decrypt_test.cc
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += dct16x16_test.cc
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += dct32x32_test.cc
 LIBVPX_TEST_SRCS-$(CONFIG_VP9_ENCODER) += fdct4x4_test.cc
--- a/test/vp8_boolcoder_test.cc
+++ b/test/vp8_boolcoder_test.cc
@@ -94,14 +94,10 @@
         vp8_stop_encode(&bw);
 
         BOOL_DECODER br;
-#if CONFIG_DECRYPT
-        encrypt_buffer(bw_buffer, buffer_size);
-        vp8dx_start_decode(&br, bw_buffer, buffer_size,
+        encrypt_buffer(bw_buffer, kBufferSize);
+        vp8dx_start_decode(&br, bw_buffer, kBufferSize,
                            test_decrypt_cb,
                            reinterpret_cast<void *>(bw_buffer));
-#else
-        vp8dx_start_decode(&br, bw_buffer, kBufferSize, NULL, NULL);
-#endif
         bit_rnd.Reset(random_seed);
         for (int i = 0; i < kBitsToTest; ++i) {
           if (bit_method == 2) {
--- a/test/vp8_decrypt_test.cc
+++ b/test/vp8_decrypt_test.cc
@@ -43,7 +43,7 @@
 
 namespace libvpx_test {
 
-TEST(TestDecrypt, DecryptWorks) {
+TEST(TestDecrypt, DecryptWorksVp8) {
   libvpx_test::IVFVideoSource video("vp80-00-comprehensive-001.ivf");
   video.Init();
 
@@ -59,14 +59,12 @@
   // decrypt frame
   video.Next();
 
-#if CONFIG_DECRYPT
   std::vector<uint8_t> encrypted(video.frame_size());
   encrypt_buffer(video.cxdata(), &encrypted[0], video.frame_size(), 0);
-  vp8_decrypt_init di = { test_decrypt_cb, &encrypted[0] };
-  decoder.Control(VP8D_SET_DECRYPTOR, &di);
-#endif  // CONFIG_DECRYPT
+  vpx_decrypt_init di = { test_decrypt_cb, &encrypted[0] };
+  decoder.Control(VPXD_SET_DECRYPTOR, &di);
 
-  res = decoder.DecodeFrame(video.cxdata(), video.frame_size());
+  res = decoder.DecodeFrame(&encrypted[0], encrypted.size());
   ASSERT_EQ(VPX_CODEC_OK, res) << decoder.DecodeError();
 }
 
--- a/test/vp9_boolcoder_test.cc
+++ b/test/vp9_boolcoder_test.cc
@@ -70,7 +70,7 @@
         GTEST_ASSERT_EQ(bw_buffer[0] & 0x80, 0);
 
         vp9_reader br;
-        vp9_reader_init(&br, bw_buffer, kBufferSize);
+        vp9_reader_init(&br, bw_buffer, kBufferSize, NULL, NULL);
         bit_rnd.Reset(random_seed);
         for (int i = 0; i < kBitsToTest; ++i) {
           if (bit_method == 2) {
--- /dev/null
+++ b/test/vp9_decrypt_test.cc
@@ -1,0 +1,71 @@
+/*
+ *  Copyright (c) 2013 The WebM project authors. All Rights Reserved.
+ *
+ *  Use of this source code is governed by a BSD-style license
+ *  that can be found in the LICENSE file in the root of the source
+ *  tree. An additional intellectual property rights grant can be found
+ *  in the file PATENTS.  All contributing project authors may
+ *  be found in the AUTHORS file in the root of the source tree.
+ */
+
+#include <cstdio>
+#include <cstdlib>
+#include <string>
+#include <vector>
+#include "third_party/googletest/src/include/gtest/gtest.h"
+#include "test/codec_factory.h"
+#include "test/ivf_video_source.h"
+
+namespace {
+// In a real use the 'decrypt_state' parameter will be a pointer to a struct
+// with whatever internal state the decryptor uses. For testing we'll just
+// xor with a constant key, and decrypt_state will point to the start of
+// the original buffer.
+const uint8_t test_key[16] = {
+  0x01, 0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78,
+  0x89, 0x9a, 0xab, 0xbc, 0xcd, 0xde, 0xef, 0xf0
+};
+
+void encrypt_buffer(const uint8_t *src, uint8_t *dst, size_t size,
+                    ptrdiff_t offset) {
+  for (size_t i = 0; i < size; ++i) {
+    dst[i] = src[i] ^ test_key[(offset + i) & 15];
+  }
+}
+
+void test_decrypt_cb(void *decrypt_state, const uint8_t *input,
+                     uint8_t *output, int count) {
+  encrypt_buffer(input, output, count,
+                 input - reinterpret_cast<uint8_t *>(decrypt_state));
+}
+
+}  // namespace
+
+namespace libvpx_test {
+
+TEST(TestDecrypt, DecryptWorksVp9) {
+  libvpx_test::IVFVideoSource video("vp90-2-05-resize.ivf");
+  video.Init();
+
+  vpx_codec_dec_cfg_t dec_cfg = {0};
+  VP9Decoder decoder(dec_cfg, 0);
+
+  video.Begin();
+
+  // no decryption
+  vpx_codec_err_t res = decoder.DecodeFrame(video.cxdata(), video.frame_size());
+  ASSERT_EQ(VPX_CODEC_OK, res) << decoder.DecodeError();
+
+  // decrypt frame
+  video.Next();
+
+  std::vector<uint8_t> encrypted(video.frame_size());
+  encrypt_buffer(video.cxdata(), &encrypted[0], video.frame_size(), 0);
+  vpx_decrypt_init di = { test_decrypt_cb, &encrypted[0] };
+  decoder.Control(VPXD_SET_DECRYPTOR, &di);
+
+  res = decoder.DecodeFrame(&encrypted[0], encrypted.size());
+  ASSERT_EQ(VPX_CODEC_OK, res) << decoder.DecodeError();
+}
+
+}  // namespace libvpx_test
--- a/vp8/common/common.h
+++ b/vp8/common/common.h
@@ -22,6 +22,9 @@
 extern "C" {
 #endif
 
+#define MIN(x, y) (((x) < (y)) ? (x) : (y))
+#define MAX(x, y) (((x) > (y)) ? (x) : (y))
+
 /* Only need this for fixed-size arrays, for structs just assign. */
 
 #define vp8_copy( Dest, Src) { \
--- a/vp8/decoder/dboolhuff.c
+++ b/vp8/decoder/dboolhuff.c
@@ -10,11 +10,12 @@
 
 
 #include "dboolhuff.h"
+#include "vp8/common/common.h"
 
 int vp8dx_start_decode(BOOL_DECODER *br,
                        const unsigned char *source,
                        unsigned int source_sz,
-                       vp8_decrypt_cb *decrypt_cb,
+                       vpx_decrypt_cb decrypt_cb,
                        void *decrypt_state)
 {
     br->user_buffer_end = source+source_sz;
@@ -39,7 +40,7 @@
     const unsigned char *bufptr = br->user_buffer;
     VP8_BD_VALUE value = br->value;
     int count = br->count;
-    int shift = VP8_BD_VALUE_SIZE - 8 - (count + 8);
+    int shift = VP8_BD_VALUE_SIZE - CHAR_BIT - (count + CHAR_BIT);
     size_t bytes_left = br->user_buffer_end - bufptr;
     size_t bits_left = bytes_left * CHAR_BIT;
     int x = (int)(shift + CHAR_BIT - bits_left);
@@ -47,7 +48,7 @@
     unsigned char decrypted[sizeof(VP8_BD_VALUE) + 1];
 
     if (br->decrypt_cb) {
-        size_t n = bytes_left > sizeof(decrypted) ? sizeof(decrypted) : bytes_left;
+        size_t n = MIN(sizeof(decrypted), bytes_left);
         br->decrypt_cb(br->decrypt_state, bufptr, decrypted, (int)n);
         bufptr = decrypted;
     }
--- a/vp8/decoder/dboolhuff.h
+++ b/vp8/decoder/dboolhuff.h
@@ -17,6 +17,7 @@
 
 #include "vpx_config.h"
 #include "vpx_ports/mem.h"
+#include "vpx/vp8dx.h"
 #include "vpx/vpx_integer.h"
 
 #ifdef __cplusplus
@@ -32,12 +33,6 @@
   Even relatively modest values like 100 would work fine.*/
 #define VP8_LOTS_OF_BITS (0x40000000)
 
-/*Decrypt n bytes of data from input -> output, using the decrypt_state
-   passed in VP8D_SET_DECRYPTOR.
-*/
-typedef void (vp8_decrypt_cb)(void *decrypt_state, const unsigned char *input,
-                              unsigned char *output, int count);
-
 typedef struct
 {
     const unsigned char *user_buffer_end;
@@ -45,7 +40,7 @@
     VP8_BD_VALUE         value;
     int                  count;
     unsigned int         range;
-    vp8_decrypt_cb      *decrypt_cb;
+    vpx_decrypt_cb       decrypt_cb;
     void                *decrypt_state;
 } BOOL_DECODER;
 
@@ -54,7 +49,7 @@
 int vp8dx_start_decode(BOOL_DECODER *br,
                        const unsigned char *source,
                        unsigned int source_sz,
-                       vp8_decrypt_cb *decrypt_cb,
+                       vpx_decrypt_cb decrypt_cb,
                        void *decrypt_state);
 
 void vp8dx_bool_decoder_fill(BOOL_DECODER *br);
--- a/vp8/decoder/decodeframe.c
+++ b/vp8/decoder/decodeframe.c
@@ -17,6 +17,7 @@
 #include "vp8/common/reconintra4x4.h"
 #include "vp8/common/reconinter.h"
 #include "detokenize.h"
+#include "vp8/common/common.h"
 #include "vp8/common/invtrans.h"
 #include "vp8/common/alloccommon.h"
 #include "vp8/common/entropymode.h"
@@ -1018,8 +1019,7 @@
         const unsigned char *clear = data;
         if (pbi->decrypt_cb)
         {
-            int n = (int)(data_end - data);
-            if (n > 10) n = 10;
+            int n = (int)MIN(sizeof(clear_buffer), data_end - data);
             pbi->decrypt_cb(pbi->decrypt_state, data, clear_buffer, n);
             clear = clear_buffer;
         }
--- a/vp8/decoder/error_concealment.c
+++ b/vp8/decoder/error_concealment.c
@@ -15,9 +15,7 @@
 #include "decodemv.h"
 #include "vpx_mem/vpx_mem.h"
 #include "vp8/common/findnearmv.h"
-
-#define MIN(x,y) (((x)<(y))?(x):(y))
-#define MAX(x,y) (((x)>(y))?(x):(y))
+#include "vp8/common/common.h"
 
 #define FLOOR(x,q) ((x) & -(1 << (q)))
 
--- a/vp8/decoder/onyxd_int.h
+++ b/vp8/decoder/onyxd_int.h
@@ -126,7 +126,7 @@
     int independent_partitions;
     int frame_corrupt_residual;
 
-    vp8_decrypt_cb *decrypt_cb;
+    vpx_decrypt_cb decrypt_cb;
     void *decrypt_state;
 } VP8D_COMP;
 
--- a/vp8/encoder/mcomp.c
+++ b/vp8/encoder/mcomp.c
@@ -17,6 +17,7 @@
 #include <limits.h>
 #include <math.h>
 #include "vp8/common/findnearmv.h"
+#include "vp8/common/common.h"
 
 #ifdef VP8_ENTROPY_STATS
 static int mv_ref_ct [31] [4] [2];
--- a/vp8/encoder/mr_dissim.c
+++ b/vp8/encoder/mr_dissim.c
@@ -15,6 +15,7 @@
 #include "mr_dissim.h"
 #include "vpx_mem/vpx_mem.h"
 #include "rdopt.h"
+#include "vp8/common/common.h"
 
 void vp8_cal_low_res_mb_cols(VP8_COMP *cpi)
 {
--- a/vp8/encoder/onyx_int.h
+++ b/vp8/encoder/onyx_int.h
@@ -61,9 +61,6 @@
 #define VP8_TEMPORAL_ALT_REF 1
 #endif
 
-#define MAX(x,y) (((x)>(y))?(x):(y))
-#define MIN(x,y) (((x)<(y))?(x):(y))
-
 typedef struct
 {
     int kf_indicated;
--- a/vp8/encoder/pickinter.c
+++ b/vp8/encoder/pickinter.c
@@ -14,6 +14,7 @@
 #include "onyx_int.h"
 #include "modecosts.h"
 #include "encodeintra.h"
+#include "vp8/common/common.h"
 #include "vp8/common/entropymode.h"
 #include "pickinter.h"
 #include "vp8/common/findnearmv.h"
--- a/vp8/vp8_dx_iface.c
+++ b/vp8/vp8_dx_iface.c
@@ -16,9 +16,10 @@
 #include "vpx/vp8dx.h"
 #include "vpx/internal/vpx_codec_internal.h"
 #include "vpx_version.h"
+#include "common/alloccommon.h"
+#include "common/common.h"
 #include "common/onyxd.h"
 #include "decoder/onyxd_int.h"
-#include "common/alloccommon.h"
 #include "vpx_mem/vpx_mem.h"
 #if CONFIG_ERROR_CONCEALMENT
 #include "decoder/error_concealment.h"
@@ -56,7 +57,7 @@
     int                     dbg_color_b_modes_flag;
     int                     dbg_display_mv_flag;
 #endif
-    vp8_decrypt_cb          *decrypt_cb;
+    vpx_decrypt_cb          decrypt_cb;
     void                    *decrypt_state;
     vpx_image_t             img;
     int                     img_setup;
@@ -156,7 +157,7 @@
 static vpx_codec_err_t vp8_peek_si_internal(const uint8_t *data,
                                             unsigned int data_sz,
                                             vpx_codec_stream_info_t *si,
-                                            vp8_decrypt_cb *decrypt_cb,
+                                            vpx_decrypt_cb decrypt_cb,
                                             void *decrypt_state)
 {
     vpx_codec_err_t res = VPX_CODEC_OK;
@@ -177,7 +178,7 @@
         const uint8_t *clear = data;
         if (decrypt_cb)
         {
-            int n = data_sz > 10 ? 10 : data_sz;
+            int n = MIN(sizeof(clear_buffer), data_sz);
             decrypt_cb(decrypt_state, data, clear_buffer, n);
             clear = clear_buffer;
         }
@@ -379,12 +380,15 @@
        }
 
        res = vp8_create_decoder_instances(&ctx->yv12_frame_buffers, &oxcf);
-       ctx->yv12_frame_buffers.pbi[0]->decrypt_cb = ctx->decrypt_cb;
-       ctx->yv12_frame_buffers.pbi[0]->decrypt_state = ctx->decrypt_state;
-
        ctx->decoder_init = 1;
     }
 
+    /* Set these even if already initialized.  The caller may have changed the
+     * decrypt config between frames.
+     */
+    ctx->yv12_frame_buffers.pbi[0]->decrypt_cb = ctx->decrypt_cb;
+    ctx->yv12_frame_buffers.pbi[0]->decrypt_state = ctx->decrypt_state;
+
     if (!res)
     {
         VP8D_COMP *pbi = ctx->yv12_frame_buffers.pbi[0];
@@ -722,7 +726,7 @@
                                          int ctrl_id,
                                          va_list args)
 {
-    vp8_decrypt_init *init = va_arg(args, vp8_decrypt_init *);
+    vpx_decrypt_init *init = va_arg(args, vpx_decrypt_init *);
 
     if (init)
     {
@@ -749,7 +753,7 @@
     {VP8D_GET_LAST_REF_UPDATES,     vp8_get_last_ref_updates},
     {VP8D_GET_FRAME_CORRUPTED,      vp8_get_frame_corrupted},
     {VP8D_GET_LAST_REF_USED,        vp8_get_last_ref_frame},
-    {VP8D_SET_DECRYPTOR,            vp8_set_decryptor},
+    {VPXD_SET_DECRYPTOR,            vp8_set_decryptor},
     { -1, NULL},
 };
 
--- a/vp9/decoder/vp9_decodeframe.c
+++ b/vp9/decoder/vp9_decodeframe.c
@@ -40,6 +40,8 @@
 #include "vp9/decoder/vp9_reader.h"
 #include "vp9/decoder/vp9_thread.h"
 
+#define MAX_VP9_HEADER_SIZE 80
+
 static int is_compound_reference_allowed(const VP9_COMMON *cm) {
   int i;
   for (i = 1; i < REFS_PER_FRAME; ++i)
@@ -451,7 +453,9 @@
                                 const uint8_t *data_end,
                                 size_t read_size,
                                 struct vpx_internal_error_info *error_info,
-                                vp9_reader *r) {
+                                vp9_reader *r,
+                                vpx_decrypt_cb decrypt_cb,
+                                void *decrypt_state) {
   // Validate the calculated partition length. If the buffer
   // described by the partition can't be fully read, then restrict
   // it to the portion that can be (for EC mode) or throw an error.
@@ -459,7 +463,7 @@
     vpx_internal_error(error_info, VPX_CODEC_CORRUPT_FRAME,
                        "Truncated packet or corrupt tile length");
 
-  if (vp9_reader_init(r, data, read_size))
+  if (vp9_reader_init(r, data, read_size, decrypt_cb, decrypt_state))
     vpx_internal_error(error_info, VPX_CODEC_MEM_ERROR,
                        "Failed to allocate bool decoder %d", 1);
 }
@@ -750,7 +754,9 @@
 static size_t get_tile(const uint8_t *const data_end,
                        int is_last,
                        struct vpx_internal_error_info *error_info,
-                       const uint8_t **data) {
+                       const uint8_t **data,
+                       vpx_decrypt_cb decrypt_cb,
+                       void *decrypt_state) {
   size_t size;
 
   if (!is_last) {
@@ -758,7 +764,13 @@
       vpx_internal_error(error_info, VPX_CODEC_CORRUPT_FRAME,
                          "Truncated packet or corrupt tile length");
 
-    size = mem_get_be32(*data);
+    if (decrypt_cb) {
+      uint8_t be_data[4];
+      decrypt_cb(decrypt_state, *data, be_data, 4);
+      size = mem_get_be32(be_data);
+    } else {
+      size = mem_get_be32(*data);
+    }
     *data += 4;
 
     if (size > (size_t)(data_end - *data))
@@ -804,7 +816,8 @@
     for (tile_col = 0; tile_col < tile_cols; ++tile_col) {
       const int last_tile = tile_row == tile_rows - 1 &&
                             tile_col == tile_cols - 1;
-      const size_t size = get_tile(data_end, last_tile, &cm->error, &data);
+      const size_t size = get_tile(data_end, last_tile, &cm->error, &data,
+                                   pbi->decrypt_cb, pbi->decrypt_state);
       TileBuffer *const buf = &tile_buffers[tile_row][tile_col];
       buf->data = data;
       buf->size = size;
@@ -823,7 +836,8 @@
       TileInfo tile;
 
       vp9_tile_init(&tile, cm, tile_row, col);
-      setup_token_decoder(buf->data, data_end, buf->size, &cm->error, &r);
+      setup_token_decoder(buf->data, data_end, buf->size, &cm->error, &r,
+                          pbi->decrypt_cb, pbi->decrypt_state);
       decode_tile(pbi, &tile, &r);
 
       if (last_tile)
@@ -921,7 +935,8 @@
   // Load tile data into tile_buffers
   for (n = 0; n < tile_cols; ++n) {
     const size_t size =
-        get_tile(data_end, n == tile_cols - 1, &cm->error, &data);
+        get_tile(data_end, n == tile_cols - 1, &cm->error, &data,
+                 pbi->decrypt_cb, pbi->decrypt_state);
     TileBuffer *const buf = &tile_buffers[n];
     buf->data = data;
     buf->size = size;
@@ -962,7 +977,8 @@
       tile_data->xd.corrupted = 0;
       vp9_tile_init(tile, tile_data->cm, 0, buf->col);
       setup_token_decoder(buf->data, data_end, buf->size, &cm->error,
-                          &tile_data->bit_reader);
+                          &tile_data->bit_reader, pbi->decrypt_cb,
+                          pbi->decrypt_state);
       init_macroblockd(cm, &tile_data->xd);
       vp9_zero(tile_data->xd.dqcoeff);
 
@@ -1163,7 +1179,8 @@
   vp9_reader r;
   int k;
 
-  if (vp9_reader_init(&r, data, partition_size))
+  if (vp9_reader_init(&r, data, partition_size, pbi->decrypt_cb,
+                      pbi->decrypt_state))
     vpx_internal_error(&cm->error, VPX_CODEC_MEM_ERROR,
                        "Failed to allocate bool decoder 0");
 
@@ -1255,14 +1272,36 @@
 }
 #endif  // NDEBUG
 
+static struct vp9_read_bit_buffer* init_read_bit_buffer(
+    VP9Decoder *pbi,
+    struct vp9_read_bit_buffer *rb,
+    const uint8_t *data,
+    const uint8_t *data_end,
+    uint8_t *clear_data /* buffer size MAX_VP9_HEADER_SIZE */) {
+  rb->bit_offset = 0;
+  rb->error_handler = error_handler;
+  rb->error_handler_data = &pbi->common;
+  if (pbi->decrypt_cb) {
+    const int n = (int)MIN(MAX_VP9_HEADER_SIZE, data_end - data);
+    pbi->decrypt_cb(pbi->decrypt_state, data, clear_data, n);
+    rb->bit_buffer = clear_data;
+    rb->bit_buffer_end = clear_data + n;
+  } else {
+    rb->bit_buffer = data;
+    rb->bit_buffer_end = data_end;
+  }
+  return rb;
+}
+
 int vp9_decode_frame(VP9Decoder *pbi,
                      const uint8_t *data, const uint8_t *data_end,
                      const uint8_t **p_data_end) {
   VP9_COMMON *const cm = &pbi->common;
   MACROBLOCKD *const xd = &pbi->mb;
-
-  struct vp9_read_bit_buffer rb = { data, data_end, 0, cm, error_handler };
-  const size_t first_partition_size = read_uncompressed_header(pbi, &rb);
+  struct vp9_read_bit_buffer rb = { 0 };
+  uint8_t clear_data[MAX_VP9_HEADER_SIZE];
+  const size_t first_partition_size = read_uncompressed_header(pbi,
+      init_read_bit_buffer(pbi, &rb, data, data_end, clear_data));
   const int keyframe = cm->frame_type == KEY_FRAME;
   const int tile_rows = 1 << cm->log2_tile_rows;
   const int tile_cols = 1 << cm->log2_tile_cols;
@@ -1270,9 +1309,9 @@
   xd->cur_buf = new_fb;
 
   if (!first_partition_size) {
-      // showing a frame directly
-      *p_data_end = data + 1;
-      return 0;
+    // showing a frame directly
+    *p_data_end = data + 1;
+    return 0;
   }
 
   if (!pbi->decoded_key_frame && !keyframe)
--- a/vp9/decoder/vp9_decoder.h
+++ b/vp9/decoder/vp9_decoder.h
@@ -56,6 +56,9 @@
   int num_tile_workers;
 
   VP9LfSync lf_row_sync;
+
+  vpx_decrypt_cb decrypt_cb;
+  void *decrypt_state;
 } VP9Decoder;
 
 void vp9_initialize_dec();
--- a/vp9/decoder/vp9_reader.c
+++ b/vp9/decoder/vp9_reader.c
@@ -18,7 +18,11 @@
 // Even relatively modest values like 100 would work fine.
 #define LOTS_OF_BITS 0x40000000
 
-int vp9_reader_init(vp9_reader *r, const uint8_t *buffer, size_t size) {
+int vp9_reader_init(vp9_reader *r,
+                    const uint8_t *buffer,
+                    size_t size,
+                    vpx_decrypt_cb decrypt_cb,
+                    void *decrypt_state) {
   if (size && !buffer) {
     return 1;
   } else {
@@ -27,6 +31,8 @@
     r->value = 0;
     r->count = -8;
     r->range = 255;
+    r->decrypt_cb = decrypt_cb;
+    r->decrypt_state = decrypt_state;
     vp9_reader_fill(r);
     return vp9_read_bit(r) != 0;  // marker bit
   }
@@ -35,13 +41,22 @@
 void vp9_reader_fill(vp9_reader *r) {
   const uint8_t *const buffer_end = r->buffer_end;
   const uint8_t *buffer = r->buffer;
+  const uint8_t *buffer_start = buffer;
   BD_VALUE value = r->value;
   int count = r->count;
   int shift = BD_VALUE_SIZE - CHAR_BIT - (count + CHAR_BIT);
   int loop_end = 0;
-  const int bits_left = (int)((buffer_end - buffer) * CHAR_BIT);
-  const int x = shift + CHAR_BIT - bits_left;
+  const size_t bytes_left = buffer_end - buffer;
+  const size_t bits_left = bytes_left * CHAR_BIT;
+  const int x = (int)(shift + CHAR_BIT - bits_left);
 
+  if (r->decrypt_cb) {
+    size_t n = MIN(sizeof(r->clear_buffer), bytes_left);
+    r->decrypt_cb(r->decrypt_state, buffer, r->clear_buffer, (int)n);
+    buffer = r->clear_buffer;
+    buffer_start = r->clear_buffer;
+  }
+
   if (x >= 0) {
     count += LOTS_OF_BITS;
     loop_end = x;
@@ -55,7 +70,10 @@
     }
   }
 
-  r->buffer = buffer;
+  // NOTE: Variable 'buffer' may not relate to 'r->buffer' after decryption,
+  // so we increase 'r->buffer' by the amount that 'buffer' moved, rather than
+  // assign 'buffer' to 'r->buffer'.
+  r->buffer += buffer - buffer_start;
   r->value = value;
   r->count = count;
 }
--- a/vp9/decoder/vp9_reader.h
+++ b/vp9/decoder/vp9_reader.h
@@ -16,6 +16,7 @@
 
 #include "./vpx_config.h"
 #include "vpx_ports/mem.h"
+#include "vpx/vp8dx.h"
 #include "vpx/vpx_integer.h"
 
 #include "vp9/common/vp9_prob.h"
@@ -31,12 +32,19 @@
 typedef struct {
   const uint8_t *buffer_end;
   const uint8_t *buffer;
+  uint8_t clear_buffer[sizeof(BD_VALUE) + 1];
   BD_VALUE value;
   int count;
   unsigned int range;
+  vpx_decrypt_cb decrypt_cb;
+  void *decrypt_state;
 } vp9_reader;
 
-int vp9_reader_init(vp9_reader *r, const uint8_t *buffer, size_t size);
+int vp9_reader_init(vp9_reader *r,
+                    const uint8_t *buffer,
+                    size_t size,
+                    vpx_decrypt_cb decrypt_cb,
+                    void *decrypt_state);
 
 void vp9_reader_fill(vp9_reader *r);
 
--- a/vp9/vp9_dx_iface.c
+++ b/vp9/vp9_dx_iface.c
@@ -43,6 +43,8 @@
   int                     dbg_color_b_modes_flag;
   int                     dbg_display_mv_flag;
 #endif
+  vpx_decrypt_cb          decrypt_cb;
+  void                   *decrypt_state;
   vpx_image_t             img;
   int                     img_setup;
   int                     img_avail;
@@ -94,9 +96,13 @@
   return VPX_CODEC_OK;
 }
 
-static vpx_codec_err_t decoder_peek_si(const uint8_t *data,
-                                       unsigned int data_sz,
-                                       vpx_codec_stream_info_t *si) {
+static vpx_codec_err_t decoder_peek_si_internal(const uint8_t *data,
+                                                unsigned int data_sz,
+                                                vpx_codec_stream_info_t *si,
+                                                vpx_decrypt_cb decrypt_cb,
+                                                void *decrypt_state) {
+  uint8_t clear_buffer[9];
+
   if (data_sz <= 8)
     return VPX_CODEC_UNSUP_BITSTREAM;
 
@@ -106,6 +112,12 @@
   si->is_kf = 0;
   si->w = si->h = 0;
 
+  if (decrypt_cb) {
+    data_sz = MIN(sizeof(clear_buffer), data_sz);
+    decrypt_cb(decrypt_state, data, clear_buffer, data_sz);
+    data = clear_buffer;
+  }
+
   {
     struct vp9_read_bit_buffer rb = { data, data + data_sz, 0, NULL, NULL };
     const int frame_marker = vp9_rb_read_literal(&rb, 2);
@@ -159,6 +171,12 @@
   return VPX_CODEC_OK;
 }
 
+static vpx_codec_err_t decoder_peek_si(const uint8_t *data,
+                                       unsigned int data_sz,
+                                       vpx_codec_stream_info_t *si) {
+  return decoder_peek_si_internal(data, data_sz, si, NULL, NULL);
+}
+
 static vpx_codec_err_t decoder_get_si(vpx_codec_alg_priv_t *ctx,
                                       vpx_codec_stream_info_t *si) {
   const size_t sz = (si->sz >= sizeof(vp9_stream_info_t))
@@ -264,7 +282,8 @@
   // of the heap.
   if (!ctx->si.h) {
     const vpx_codec_err_t res =
-        ctx->base.iface->dec.peek_si(*data, data_sz, &ctx->si);
+        decoder_peek_si_internal(*data, data_sz, &ctx->si, ctx->decrypt_cb,
+                                 ctx->decrypt_state);
     if (res != VPX_CODEC_OK)
       return res;
   }
@@ -278,6 +297,11 @@
     ctx->decoder_init = 1;
   }
 
+  // Set these even if already initialized.  The caller may have changed the
+  // decrypt config between frames.
+  ctx->pbi->decrypt_cb = ctx->decrypt_cb;
+  ctx->pbi->decrypt_state = ctx->decrypt_state;
+
   cm = &ctx->pbi->common;
 
   if (vp9_receive_compressed_data(ctx->pbi, data_sz, data, deadline))
@@ -296,12 +320,25 @@
   return VPX_CODEC_OK;
 }
 
+static INLINE uint8_t read_marker(vpx_decrypt_cb decrypt_cb,
+                                  void *decrypt_state,
+                                  const uint8_t *data) {
+  if (decrypt_cb) {
+    uint8_t marker;
+    decrypt_cb(decrypt_state, data, &marker, 1);
+    return marker;
+  }
+  return *data;
+}
+
 static void parse_superframe_index(const uint8_t *data, size_t data_sz,
-                                   uint32_t sizes[8], int *count) {
+                                   uint32_t sizes[8], int *count,
+                                   vpx_decrypt_cb decrypt_cb,
+                                   void *decrypt_state) {
   uint8_t marker;
 
   assert(data_sz);
-  marker = data[data_sz - 1];
+  marker = read_marker(decrypt_cb, decrypt_state, data + data_sz - 1);
   *count = 0;
 
   if ((marker & 0xe0) == 0xc0) {
@@ -309,11 +346,22 @@
     const uint32_t mag = ((marker >> 3) & 0x3) + 1;
     const size_t index_sz = 2 + mag * frames;
 
-    if (data_sz >= index_sz && data[data_sz - index_sz] == marker) {
+    uint8_t marker2 = read_marker(decrypt_cb, decrypt_state,
+                                  data + data_sz - index_sz);
+
+    if (data_sz >= index_sz && marker2 == marker) {
       // found a valid superframe index
       uint32_t i, j;
       const uint8_t *x = &data[data_sz - index_sz + 1];
 
+      // frames has a maximum of 8 and mag has a maximum of 4.
+      uint8_t clear_buffer[32];
+      assert(sizeof(clear_buffer) >= frames * mag);
+      if (decrypt_cb) {
+        decrypt_cb(decrypt_state, x, clear_buffer, frames * mag);
+        x = clear_buffer;
+      }
+
       for (i = 0; i < frames; i++) {
         uint32_t this_sz = 0;
 
@@ -339,23 +387,31 @@
   if (data == NULL || data_sz == 0)
     return VPX_CODEC_INVALID_PARAM;
 
-  parse_superframe_index(data, data_sz, sizes, &frames_this_pts);
+  parse_superframe_index(data, data_sz, sizes, &frames_this_pts,
+                         ctx->decrypt_cb, ctx->decrypt_state);
 
   do {
-    // Skip over the superframe index, if present
-    if (data_sz && (*data_start & 0xe0) == 0xc0) {
-      const uint8_t marker = *data_start;
-      const uint32_t frames = (marker & 0x7) + 1;
-      const uint32_t mag = ((marker >> 3) & 0x3) + 1;
-      const uint32_t index_sz = 2 + mag * frames;
+    if (data_sz) {
+      uint8_t marker = read_marker(ctx->decrypt_cb, ctx->decrypt_state,
+                                   data_start);
+      // Skip over the superframe index, if present
+      if ((marker & 0xe0) == 0xc0) {
+        const uint32_t frames = (marker & 0x7) + 1;
+        const uint32_t mag = ((marker >> 3) & 0x3) + 1;
+        const uint32_t index_sz = 2 + mag * frames;
 
-      if (data_sz >= index_sz && data_start[index_sz - 1] == marker) {
-        data_start += index_sz;
-        data_sz -= index_sz;
-        if (data_start < data_end)
-          continue;
-        else
-          break;
+        if (data_sz >= index_sz) {
+          uint8_t marker2 = read_marker(ctx->decrypt_cb, ctx->decrypt_state,
+                                        data_start + index_sz - 1);
+          if (marker2 == marker) {
+            data_start += index_sz;
+            data_sz -= index_sz;
+            if (data_start < data_end)
+              continue;
+            else
+              break;
+          }
+        }
       }
     }
 
@@ -381,8 +437,13 @@
       break;
 
     // Account for suboptimal termination by the encoder.
-    while (data_start < data_end && *data_start == 0)
+    while (data_start < data_end) {
+      uint8_t marker3 = read_marker(ctx->decrypt_cb, ctx->decrypt_state,
+                                    data_start);
+      if (marker3)
+        break;
       data_start++;
+    }
 
     data_sz = (unsigned int)(data_end - data_start);
   } while (data_start < data_end);
@@ -565,6 +626,15 @@
   return VPX_CODEC_OK;
 }
 
+static vpx_codec_err_t ctrl_set_decryptor(vpx_codec_alg_priv_t *ctx,
+                                          int ctrl_id,
+                                          va_list args) {
+  vpx_decrypt_init *init = va_arg(args, vpx_decrypt_init *);
+  ctx->decrypt_cb = init ? init->decrypt_cb : NULL;
+  ctx->decrypt_state = init ? init->decrypt_state : NULL;
+  return VPX_CODEC_OK;
+}
+
 static vpx_codec_ctrl_fn_map_t decoder_ctrl_maps[] = {
   {VP8_COPY_REFERENCE,            ctrl_copy_reference},
 
@@ -576,6 +646,7 @@
   {VP8_SET_DBG_COLOR_B_MODES,     ctrl_set_dbg_options},
   {VP8_SET_DBG_DISPLAY_MV,        ctrl_set_dbg_options},
   {VP9_INVERT_TILE_DECODE_ORDER,  ctrl_set_invert_tile_order},
+  {VPXD_SET_DECRYPTOR,            ctrl_set_decryptor},
 
   // Getters
   {VP8D_GET_LAST_REF_UPDATES,     ctrl_get_last_ref_updates},
--- a/vpx/vp8dx.h
+++ b/vpx/vp8dx.h
@@ -66,10 +66,11 @@
   VP8D_GET_LAST_REF_USED,
 
   /** decryption function to decrypt encoded buffer data immediately
-   * before decoding. Takes a vp8_decrypt_init, which contains
+   * before decoding. Takes a vpx_decrypt_init, which contains
    * a callback function and opaque context pointer.
    */
-  VP8D_SET_DECRYPTOR,
+  VPXD_SET_DECRYPTOR,
+  VP8D_SET_DECRYPTOR = VPXD_SET_DECRYPTOR,
 
   /** control function to get the display dimensions for the current frame. */
   VP9D_GET_DISPLAY_SIZE,
@@ -80,20 +81,29 @@
   VP8_DECODER_CTRL_ID_MAX
 };
 
+/** Decrypt n bytes of data from input -> output, using the decrypt_state
+ *  passed in VPXD_SET_DECRYPTOR.
+ */
+typedef void (*vpx_decrypt_cb)(void *decrypt_state, const unsigned char *input,
+                               unsigned char *output, int count);
+
 /*!\brief Structure to hold decryption state
  *
  * Defines a structure to hold the decryption state and access function.
  */
-typedef struct vp8_decrypt_init {
-    /** Decrypt n bytes of data from input -> output, using the decrypt_state
-     *  passed in VP8D_SET_DECRYPTOR.
-     */
-    void (*decrypt_cb)(void *decrypt_state, const unsigned char *input,
-                       unsigned char *output, int count);
+typedef struct vpx_decrypt_init {
+    /*! Decrypt callback. */
+    vpx_decrypt_cb decrypt_cb;
+
     /*! Decryption state. */
     void *decrypt_state;
-} vp8_decrypt_init;
+} vpx_decrypt_init;
 
+/*!\brief A deprecated alias for vpx_decrypt_init.
+ */
+typedef vpx_decrypt_init vp8_decrypt_init;
+
+
 /*!\brief VP8 decoder control function parameter type
  *
  * Defines the data types that VP8D control functions take. Note that
@@ -102,11 +112,12 @@
  */
 
 
-VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_UPDATES,   int *)
-VPX_CTRL_USE_TYPE(VP8D_GET_FRAME_CORRUPTED,    int *)
-VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_USED,      int *)
-VPX_CTRL_USE_TYPE(VP8D_SET_DECRYPTOR,          vp8_decrypt_init *)
-VPX_CTRL_USE_TYPE(VP9D_GET_DISPLAY_SIZE,       int *)
+VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_UPDATES,    int *)
+VPX_CTRL_USE_TYPE(VP8D_GET_FRAME_CORRUPTED,     int *)
+VPX_CTRL_USE_TYPE(VP8D_GET_LAST_REF_USED,       int *)
+VPX_CTRL_USE_TYPE(VPXD_SET_DECRYPTOR,           vpx_decrypt_init *)
+VPX_CTRL_USE_TYPE(VP8D_SET_DECRYPTOR,           vpx_decrypt_init *)
+VPX_CTRL_USE_TYPE(VP9D_GET_DISPLAY_SIZE,        int *)
 VPX_CTRL_USE_TYPE(VP9_INVERT_TILE_DECODE_ORDER, int)
 
 /*! @} - end defgroup vp8_decoder */