In this post, I implement Multi-Head Attention in 3 stages, following Sebastian Raschka's excellent instruction in Chapter 16 of his text on Machine Learning and Deep Learning.
The 3 stages are:
- Basic Form of Attention - really easy to understand, but not used in practice. A good starting point to understand that attention is just a weighted average of the input vectors, where the weights are determined by the similarity of the input vectors. In linear algebra terms, this is a linear transformation of the input vectors, where the transformation matrix is the similarity matrix.
- Parameterized Attention - this introduces trainable parameters in the form of projection matrices query, key, and value. This is the form of self-attention used in the Transformer architecture.
- Multi-Headed Attention - this introduces "heads" which are parallel attention mechanisms and is analogous to channels in a convolutional neural network.
In all cases, this is also a good exercise in matrix multiplication and broadcasting.
import torch import torch.nn.functional as F torch.manual_seed(123)
<torch._C.Generator at 0x12516e690>
Create the matrix of word embeddings to be used as input to the attention mechanism. This has dimensions (sequence_length, embedding_dimension).
d_sentence = 12 #number of words in the sentence d_embedding = 14 # embedding length sentence = torch.randperm(d_sentence) embeddings = torch.nn.Embedding(d_sentence, d_embedding) embedding_sentence = embeddings(sentence).detach() embedding_sentence.shape
# distance matrix - omega_ij represents the distance between word i and word j omega = embedding_sentence.matmul(embedding_sentence.T)
# attention weights - normalization of the distance matrix attention_weights = F.softmax(omega, dim=1) attention_weights.shape
# context vector - weighted sum of the embedding vectors context_vector = attention_weights.matmul(embedding_sentence) context_vector.shape
# initialize projection matrices d_query = d_key = d_value = 10 # dimension of query, key and value projection matrices U_query = torch.rand(d_query, d_embedding) U_key = torch.rand(d_key, d_embedding) U_value = torch.rand(d_value, d_embedding)
query = U_query.matmul(embedding_sentence.T).T key = U_key.matmul(embedding_sentence.T).T values = U_value.matmul(embedding_sentence.T).T query.shape, key.shape, values.shape #d_sentence x d_query
(torch.Size([12, 10]), torch.Size([12, 10]), torch.Size([12, 10]))
# distance between projected vectors omega = query.matmul(key.T) #d_sentence x d_sentence:
# normalization attention_weights = F.softmax(omega / d_query**0.5, dim=1)
# context vector - weighted sum of the projected value vectors context_vector = attention_weights.matmul(values) context_vector.shape #d_sentence x d_value
The primary challenge is to multiply in such a way to perform the head dimension in parallel with everything as above constant. The other difference is the use of a linear layer to collapse the head dimension into the embedding dimension.
# initialize projection matrices h = 8 # number of heads multihead_U_query = torch.rand(h, d_query, d_embedding) multihead_U_key = torch.rand(h, d_key, d_embedding) multihead_U_value = torch.rand(h, d_value, d_embedding)
multihead_query = multihead_U_query.matmul(embedding_sentence.T).transpose(2,1) multihead_key = multihead_U_key.matmul(embedding_sentence.T).transpose(2,1) multihead_values = multihead_U_value.matmul(embedding_sentence.T).transpose(2,1) multihead_query.shape, multihead_key.shape, multihead_values.shape #h x d_sentence x d_query
(torch.Size([8, 12, 10]), torch.Size([8, 12, 10]), torch.Size([8, 12, 10]))
# distance between projected vectors for each of the 8 heads omega = multihead_query.matmul(multihead_key.transpose(2,1)) omega.shape #h x d_sentence x d_sentence
torch.Size([8, 12, 12])
# normalization attention_weights = F.softmax(omega / d_query**0.5, dim=2) attention_weights.shape #h x d_sentence x d_sentence
torch.Size([8, 12, 12])
# context vector - with each head separate context_vector = attention_weights.matmul(multihead_values) context_vector.shape # h x d_sentence x d_value
torch.Size([8, 12, 10])
# adding a linear layer to combine the heads for each word in the sentence linear = torch.nn.Linear(h*d_value, d_embedding)
# flatten the first two dimensions of the tensor context_vector = context_vector.view(-1, h*d_value) context_vector.shape # d_sentence x (h*d_value)
context_vector_linear = linear(context_vector) context_vector_linear.shape # d_sentence x d_embedding