Understanding Vector Quantized Variational Autoencoders (VQ-VAE)

Shashank Yadav
5 min readSep 1, 2019

--

From my most recent escapade into the deep learning literature I present to you this paper by Oord et. al. which presents the idea of using discrete latent embeddings for variational auto encoders. The proposed model is called Vector Quantized Variational Autoencoders (VQ-VAE). I really liked the idea and the results that came with it but found surprisingly few resources to develop an understanding. Here’s an attempt to help other who might venture into this domain after me.

Like numerous other people Variational Autoencoders (VAEs) are my choice of generative models. Unlike GANs they are easier to train and reason about (No offence intended dear GANs). Going forward I assume you have some understanding of VAEs. If you don’t I suggest going through this post, I found it to be one of the simpler ones.

Basic Idea

So what is the big deal here? As you might recall, VAEs consist of 3 parts:

  1. An encoder network that parametrizes the posterior q(z|x)over latents
  2. A prior distribution p(z)
  3. A decoder with distribution p(x|z) over input data

Typically we assume this prior and posterior to be normally distributed with diagonal variance. The encoder is then used to predict the mean and variances of the posterior.

In the proposed work however, the authors use discrete latent variables (instead of a continuous normal distribution). The posterior and prior distributions are categorical, and the samples drawn from these distributions index an embedding table. In other words:

  1. Encoders model a categorical distribution, sampling from which you get integral values
  2. These integral values are used to index a dictionary of embeddings
  3. The indexed values are then passed on to the decoder

Why do it?

Many important real-world objects are discrete. For example in images we might have categories like “Cat”, “Car”, etc. and it might not make sense to interpolate between these categories. Discrete representations are also easier to model since each category has a single value whereas if we had a continuous latent space then we will need to normalize this density function and learn the dependencies between the different variables which could be very complex.

Moreover, the authors claim that their model doesn’t suffer from posterior collapse, an issue that plagues VAEs in general and prevents making use of complex decoders.

Architecture

Fig 1: VQ-VAE Architecture

Fig 1 shows various top level components in the architecture along with dimensions at each step. Assuming we run our model over image data, here’s some nomenclature we’ll be using going forward:

n : batch size

h: image height

w: image width

c: number of channels in the input image

d: number of channels in the hidden state

Now the working can be explained in the following steps:

  1. Encoder takes in images x: (n, h, w, c) and give outputs z_e: (n, h, w, d)
  2. Vector Quantization layer takes z_e and selects embeddings from a dictionary based on distance and outputs z_q (we’ll discuss more about this later don’t worry)
  3. Decoder consumes z_q and outputs x’ trying to recreate input x

Vector Quantization Layer

Fig 2: Vector Quantization Layer

The working of VQ layer can be explained in six steps as numbered in Fig 2:

  1. Reshape: all dimensions except the last one are combined into one so that we have n*h*w vectors each of dimensionality d
  2. Calculating distances: for each of the n*h*w vectors we calculate distance from each of k vectors of the embedding dictionary to obtain a matrix of shape (n*h*w, k)
  3. Argmin: for each of the n*h*w vectors we find the the index of closest of the k vectors from dictionary
  4. Index from dictionary: index the closest vector from the dictionary for each of n*h*w vectors
  5. Reshape: convert back to shape (n, h, w, d)
  6. Copying gradients: If you followed up till now you’d realize that it’s not possible to train this architecture through backpropagation as the gradient won’t flow through argmin. Hence we try to approximate by copying the gradients from z_q back to z_e. In this way we’re not actually minimizing the loss function but are still able to pass some information back for training.

Loss Function

The total loss is actually composed of three components:

  1. Reconstruction loss: which optimizes the decoder and encoder:
reconstruction_loss = -log( p(x|z_q) )

2. Codebook loss: due to the fact that gradients bypass the embedding, we use a dictionary learning algorithm which uses an l2 error to move the embedding vectors e towards the encoder output:

codebook_loss =  ‖ sg[z_e(x)]− e ‖^2
// sg represents stop gradient operator meaning no gradient
// flows through whatever it's applied on

3. Commitment loss: since the volume of the embedding space is dimensionless, it can grow arbirtarily if the embeddings e do not train as fast as the encoder parameters, and thus we add a commitment loss to make sure that the encoder commits to an embedding

commitment_loss = β‖ z_e(x)− sg[e] ‖^2
// β is a hyperparameter that controls how much we want to weigh
// commitment loss compared to other components

Important: Note that we’re training both the dictionary embeddings as well as encoder and decoder network

Results

The paper presents state of the art results on images, text as well as videos.

VQ-VAE samples (left) and BigGAN deep samples (right) trained on ImageNet.
VQ-VAE generated facial images.

You can find the results on audio here: https://avdnoord.github.io/homepage/vqvae/

Code

There are several implementations available in Tensorflow, Pytorch as well as keras. You can look through them here.

Conclusion

There are two main ideas to be learnt from this paper:

  1. How to train discrete latent embeddings and their importance
  2. How to approximate gradients in case of non differentiable functions

For more details go through the paper, it’d be easier to understand after going through the article. You can also play around with this Jupyter Notebook, open it in colab here. Happy learning!

Note: Feel free to ask any doubts or give feedback/suggestions. All the diagrams used here have been created by the author. Feel free to use them along with a note of acknowledgement :-)

📝 Read this story later in Journal.

👩‍💻 Wake up every Sunday morning to the week’s most noteworthy stories in Tech waiting in your inbox. Read the Noteworthy in Tech newsletter.

--

--

Shashank Yadav

Founding fractionai.xyz: decentralized platform for creating high quality labelled datasets for AI