Understanding Attention Mechanism

Shashank Yadav
5 min readFeb 5, 2019

--

Attention mechanism for sequence modelling was first introduced in the paper: Neural Machine Translation by jointly learning to align and translate, Bengio et. al. ICLR 2015. Even though the paper itself mentions the word “attention” scarcely (3 times total in 2 consecutive lines!!) the term has caught on. A lot of prominent work that came later on uses the same naming convention (Well, I for one think it’s more of a “soft memory” rather than “attention”).

This post focuses on Bengio et. al. 2015 and tries to give a step by step explanation of the (attention) model explained in their paper. Probably it’s just me but the explanation given in the paper and the diagrams that came with it left a lot to the imagination. This post tries to make understanding their great work a little easier. So here we go:

Core Idea behind Attention:

The main assumption in sequence modelling networks such as RNNs, LSTMs and GRUs is that the current state holds information for the whole of input seen so far. Hence the final state of a RNN after reading the whole input sequence should contain complete information about that sequence. This seems to be too strong a condition and too much to ask.

Fig 1: Difference between attention and normal encoder decoder architecture. 1(a) shows a vanilla architecture in which decoder just looks at the final state of encoder before it starts making its predictions. In 1(b) decoder ”attends” to every hidden state of the encoder at each time step while making prediction

Attention mechanism relax this assumption and proposes that we should look at the hidden states corresponding to the whole input sequence in order to make any prediction. But how do we decide which states to look at? Well, try learning that!!

Step by Step Walk-through

As introduced in the previous section, the task at hand is:

To “learn” how much we need to “attend” to each hidden state of the encoder.

The complete architecture is as follows:

Fig 2: Attention Mechanism

The network is shown in a state when the encoder (lower part of the Fig 2) has computed the hidden states hⱼ corresponding to each input Xⱼ and the decoder (top part of Fig 2) has run for t-1 steps and is now going to produce output for time step t.

Don’t get too nervous looking at this seemingly difficult figure. We’ll be going through each component one by one. Broadly, the whole process can be divided into four steps:

  1. Encoding
  2. Computing Attention weights / Alignment
  3. Creating context vector
  4. Decoding / Translation

Let’s have a look!!

Encoding:

Fig 3: Looking at Bidirectional Encoder

Hence the hidden state for the jᵗʰ input hⱼ is the concatenation of jᵗʰ hidden states of forward and backward RNNs. We’ll be using a weighted linear combination of all of these hⱼs to make predictions at each step of the decoder. The decoder output length might be same or different than that of encoder.

Computing Attention Weights / Alignment:

Fig 4: Looking at how attention weights are computed

At each time step t of the decoder the amount of attention to be paid to the hidden encoder unit hⱼ is denoted by αₜⱼ and calculated as a function of both hⱼ and previous hidden state of decoder s ₜ-₁:

In the paper a is parametrized as a feedforward neural network that runs for all j at the decoding time step t. Note the 0 ≤ αₜⱼ 1 and that all ∑ⱼ αₜⱼ = 1 because of the softmax on eₜⱼ. These αₜⱼ can be visualized as the attention paid by decoder at time step t to hidden encoder unit hⱼ.

Computing context vector:

Time to make use of the attention weights we’ve computed in the preceding step!! Here’s another figure to help understand:

Fig 5: Looking at how to calculate context vector

The context vector is simply a linear combination of the hidden weights hⱼ weighted by the attention values αₜⱼ that we computed:

From the equation we can see that αₜⱼ determines how much hⱼ affects the context cₜ. Higher the value, higher the impact of hⱼ on the context for time t.

Decoding / Translation:

Fig 6: Looking at how decoding happens

We are nearly there! All there remains is to use the context vector cₜ we worked so hard to compute, along with the previous hidden state of the decoder s ₜ-₁ and the previous output yₜ-₁and use all of them to compute the new hidden state and output of the decoder: sₜ and yₜ respectively.

In the paper, authors have used a GRU cell for f and a similar function for g. These are higher level details and if you’re interested I’d suggest you have a look into Appendix A of the paper. For details about training look into Appendix B of the paper.

So yeah, it’s over. This is it. Attention mastered! (Probably at some level of understanding we did). Thanks a lot for sticking by till now.

Conclusion

This was the first paper that introduced that concept of attention and several works have come since then that have built on top of this. This idea has been the most useful in NLP where the state of the art transformer networks that utilize self attention have taken the field by storm. This tutorial would be the first step for taking deep dive into the field. I hope this was helpful. Looking forward to your comments and suggestions.

Note: All the diagrams used here have been created by the author. Feel free to use them along with a note of acknowledgement :-)

--

--

Shashank Yadav
Shashank Yadav

Written by Shashank Yadav

Founding fractionai.xyz: decentralized platform for creating high quality labelled datasets for AI

Responses (5)