from fastai2.vision.all import *
import altair as alt
from itertools import product

Finding a proper case study for attention

I have been looking into attention and transformers lately, and wanted to explore it with an example data set - solving a problem is the best way to learn in my opinion. Most use cases I come across, however, seem to be huge language models, which are just too complex to be a good starting point for learning. One day this popped up in my twitter feed, and seemed like just the right thing for further examination:

The code below creates the actual data which we will investigate in this post. The code is written by François Fleuret, which was kind to let met reuse the code. I've modified it slightly so that it fits the notebook format. You can find the original code here.

#collapse
# Author: François Fleuret
seq_height_min, seq_height_max = 1.0, 25.0
seq_width_min, seq_width_max = 5.0, 11.0
seq_length = 100

def positions_to_sequences(tr = None, bx = None, noise_level = 0.3):
    st = torch.arange(seq_length).float()
    st = st[None, :, None]
    tr = tr[:, None, :, :]
    bx = bx[:, None, :, :]

    xtr =            torch.relu(tr[..., 1] - torch.relu(torch.abs(st - tr[..., 0]) - 0.5) * 2 * tr[..., 1] / tr[..., 2])
    xbx = torch.sign(torch.relu(bx[..., 1] - torch.abs((st - bx[..., 0]) * 2 * bx[..., 1] / bx[..., 2]))) * bx[..., 1]

    x = torch.cat((xtr, xbx), 2)

    # u = x.sign()
    u = F.max_pool1d(x.sign().permute(0, 2, 1), kernel_size = 2, stride = 1).permute(0, 2, 1)

    collisions = (u.sum(2) > 1).max(1).values
    y = x.max(2).values

    return y + torch.rand_like(y) * noise_level - noise_level / 2, collisions

def generate_sequences(nb, group_by_locations=False):

    # Position / height / width

    tr = torch.empty(nb, 2, 3)
    tr[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
    tr[:, :, 1].uniform_(seq_height_min, seq_height_max)
    tr[:, :, 2].uniform_(seq_width_min, seq_width_max)

    bx = torch.empty(nb, 2, 3)
    bx[:, :, 0].uniform_(seq_width_max/2, seq_length - seq_width_max/2)
    bx[:, :, 1].uniform_(seq_height_min, seq_height_max)
    bx[:, :, 2].uniform_(seq_width_min, seq_width_max)

    if group_by_locations:
        a = torch.cat((tr, bx), 1)
        v = a[:, :, 0].sort(1).values[:, 2:3]
        mask_left = (a[:, :, 0] < v).float()
        h_left = (a[:, :, 1] * mask_left).sum(1) / 2
        h_right = (a[:, :, 1] * (1 - mask_left)).sum(1) / 2
        valid = (h_left - h_right).abs() > 4
    else:
        valid = (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4) & (torch.abs(tr[:, 0, 1] - tr[:, 1, 1]) > 4)

    input, collisions = positions_to_sequences(tr, bx)

    if group_by_locations:
        a = torch.cat((tr, bx), 1)
        v = a[:, :, 0].sort(1).values[:, 2:3]
        mask_left = (a[:, :, 0] < v).float()
        h_left = (a[:, :, 1] * mask_left).sum(1, keepdim = True) / 2
        h_right = (a[:, :, 1] * (1 - mask_left)).sum(1, keepdim = True) / 2
        a[:, :, 1] = mask_left * h_left + (1 - mask_left) * h_right
        tr, bx = a.split(2, 1)
    else:
        tr[:, :, 1:2] = tr[:, :, 1:2].mean(1, keepdim = True)
        bx[:, :, 1:2] = bx[:, :, 1:2].mean(1, keepdim = True)

    targets, _ = positions_to_sequences(tr, bx)

    valid = valid & ~collisions
    tr = tr[valid]
    bx = bx[valid]
    input = input[valid][:, None, :]
    targets = targets[valid][:, None, :]

    if input.size(0) < nb:
        input2, targets2, tr2, bx2 = generate_sequences(nb - input.size(0))
        input = torch.cat((input, input2), 0)
        targets = torch.cat((targets, targets2), 0)
        tr = torch.cat((tr, tr2), 0)
        bx = torch.cat((bx, bx2), 0)

    return input, targets, tr, bx

We use the above code to create our data set: 25 000 and 1000 train and test samples, respectively, of 1D sequences with a length of 100:

torch.manual_seed(42)
train_input, train_targets, train_tr, train_bx = generate_sequences(25000)
test_input, test_targets, test_tr, test_bx = generate_sequences(1000)
train_input.shape, train_targets.shape, test_input.shape, test_targets.shape
(torch.Size([25000, 1, 100]),
 torch.Size([25000, 1, 100]),
 torch.Size([1000, 1, 100]),
 torch.Size([1000, 1, 100]))

Our input data is simply a series of floats:

print(train_input[0,0,:5].tolist())
[0.13880640268325806, 0.0006390661001205444, 0.0356530100107193, -0.10253779590129852, -0.014928251504898071]

Let's also plot a pattern to better understand what our task is going to be:

#collapse
def plot_pattern(inp, target, input_type='input', ax=None, include_targ = True, legend=True):
    if not ax: fix, ax = plt.subplots()
    inp = to_np(tensor(inp).squeeze())
    target = to_np(tensor(target).squeeze())
    ax.plot(inp, label=input_type)
    if include_targ: ax.plot(target, label='target')
    if legend: ax.legend();
plot_pattern(train_input[123], train_targets[123])

The target output should be the average height of the input for corresponding geometric shapes:The two yellow squares represent the average height of the blue squares, and similary for the triangles.

Basic cnn model

Our first attempt at solving this task will be a regular (1d) convnet. We will make it with the fastai library. The architecture and training will be identical to the one François used in his example (see tweet above).

Dataloaders

In fastai the dataloader holds both the input and the target data. So let's combine our input and target data to a combined training dataset. This step makes the creation of the dataloader a bit easier.

train = torch.cat((train_input, train_targets), dim=-1)
train.shape
torch.Size([25000, 1, 200])

We also want to normalize our data, so we'll create a normalize helper function.

mean, std = train.mean(), train.std()
def norm(x, m, s): return (x-m)/s
normalize = partial(norm, m=mean, s=std)

Finally we make our dataloader with the datablock api. Think of this as an assembly line which we pass our data source through:

dls = DataBlock(
    get_x = lambda row: row[:, :100],    # the first 100 cols of a row hold our training data
    get_y = lambda row: row[:, 100:],    # and the final 100 cols in the row holds our target data
    splitter=RandomSplitter(valid_pct=0.2),
    batch_tfms=normalize
).dataloaders(source=train, bs=128)

Next we will grab a batch from the dataloader, and expect to find the data normalized:

xb, yb = next(iter(dls.train))
xb.shape, yb.shape, xb.mean().item(), xb.std().item()
(torch.Size([128, 1, 100]),
 torch.Size([128, 1, 100]),
 0.004318716004490852,
 1.0494749546051025)

We get a batch of 128 items both for our input and target, with zero mean and unit variance. We can also plot an item from the our batch to verify that everything is ok:

plot_pattern(xb[0], yb[0])

Create a learner

This is the same model that Francois used in his example:

#collapse
def get_model_cnn(ks, stride=1):
    return sequential(
        nn.Conv1d(1, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 1, kernel_size=ks, stride=stride, padding=ks//2),
    )
basic_conv = get_model_cnn(ks=5)
learn_cnn = Learner(dls, model=basic_conv, loss_func=MSELossFlat())

Our cnn model has about 62 000 parameters:

learn_cnn.summary()
Sequential (Input shape: ['128 x 1 x 100'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Conv1d               128 x 64 x 100       384        True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 64 x 100       20,544     True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 64 x 100       20,544     True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 64 x 100       20,544     True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 1 x 100        321        True      
________________________________________________________________

Total params: 62,337
Total trainable params: 62,337
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7f47ac2085f0>
Loss function: FlattenedLoss of MSELoss()

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback

Training

We'll train our model for 50 epochs with a flat learning rate of 1e-3. It's a small model, so training won't take very long. We could definitively train for longer without waiting all day. Fastai prints out the metrics from every epoch by default, so we'll use the no_logging() context manager to make the output a bit more compact:

with learn_cnn.no_logging():
    learn_cnn.fit(50, 1e-3)
learn_cnn.recorder.plot_loss(skip_start=100)

Our model seems to be learning something, but how good are the predictions?

Check a few predictions

Let's check a few predictions from the test set. We'll add the test set to our data loader first. Note that the test_dl() method expects a numpy ndarray, so we transform our input tensors with to_np() before passing it:

test_dl = learn_cnn.dls.test_dl(to_np(test_input))
preds_cnn, _ = learn_cnn.get_preds(dl=test_dl)
preds_cnn.shape
torch.Size([1000, 1, 100])

Then we grab a few random indexes to plot. We'll be reusing these for other models later:

idxs = [84, 701,  27, 493]

The predictions aren't particulary good at this point:

fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
    plot_pattern(preds_cnn[idx], normalize(test_targets[idx]), input_type='prediction', ax=ax)

Sanity check our model

We didn't get great results from the above convnet. Is the task too hard or is our model broken in some way? A standard way of checking the health of a model is to try to overfit a single batch. If the model can't do this something funky is going on.

Let's take 20 items and see if we're ablet to overfit that. We'll use the standard PyTorch training loop:

model = get_model_cnn(ks=5)
optimizer = Adam(model.parameters(), lr=1e-3)
mse_loss = MSELossFlat()
inp, target = normalize(train[:20,:,:100]), normalize(train[:20,:,100:])

for epoch in range(5001):
    optimizer.zero_grad()
    output = model(inp)
    loss = mse_loss(output, target)
    if epoch%1000==0: print(f'epoch: {epoch}, loss: {loss.item()}')
    loss.backward()
    optimizer.step()
epoch: 0, loss: 0.9717220664024353
epoch: 1000, loss: 0.016162781044840813
epoch: 2000, loss: 0.006243061739951372
epoch: 3000, loss: 0.003895669477060437
epoch: 4000, loss: 0.0020476507488638163
epoch: 5000, loss: 0.0005493065691553056

The loss is slowly approaching zero, so I it seems that our model is working. But it needs a crazy number of epochs to overfit just 20 items, so doing well on the entire data set would be very difficult. There are many things we do to improve the model - make it deeper or wider, or perhaps increasing the kernel size would help? But there is also this thing called attention!

Attention

Attention has become a central concept in deep learning with the rise of the transformer architecture. There are several good resources that explains the concept. I particularly like the following:

  1. The Rasa white board video series on attention. Higly reccomended!
  2. Jay Allamar's blog has several posts on attentions and transformers
  3. Yannick Kilcher has lot's of videos on deep learning papers, including a playlist for NLP
  4. Lilian Weng has a nice blog with a few posts on attention and transformers
  5. Peter Bloem has a nice from-scratch implementation of the transformer in PyTorch
  6. The annotated Transformer by Harvard NLP, and the Attention is All You Need paper
  7. The annotated GPT-2 by Aman Arora

The rest of the post will assume a basic understanding of the concept of attention (specifically self-attention). Below is the implementation of an attention layer and the proposed model architecture - very similar to the above cnn architecture except for a new AttentionLayer. This code is once again written by François, but I added a few comments on intermediate tensor shapes. Let's go through it step by step.

#collapse
class AttentionLayer(nn.Module):
    def __init__(self, in_channels, out_channels, key_channels):
        super(AttentionLayer, self).__init__()
        self.conv_Q = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
        self.conv_K = nn.Conv1d(in_channels, key_channels, kernel_size = 1, bias = False)
        self.conv_V = nn.Conv1d(in_channels, out_channels, kernel_size = 1, bias = False)

    def forward(self, x):                      #  x.shape =  [bs x in_channels x seq_len]
        Q = self.conv_Q(x)                                 # [bs x key_channels x seq_len]
        K = self.conv_K(x)                                 # [bs x key_channels x seq_len]
        V = self.conv_V(x)                                 # [bs x out_channels x seq_len]
        A = Q.permute(0, 2, 1).matmul(K).softmax(2)        # [bs x seq_len x seq_len]
        x = A.matmul(V.permute(0, 2, 1)).permute(0, 2, 1)  # [bs x out_channels x seq_len]
        return x

    def __repr__(self):
        return self._get_name() + \
            '(in_channels={}, out_channels={}, key_channels={})'.format(
                self.conv_Q.in_channels,
                self.conv_V.out_channels,
                self.conv_K.out_channels
            )

    def attention(self, x):
        Q = self.conv_Q(x)
        K = self.conv_K(x)
        return Q.permute(0, 2, 1).matmul(K).softmax(2)

def get_model_attention(ks, stride=1):
    return sequential(
        nn.Conv1d(1, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        AttentionLayer(in_channels=64, out_channels=64, key_channels=96),  ## New attention layer
        nn.ReLU(), 
        nn.Conv1d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), 
        nn.ReLU(), 
        nn.Conv1d(64, 1, kernel_size=ks, stride=stride, padding=ks//2)
    )

def __init__():

Our class has a nn.Module superclass, so first we initialize the class properly.

We proceed to initialize our Q, K and V matrices. Note that conv_K and conv_Q have identical dimensions: in_channels x key_channels. We'll look into why this is a requirement later in the post. conv_V, however, is different: in_channels x out_channels. Q, K and V often have the same shape in other implementations, see e.g. the fast.ai nlp course implementation, but this depends on the overall architecture that the attention layer is part of. In this particular case our attention layer is part of a specific architecture where the number of out channels need to conform to the expected number of input channels in the next layer. Finally, if Q, K and V do have identical shape we can also stack them in a single matrix and thereby improve performance.

Why are Q, K and V conv layers?

Note that we use Conv1d layers with kernel_size=1. This means that a conv-layer is identical to a linear layer. Let's take some dummy data to compare a Conv1d(ks=1) and a regular nn.Linear() and see if we can change the number of channels of our input from 8 to 12:

# bs x n_channels x seq_len
x = torch.rand(16,8,50)
x.shape
torch.Size([16, 8, 50])

If we pass x through a conv layer with kernel_size=1 and no bias, we get our desired output:

nn.Conv1d(in_channels=8, out_channels=12, kernel_size=1, bias=False)(x).shape
torch.Size([16, 12, 50])

We can achieve the same with a linear layer. But the linear layer expects input where the final dimension represents the input's channel dimension. And similar for the output. To compare with our conv layer we need to permute the data in each step:

out = nn.Linear(in_features=8, out_features=12, bias=False)(x.permute(0,2,1))
out.permute(0,2,1).shape
torch.Size([16, 12, 50])

So using a Conv1d with ks=1 is the same as a linear layer, but seems to avoid a few annoying permutations. Huggingface's implementation of attention in GPT-2 uses a conv1D for Q, K and V for example.

Why don't we scale the attention scores?

There are many variations of attention, see Lillian Weng's post for an overview. Our implementation is a so called dot product attention. But the Attention is All You Need paper uses a scaled dot product attention instead. This seem to have become the norm in subsequent transformers, and the argument is that scaling will improve the gradient signal especially for long input sequences. In this particular example it probably doesn't matter much anyway since we have a fairly simple architecture and small input data.

def forward()

In the forward method we calculate the matrix product of our conv_q, conv_k and conv_v matrices with our input x. Since these are conv-layers we can just feed x to them, and the superclass' forward method will do the matrix multiplication automatically. We proceed to calculate the Attention, A by doing a dot-product with each Q and K pair of the sequence and finally doing a softmax of the final dimension (2 in this case) to normalize the attention scores. But why do we need to permute k before the matrix multiplication? Let's walk through an example step by step.

First we create sample Q, K and V conv layers:

conv_Q = nn.Conv1d(4, 8, kernel_size=1, bias=False)
conv_K = nn.Conv1d(4, 8, kernel_size=1, bias=False)
conv_V = nn.Conv1d(4, 16, kernel_size=1, bias=False)
conv_K.weight.shape, conv_Q.weight.shape, conv_V.weight.shape
(torch.Size([8, 4, 1]), torch.Size([8, 4, 1]), torch.Size([16, 4, 1]))

X is our mock input data. it has a batch size of 1 (keep it simple), a channel size of 8, and a sequence length of 50:

x = torch.rand((1, 4, 50))
x.shape
torch.Size([1, 4, 50])

Then we calculate our Q, K and V matrices by feeding x to them:

qx = conv_Q(x)
kx = conv_K(x)
vx = conv_V(x)
qx.shape, kx.shape, vx.shape
(torch.Size([1, 8, 50]), torch.Size([1, 8, 50]), torch.Size([1, 16, 50]))

Now we want to get the dot product of each key and value pair for every item in the sequence for the entire batch. If we ignore the batch dimension, qx and kx are shaped like 8x50 @ 8x50. If we permute the final dimension of the former we instead get 50x8 @ 8x50, which are valid dimensions for a matrix multiplication. As expected the output is seq_len x seq_len:

qx.permute(0, 2, 1).matmul(kx).shape
torch.Size([1, 50, 50])

Note that if we instead permute the last matrix we get the wrong result!

qx.matmul(kx.permute(0, 2, 1)).shape
torch.Size([1, 8, 8])

Finally we calculate our attention scores by normalizing with a softmax along the seq_len dimension, in this case the second dimension:

A = qx.permute(0, 2, 1).matmul(kx).softmax(2)
A.shape
torch.Size([1, 50, 50])

We can verify that the softmax normalized our attention scores to 1 for each item in the sequence:

A.sum(-1).mean().item()
1.0

Finally we want to multiply A with our learned values vx, but once again our matrices don't align:

A.shape, vx.shape
(torch.Size([1, 50, 50]), torch.Size([1, 16, 50]))

When we permute vx and do the matrix multiply with A we get the contextualized values for our input sequence. But the seq_len dimension has been switched:

A.matmul(vx.permute(0, 2, 1)).shape
torch.Size([1, 50, 16])

We have to permute once more to get our desired result. We also see that the resulting shape is independent of the 8 key_channels in Q and K:

A.matmul(vx.permute(0, 2, 1)).permute(0, 2, 1).shape
torch.Size([1, 16, 50])

Note that the shape of our contextualized x is different than our input x:

x.shape
torch.Size([1, 4, 50])

This is simply due to the fact that our attention layer is the middle layer of a particular architecture and the number of in_channels and out_channels have to match the characteristics of the architecture. In many self-attention implementations, we need to return a contextualized input with identical shape as the original input.

permute vs transpose

The code permutes the matrices several times. This is necessary to align the matrices properly for matrix multiplication, and is similar to transposing. Note that we can transpose by listing dimensions explicitly or relatively. The latter is perhaps a bit more robust to changing input data types, but the former is maybe easier to read:

tmp = torch.rand(3,4,5)
tmp.permute(0,2,1).shape, tmp.permute(0, -1,-2).shape
(torch.Size([3, 5, 4]), torch.Size([3, 5, 4]))

Transpose gives us the same result:

tmp.transpose(1,2).shape, tmp.transpose(-1, -2).shape
(torch.Size([3, 5, 4]), torch.Size([3, 5, 4]))

def __repr__()

This method simply gives the module a string representation for printing. We can se the result in learn.model() below, which prints the layers of the model.

def attention()

This method outputs the same as the A in the forward. This gives us a convenient way to inspect the attention scores for a particular input x.

CNN + attention

We make a new model and train it in a similar manner as our original cnn implementation above. Note that it has slightly fewer parameters than our original cnn, 58k vs 62k.

attention_model = get_model_attention(ks=5)
learn_att = Learner(dls, model=attention_model, loss_func=MSELossFlat())
learn_att.summary()
Sequential (Input shape: ['128 x 1 x 100'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
Conv1d               128 x 64 x 100       384        True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 64 x 100       20,544     True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 96 x 100       6,144      True      
________________________________________________________________
Conv1d               128 x 96 x 100       6,144      True      
________________________________________________________________
Conv1d               128 x 64 x 100       4,096      True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 64 x 100       20,544     True      
________________________________________________________________
ReLU                 128 x 64 x 100       0          False     
________________________________________________________________
Conv1d               128 x 1 x 100        321        True      
________________________________________________________________

Total params: 58,177
Total trainable params: 58,177
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7f47ac2085f0>
Loss function: FlattenedLoss of MSELoss()

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
learn_att.model
Sequential(
  (0): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (1): ReLU()
  (2): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (3): ReLU()
  (4): AttentionLayer(in_channels=64, out_channels=64, key_channels=96)
  (5): ReLU()
  (6): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (7): ReLU()
  (8): Conv1d(64, 1, kernel_size=(5,), stride=(1,), padding=(2,))
)

Train model

Let's train our new attention model for a similar number of epochs and the same learning rate as our basic_conv model:

with learn_att.no_logging(): 
    learn_att.fit(50, 1e-3)
learn_att.recorder.plot_loss(skip_start=100)

Wow! The loss is much lower this time around!

Check a few predictions

We'll reuse the idxs and the test_dl dataloader to inspect a few predictions. The predictions from the attention model are much better than our previous cnn-model:

preds_att, _ = learn_att.get_preds(dl=test_dl)
preds_att.shape

fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
    plot_pattern(preds_cnn[idx], normalize(test_targets[idx]), 
                 input_type='prediction cnn', ax=ax, include_targ=False, legend=False)
    plot_pattern(preds_att[idx], normalize(test_targets[idx]), 
                 input_type='prediction cnn+attention', ax=ax, legend=False)
    
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=3);

Check attention matrix

Another fascinating thing with attention is that we can inspect the attention matrix. This means we can see which parts of the input sequence that the model thinks is most important for it's contextualized output sequence. In order to get it we have to:

  1. Run a batch from the test set through the a part of the model up to the attention layer
  2. Grab the attention layer and run it's attention-method to get the attention matrix

We can inspect the individual parts of our model with the .childeren() method, and also slice the model into separate parts:

list(attention_model.children())
[Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,)),
 ReLU(),
 Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,)),
 ReLU(),
 AttentionLayer(in_channels=64, out_channels=64, key_channels=96),
 ReLU(),
 Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,)),
 ReLU(),
 Conv1d(64, 1, kernel_size=(5,), stride=(1,), padding=(2,))]

Let's grab the first part and the attention layer:

base_model = attention_model[0:4]
attention_layer = attention_model[4]
base_model, attention_layer
(Sequential(
   (0): Conv1d(1, 64, kernel_size=(5,), stride=(1,), padding=(2,))
   (1): ReLU()
   (2): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(2,))
   (3): ReLU()
 ),
 AttentionLayer(in_channels=64, out_channels=64, key_channels=96))

Now we pass our normalized test_input through the first part of the model, ensuring it's on the same device as our model:

base_output = base_model(normalize(test_input).to('cuda'))
base_output.shape
torch.Size([1000, 64, 100])

We then take the output and run it through the attention() method from our attention layer:

attention = attention_layer.attention(base_output)
attention.shape
torch.Size([1000, 100, 100])

As expected we get a 100x100 attention matrix for each item in the test set. Let's check the attention scores of the first id in idxs from our predictions above:

sample = attention[idxs[0]]
idxs[0], sample.shape
(84, torch.Size([100, 100]))

In order to plot this we'll convert the data into so called long format. That means reshaping the data from a 100x100 matrix to a 100*100 x 3 matrix which has one observation per row. We'll add this to a dataframe. This step makes plotting a bit easier.

df = pd.DataFrame(list(product(range(100), range(100))), columns=['input', 'output'])
df['attention'] = to_np(sample.reshape(-1,1).squeeze())
df.head()
input output attention
0 0 0 0.005647
1 0 1 0.017142
2 0 2 0.019656
3 0 3 0.011353
4 0 4 0.010529

Let's first have a look at the particular task we try to solve:

plot_pattern(test_input[idxs[0]], test_targets[idxs[0]])

And then the attention-matrix from the model's prediction of that sample:

alt.data_transformers.disable_max_rows()
alt.Chart(df).mark_rect().encode(
    x = alt.X('input:O', axis=alt.Axis(values=list(range(0, 100,10)))),
    y = alt.Y('output:O', axis=alt.Axis(values=list(range(0, 100,10))), sort='descending'),
    color=alt.Color('attention:Q', scale=alt.Scale(scheme='viridis'))
).properties(height=500, width=500)

Attention is high (yellowish) when the output sequence pays the most attention to the input sequence. There is a solid pattern of rectangles to rectangles and triangles to triangles!

LSTM

After the arrival of the fancier transformer, LSTMs seem kind of old school. But on certain tasks they do almost as well as transformer based models, IMDb being one such case. Anyway, let's write a custom module for a fairly vanilla LSTM. I really recommend the chapter on LSTMs in Deep Learning for Coders with fastai and PyTorch, also available on github, for understanding LSTMs. The implementation below is similar to the one in the book. Also, checkout the LSTM documentation from PyTorch.

We will go for a 1 layer LSTM with a single linear layer to produce the output. We also need standard LSTM particulars such as model resetting and gradient truncation - check the fastai book for details. Finally, note that the fast.ai Module is very similar to nn.module, but without the need for the boilerplate super().__init__():

class LSTM(Module):
    def __init__(self, dim_in, n_out, n_hidden, n_layers):
        a_in, b_in = dim_in                              # input = [a_in, b_in], in our case [1,100]
        self.lstm = nn.LSTM(b_in, n_hidden, n_layers)    # n_layered lstm
        self.h_o = nn.Linear(n_hidden, n_out)            # hidden to output
        self.h = [torch.zeros(n_layers, a_in,            # initialize hidden and cell state in a list
                              n_hidden, device='cuda')
                  for _ in range(2)]                     
        
    def forward(self, x):                    # x=[bs,1,100], e.g. 128 elements of [1,100] data
        res,h = self.lstm(x, self.h)         # the resulting output and (hn, cn) in list form. Res=[bs,1,n_hidden]
        self.h = [h_.detach() for h_ in h]   # truncate the gradients
        return self.h_o(res)                 # run res thru self.h_o, return predictions [128, 1, 100]
    
    def reset(self):                  # reset hidden and cell before training/validation, and after epoch
        for h in self.h: h.zero_()

Let's test our model with a single batch to see that everything is working:

xb, yb = dls.train.one_batch()
model = LSTM(dim_in=(1,100), n_out=100, n_hidden=128, n_layers=1)
xb.shape, model.to('cuda')(xb).shape
(torch.Size([128, 1, 100]), torch.Size([128, 1, 100]))

Training the model

We'll create a learner in the usual way, but notice that we pass a ModelReseter callback (ModelReseter??) to the learner. It will call our LSTM-module's reset() automatically.

learn_lstm = Learner(dls, model, loss_func=MSELossFlat(), cbs=ModelReseter)

The model has more than twice the number of parameters (130 k) compared to our previous models:

learn_lstm.summary()
LSTM (Input shape: ['128 x 1 x 100'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
LSTM                 ['128 x 1 x 128', "  117,760    True      
________________________________________________________________
Linear               128 x 1 x 100        12,900     True      
________________________________________________________________

Total params: 130,660
Total trainable params: 130,660
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7f47ac2085f0>
Loss function: FlattenedLoss of MSELoss()

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - ModelReseter

We'll train the model for the same number of epochs and with the same learning rate as above. The loss is much worse than our attention model, but is clearly improving:

with learn_lstm.no_logging():
    learn_lstm.fit(50, 1e-3)
learn_lstm.recorder.plot_loss(skip_start=1000)

Check a few predictions

Let's check our usual suspects (idxs) and compare the LSTM with our other models:

preds_lstm, _ = learn_lstm.get_preds(dl=test_dl)
preds_lstm.shape
torch.Size([1000, 1, 100])

The plot is getting a bit busy. But it's clear that the attention model is the winner. Also, note that the LSTM has much more oscillations than our other models. I'm not sure why this is.

fig, axs = plt.subplots(2,2, figsize=(12,9))
for ax, idx in zip(axs.flatten(), idxs):
    plot_pattern(preds_lstm[idx], normalize(test_targets[idx]), 
                 input_type='prediction LSTM', ax=ax, legend=False)
    
    plot_pattern(preds_att[idx], normalize(test_targets[idx]), 
                 input_type='prediction attention', ax=ax, include_targ=False, legend=False)
    
    plot_pattern(preds_cnn[idx], normalize(test_targets[idx]), 
                 input_type='prediction cnn', ax=ax, include_targ=False, legend=False)

handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc='lower center', ncol=4);

Final thoughts

Attention worked really well in this example. After adding an attention layer we outclassed our base cnn model, even though our attention model had fewer parameters! The trusty old LSTM fared a bit better than the base cnn, but was not nearly as good as our attention model. I have to admit, though, that the LSTM model may need a bit more love during training, and the simple setup we used might not have allowed it to shine.

Self attention is really the heart of most transformer based models. The transformer uses multi-headed attention, but that is just a duplication of our self attention layer, with an extra linear layer to transform the output to appropriate shape. The full transformer features additional linear layers, skip connections and normalization layers, but nothing too fancy.

There is however one concept which we haven't covered, namely positional encoding. Transformers have mainly been used for nlp tasks so far. And when our input data is text, matters order (omg I'm clever). Since self attention is fundamentally permutation invariant it will struggle to cope if our input data is ordered. In the next post we'll modify our example data to include ordering, and see if we can make a model that solves this harder task too - spoiler: we can!