Exploring attention
Can we solve a seq2seq task by adding attention to our model? And will it beat an LSTM?
- Finding a proper case study for attention
- Basic cnn model
- Attention
- CNN + attention
- LSTM
- Final thoughts
from fastai2.vision.all import *
import altair as alt
from itertools import product
I have been looking into attention and transformers lately, and wanted to explore it with an example data set - solving a problem is the best way to learn in my opinion. Most use cases I come across, however, seem to be huge language models, which are just too complex to be a good starting point for learning. One day this popped up in my twitter feed, and seemed like just the right thing for further examination:
To illustrate attention mechanisms, I made a toy task seq2seq task and implemented an attention layer from scratch. It worked beautifully (thread)
— François Fleuret (@francoisfleuret) May 19, 2020
The code below creates the actual data which we will investigate in this post. The code is written by François Fleuret, which was kind to let met reuse the code. I've modified it slightly so that it fits the notebook format. You can find the original code here.
#collapse
# Author: François Fleuret
seq_height_min, seq_height_max = 1.0, 25.0
seq_width_min, seq_width_max = 5.0, 11.0
seq_length = 100
def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
st = torch.arange(seq_length).float()
st = st[None, :, None]
tr = tr[:, None, :, :]
bx = bx[:, None, :, :]
xtr = torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]
x = torch.cat((xtr, xbx), 2)
# u = x.sign()
u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)
collisions = (u.sum(2) > 1).max(1).values
y = x.max(2).values
return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions
def generate_sequences(nb, group_by_locations=False):
# Position / height / width
tr = torch.empty(nb, 2, 3)
tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
tr[:, :, 2].uniform_(seq_width_min, seq_width_max)
bx = torch.empty(nb, 2, 3)
bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
bx[:, :, 2].uniform_(seq_width_min, seq_width_max)
if group_by_locations:
a = torch.cat((tr, bx), 1)
v = a[:, :, 0].sort(1).values[:, 2:3]
mask_left = (a[:, :, 0] < v).float()
h_left = (a[:, :, 1] * mask_left).sum(1) / 2
h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
valid = (h_left - h_right).abs() > 4
else:
valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)
input, collisions = positions_to_sequences(tr, bx)
if group_by_locations:
a = torch.cat((tr, bx), 1)
v = a[:, :, 0].sort(1).values[:, 2:3]
mask_left = (a[:, :, 0] < v).float()
h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
tr, bx = a.split(2, 1)
else:
tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)
targets, _ = positions_to_sequences(tr, bx)
valid = valid & ~collisions
tr = tr[valid]
bx = bx[valid]
input = input[valid][:, None, :]
targets = targets[valid][:, None, :]
if input.size(0) < nb:
input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
input = torch.cat((input, input2), 0)
targets = torch.cat((targets, targets2), 0)
tr = torch.cat((tr, tr2), 0)
bx = torch.cat((bx, bx2), 0)
return input, targets, tr, bx
We use the above code to create our data set: 25 000 and 1000 train and test samples, respectively, of 1D sequences with a length of 100:
torch.manual_seed(42)
train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
train_input.shape, train_targets.shape, test_input.shape, test_targets.shape
Our input data is simply a series of floats:
print(train_input[0,0,:5].tolist())
Let's also plot a pattern to better understand what our task is going to be:
#collapse
def plot_pattern(inp, target, input_type='input', ax=None, include_targ = True, legend=True):
if not ax: fix, ax = plt.subplots()
inp = to_np(tensor(inp).squeeze())
target = to_np(tensor(target).squeeze())
ax.plot(inp, label=input_type)
if include_targ: ax.plot(target, label='target')
if legend: ax.legend();
plot_pattern(train_input[123], train_targets[123])
The target output should be the average height of the input for corresponding geometric shapes:The two yellow squares represent the average height of the blue squares, and similary for the triangles.
Our first attempt at solving this task will be a regular (1d) convnet. We will make it with the fastai library. The architecture and training will be identical to the one François used in his example (see tweet above).
In fastai the dataloader holds both the input and the target data. So let's combine our input and target data to a combined training dataset. This step makes the creation of the dataloader a bit easier.
train = torch.cat((train_input, train_targets), dim=-1)
train.shape
We also want to normalize our data, so we'll create a normalize
helper function.
mean, std = train.mean(), train.std()
def norm(x, m, s): return (x-m)/s
normalize = partial(norm, m=mean, s=std)
Finally we make our dataloader with the datablock api. Think of this as an assembly line which we pass our data source through:
dls = DataBlock(
get_x = lambda row: row[:, :100], # the first 100 cols of a row hold our training data
get_y = lambda row: row[:, 100:], # and the final 100 cols in the row holds our target data
splitter=RandomSplitter(valid_pct=0.2),
batch_tfms=normalize
).dataloaders(source=train, bs=128)
Next we will grab a batch from the dataloader, and expect to find the data normalized:
xb, yb = next(iter(dls.train))
xb.shape, yb.shape, xb.mean().item(), xb.std().item()
We get a batch of 128 items both for our input and target, with zero mean and unit variance. We can also plot an item from the our batch to verify that everything is ok:
plot_pattern(xb[0], yb[0])
This is the same model that Francois used in his example:
#collapse
def get_model_cnn(ks, stride=1):
return sequential(
nn.Conv1d(1, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 1, kernel_size=ks, stride=stride, padding=ks//2),
)
basic_conv = get_model_cnn(ks=5)
learn_cnn = Learner(dls, model=basic_conv, loss_func=MSELossFlat())
Our cnn model has about 62 000 parameters:
learn_cnn.summary()
We'll train our model for 50 epochs with a flat learning rate of 1e-3. It's a small model, so training won't take very long. We could definitively train for longer without waiting all day. Fastai prints out the metrics from every epoch by default, so we'll use the no_logging()
context manager to make the output a bit more compact:
with learn_cnn.no_logging():
learn_cnn.fit(50, 1e-3)
learn_cnn.recorder.plot_loss(skip_start=100)
Our model seems to be learning something, but how good are the predictions?
Let's check a few predictions from the test set. We'll add the test set to our data loader first. Note that the test_dl()
method expects a numpy ndarray, so we transform our input tensors with to_np()
before passing it:
test_dl = learn_cnn.dls.test_dl(to_np(test_input))
preds_cnn, _ = learn_cnn.get_preds(dl=test_dl)
preds_cnn.shape
Then we grab a few random indexes to plot. We'll be reusing these for other models later:
idxs = [84, 701, 27, 493]
The predictions aren't particulary good at this point:
fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
plot_pattern(preds_cnn[idx], normalize(test_targets[idx]), input_type='prediction', ax=ax)
We didn't get great results from the above convnet. Is the task too hard or is our model broken in some way? A standard way of checking the health of a model is to try to overfit a single batch. If the model can't do this something funky is going on.
Let's take 20 items and see if we're ablet to overfit that. We'll use the standard PyTorch training loop:
model = get_model_cnn(ks=5)
optimizer = Adam(model.parameters(), lr=1e-3)
mse_loss = MSELossFlat()
inp, target = normalize(train[:20,:,:100]), normalize(train[:20,:,100:])
for epoch in range(5001):
optimizer.zero_grad()
output = model(inp)
loss = mse_loss(output, target)
if epoch%1000==0: print(f'epoch: {epoch}, loss: {loss.item()}')
loss.backward()
optimizer.step()
The loss is slowly approaching zero, so I it seems that our model is working. But it needs a crazy number of epochs to overfit just 20 items, so doing well on the entire data set would be very difficult. There are many things we do to improve the model - make it deeper or wider, or perhaps increasing the kernel size would help? But there is also this thing called attention!
Attention has become a central concept in deep learning with the rise of the transformer architecture. There are several good resources that explains the concept. I particularly like the following:
- The Rasa white board video series on attention. Higly reccomended!
- Jay Allamar's blog has several posts on attentions and transformers
- Yannick Kilcher has lot's of videos on deep learning papers, including a playlist for NLP
- Lilian Weng has a nice blog with a few posts on attention and transformers
- Peter Bloem has a nice from-scratch implementation of the transformer in PyTorch
- The annotated Transformer by Harvard NLP, and the Attention is All You Need paper
- The annotated GPT-2 by Aman Arora
The rest of the post will assume a basic understanding of the concept of attention (specifically self-attention). Below is the implementation of an attention layer and the proposed model architecture - very similar to the above cnn architecture except for a new AttentionLayer
. This code is once again written by François, but I added a few comments on intermediate tensor shapes. Let's go through it step by step.
#collapse
class AttentionLayer(nn.Module):
def __init__(self, in_channels, out_channels, key_channels):
super(AttentionLayer, self).__init__()
self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)
def forward(self, x): # x.shape = [bs x in_channels x seq_len]
Q = self.conv_Q(x) # [bs x key_channels x seq_len]
K = self.conv_K(x) # [bs x key_channels x seq_len]
V = self.conv_V(x) # [bs x out_channels x seq_len]
A = Q.permute(0, 2, 1).matmul(K).softmax(2) # [bs x seq_len x seq_len]
x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1) # [bs x out_channels x seq_len]
return x
def __repr__(self):
return self._get_name() + \
'(in_channels={}, out_channels={}, key_channels={})'.format(
self.conv_Q.in_channels,
self.conv_V.out_channels,
self.conv_K.out_channels
)
def attention(self, x):
Q = self.conv_Q(x)
K = self.conv_K(x)
return Q.permute(0, 2, 1).matmul(K).softmax(2)
def get_model_attention(ks, stride=1):
return sequential(
nn.Conv1d(1, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
AttentionLayer(in_channels=64, out_channels=64, key_channels=96), ## New attention layer
nn.ReLU(),
nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv1d(64, 1, kernel_size=ks, stride=stride, padding=ks//2)
)
Our class has a nn.Module superclass, so first we initialize the class properly.
We proceed to initialize our Q, K and V matrices. Note that conv_K
and conv_Q
have identical dimensions: in_channels
x key_channels
. We'll look into why this is a requirement later in the post. conv_V
, however, is different: in_channels
x out_channels
. Q, K and V often have the same shape in other implementations, see e.g. the fast.ai nlp course implementation, but this depends on the overall architecture that the attention layer is part of. In this particular case our attention layer is part of a specific architecture where the number of out channels need to conform to the expected number of input channels in the next layer. Finally, if Q, K and V do have identical shape we can also stack them in a single matrix and thereby improve performance.
Note that we use Conv1d layers with kernel_size=1. This means that a conv-layer is identical to a linear layer. Let's take some dummy data to compare a Conv1d(ks=1)
and a regular nn.Linear()
and see if we can change the number of channels of our input from 8 to 12:
# bs x n_channels x seq_len
x = torch.rand(16,8,50)
x.shape
If we pass x
through a conv layer with kernel_size=1 and no bias, we get our desired output:
nn.Conv1d(in_channels=8, out_channels=12, kernel_size=1, bias=False)(x).shape
We can achieve the same with a linear layer. But the linear layer expects input where the final dimension represents the input's channel dimension. And similar for the output. To compare with our conv layer we need to permute the data in each step:
out = nn.Linear(in_features=8, out_features=12, bias=False)(x.permute(0,2,1))
out.permute(0,2,1).shape
So using a Conv1d
with ks=1 is the same as a linear layer, but seems to avoid a few annoying permutations. Huggingface's implementation of attention in GPT-2 uses a conv1D for Q, K and V for example.
There are many variations of attention, see Lillian Weng's post for an overview. Our implementation is a so called dot product attention. But the Attention is All You Need paper uses a scaled dot product attention instead. This seem to have become the norm in subsequent transformers, and the argument is that scaling will improve the gradient signal especially for long input sequences. In this particular example it probably doesn't matter much anyway since we have a fairly simple architecture and small input data.
In the forward method we calculate the matrix product of our conv_q
, conv_k
and conv_v
matrices with our input x. Since these are conv-layers we can just feed x
to them, and the superclass' forward method will do the matrix multiplication automatically. We proceed to calculate the Attention, A
by doing a dot-product with each Q and K pair of the sequence and finally doing a softmax of the final dimension (2 in this case) to normalize the attention scores. But why do we need to permute k before the matrix multiplication? Let's walk through an example step by step.
First we create sample Q, K and V conv layers:
conv_Q = nn.Conv1d(4, 8, kernel_size=1, bias=False)
conv_K = nn.Conv1d(4, 8, kernel_size=1, bias=False)
conv_V = nn.Conv1d(4, 16, kernel_size=1, bias=False)
conv_K.weight.shape, conv_Q.weight.shape, conv_V.weight.shape
X
is our mock input data. it has a batch size of 1 (keep it simple), a channel size of 8, and a sequence length of 50:
x = torch.rand((1, 4, 50))
x.shape
Then we calculate our Q, K and V matrices by feeding x to them:
qx = conv_Q(x)
kx = conv_K(x)
vx = conv_V(x)
qx.shape, kx.shape, vx.shape
Now we want to get the dot product of each key and value pair for every item in the sequence for the entire batch. If we ignore the batch dimension, qx
and kx
are shaped like 8x50 @ 8x50. If we permute the final dimension of the former we instead get 50x8 @ 8x50, which are valid dimensions for a matrix multiplication. As expected the output is seq_len
x seq_len
:
qx.permute(0, 2, 1).matmul(kx).shape
Note that if we instead permute the last matrix we get the wrong result!
qx.matmul(kx.permute(0, 2, 1)).shape
Finally we calculate our attention scores by normalizing with a softmax along the seq_len
dimension, in this case the second dimension:
A = qx.permute(0, 2, 1).matmul(kx).softmax(2)
A.shape
We can verify that the softmax normalized our attention scores to 1 for each item in the sequence:
A.sum(-1).mean().item()
Finally we want to multiply A
with our learned values vx
, but once again our matrices don't align:
A.shape, vx.shape
When we permute vx
and do the matrix multiply with A
we get the contextualized values for our input sequence. But the seq_len
dimension has been switched:
A.matmul(vx.permute(0, 2, 1)).shape
We have to permute once more to get our desired result. We also see that the resulting shape is independent of the 8 key_channels in Q and K:
A.matmul(vx.permute(0, 2, 1)).permute(0, 2, 1).shape
Note that the shape of our contextualized x is different than our input x:
x.shape
This is simply due to the fact that our attention layer is the middle layer of a particular architecture and the number of in_channels and out_channels have to match the characteristics of the architecture. In many self-attention implementations, we need to return a contextualized input with identical shape as the original input.
The code permutes the matrices several times. This is necessary to align the matrices properly for matrix multiplication, and is similar to transposing. Note that we can transpose by listing dimensions explicitly or relatively. The latter is perhaps a bit more robust to changing input data types, but the former is maybe easier to read:
tmp = torch.rand(3,4,5)
tmp.permute(0,2,1).shape, tmp.permute(0, -1,-2).shape
Transpose gives us the same result:
tmp.transpose(1,2).shape, tmp.transpose(-1, -2).shape
This method simply gives the module a string representation for printing. We can se the result in learn.model()
below, which prints the layers of the model.
This method outputs the same as the A
in the forward. This gives us a convenient way to inspect the attention scores for a particular input x.
We make a new model and train it in a similar manner as our original cnn implementation above. Note that it has slightly fewer parameters than our original cnn, 58k vs 62k.
attention_model = get_model_attention(ks=5)
learn_att = Learner(dls, model=attention_model, loss_func=MSELossFlat())
learn_att.summary()
learn_att.model
Let's train our new attention model for a similar number of epochs and the same learning rate as our basic_conv model:
with learn_att.no_logging():
learn_att.fit(50, 1e-3)
learn_att.recorder.plot_loss(skip_start=100)
Wow! The loss is much lower this time around!
We'll reuse the idxs
and the test_dl
dataloader to inspect a few predictions. The predictions from the attention model are much better than our previous cnn-model:
preds_att, _ = learn_att.get_preds(dl=test_dl)
preds_att.shape
fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
plot_pattern(preds_cnn[idx], normalize(test_targets[idx]),
input_type='prediction cnn', ax=ax, include_targ=False, legend=False)
plot_pattern(preds_att[idx], normalize(test_targets[idx]),
input_type='prediction cnn+attention', ax=ax, legend=False)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3);
Another fascinating thing with attention is that we can inspect the attention matrix. This means we can see which parts of the input sequence that the model thinks is most important for it's contextualized output sequence. In order to get it we have to:
- Run a batch from the test set through the a part of the model up to the attention layer
- Grab the attention layer and run it's attention-method to get the attention matrix
We can inspect the individual parts of our model with the .childeren()
method, and also slice the model into separate parts:
list(attention_model.children())
Let's grab the first part and the attention layer:
base_model = attention_model[0:4]
attention_layer = attention_model[4]
base_model, attention_layer
Now we pass our normalized test_input
through the first part of the model, ensuring it's on the same device as our model:
base_output = base_model(normalize(test_input).to('cuda'))
base_output.shape
We then take the output and run it through the attention()
method from our attention layer:
attention = attention_layer.attention(base_output)
attention.shape
As expected we get a 100x100 attention matrix for each item in the test set. Let's check the attention scores of the first id in idxs
from our predictions above:
sample = attention[idxs[0]]
idxs[0], sample.shape
In order to plot this we'll convert the data into so called long format. That means reshaping the data from a 100x100 matrix to a 100*100 x 3 matrix which has one observation per row. We'll add this to a dataframe. This step makes plotting a bit easier.
df = pd.DataFrame(list(product(range(100), range(100))), columns=['input', 'output'])
df['attention'] = to_np(sample.reshape(-1,1).squeeze())
df.head()
Let's first have a look at the particular task we try to solve:
plot_pattern(test_input[idxs[0]], test_targets[idxs[0]])
And then the attention-matrix from the model's prediction of that sample:
alt.data_transformers.disable_max_rows()
alt.Chart(df).mark_rect().encode(
x = alt.X('input:O', axis=alt.Axis(values=list(range(0, 100,10)))),
y = alt.Y('output:O', axis=alt.Axis(values=list(range(0, 100,10))), sort='descending'),
color=alt.Color('attention:Q', scale=alt.Scale(scheme='viridis'))
).properties(height=500, width=500)
Attention is high (yellowish) when the output sequence pays the most attention to the input sequence. There is a solid pattern of rectangles to rectangles and triangles to triangles!
After the arrival of the fancier transformer, LSTMs seem kind of old school. But on certain tasks they do almost as well as transformer based models, IMDb being one such case. Anyway, let's write a custom module for a fairly vanilla LSTM. I really recommend the chapter on LSTMs in Deep Learning for Coders with fastai and PyTorch, also available on github, for understanding LSTMs. The implementation below is similar to the one in the book. Also, checkout the LSTM documentation from PyTorch.
We will go for a 1 layer LSTM with a single linear layer to produce the output. We also need standard LSTM particulars such as model resetting and gradient truncation - check the fastai book for details. Finally, note that the fast.ai Module
is very similar to nn.module
, but without the need for the boilerplate super().__init__()
:
class LSTM(Module):
def __init__(self, dim_in, n_out, n_hidden, n_layers):
a_in, b_in = dim_in # input = [a_in, b_in], in our case [1,100]
self.lstm = nn.LSTM(b_in, n_hidden, n_layers) # n_layered lstm
self.h_o = nn.Linear(n_hidden, n_out) # hidden to output
self.h = [torch.zeros(n_layers, a_in, # initialize hidden and cell state in a list
n_hidden, device='cuda')
for _ in range(2)]
def forward(self, x): # x=[bs,1,100], e.g. 128 elements of [1,100] data
res,h = self.lstm(x, self.h) # the resulting output and (hn, cn) in list form. Res=[bs,1,n_hidden]
self.h = [h_.detach() for h_ in h] # truncate the gradients
return self.h_o(res) # run res thru self.h_o, return predictions [128, 1, 100]
def reset(self): # reset hidden and cell before training/validation, and after epoch
for h in self.h: h.zero_()
Let's test our model with a single batch to see that everything is working:
xb, yb = dls.train.one_batch()
model = LSTM(dim_in=(1,100), n_out=100, n_hidden=128, n_layers=1)
xb.shape, model.to('cuda')(xb).shape
We'll create a learner in the usual way, but notice that we pass a ModelReseter callback (ModelReseter??
) to the learner. It will call our LSTM-module's reset()
automatically.
learn_lstm = Learner(dls, model, loss_func=MSELossFlat(), cbs=ModelReseter)
The model has more than twice the number of parameters (130 k) compared to our previous models:
learn_lstm.summary()
We'll train the model for the same number of epochs and with the same learning rate as above. The loss is much worse than our attention model, but is clearly improving:
with learn_lstm.no_logging():
learn_lstm.fit(50, 1e-3)
learn_lstm.recorder.plot_loss(skip_start=1000)
Let's check our usual suspects (idxs
) and compare the LSTM with our other models:
preds_lstm, _ = learn_lstm.get_preds(dl=test_dl)
preds_lstm.shape
The plot is getting a bit busy. But it's clear that the attention model is the winner. Also, note that the LSTM has much more oscillations than our other models. I'm not sure why this is.
fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
plot_pattern(preds_lstm[idx], normalize(test_targets[idx]),
input_type='prediction LSTM', ax=ax, legend=False)
plot_pattern(preds_att[idx], normalize(test_targets[idx]),
input_type='prediction attention', ax=ax, include_targ=False, legend=False)
plot_pattern(preds_cnn[idx], normalize(test_targets[idx]),
input_type='prediction cnn', ax=ax, include_targ=False, legend=False)
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4);
Attention worked really well in this example. After adding an attention layer we outclassed our base cnn model, even though our attention model had fewer parameters! The trusty old LSTM fared a bit better than the base cnn, but was not nearly as good as our attention model. I have to admit, though, that the LSTM model may need a bit more love during training, and the simple setup we used might not have allowed it to shine.
Self attention is really the heart of most transformer based models. The transformer uses multi-headed attention, but that is just a duplication of our self attention layer, with an extra linear layer to transform the output to appropriate shape. The full transformer features additional linear layers, skip connections and normalization layers, but nothing too fancy.
There is however one concept which we haven't covered, namely positional encoding. Transformers have mainly been used for nlp tasks so far. And when our input data is text, matters order (omg I'm clever). Since self attention is fundamentally permutation invariant it will struggle to cope if our input data is ordered. In the next post we'll modify our example data to include ordering, and see if we can make a model that solves this harder task too - spoiler: we can!