from fastai2.text.all import *

ULMFiT

In the previous post we explored the Norec Norwegian language corpus. We grabbed the reviews for films and TV-shows, parsed the html-text and created labels based on the ratings. In the next few posts I want to use ULMFiT and other methods to predict the sentiment of the reviews based on the text.
ULMFiT has three main steps:

  1. Train a language model on a large general purpose corpus such as Wikipedia
  2. Fine-tune the language model on the text your are working with - the style is most likely different than a Wikipedia article
  3. Combine the encoder of the fine-tuned language model with a linear classifier to predict the class of your text

The core of the ULMFiT method is a type of Recurrent neural network (RNN) called AWD-LSTM. AWD-LSTM is a special kind of Recurrent neural network (RNN) with tuned dropout parameters among other. We need to look into this architecture before we continue with our modeling. For an explanation of what an LSTM actually is i suggest checking out this blog post by Chris Olah. In general, most of Chris' posts and papers are worth reading!

How to set up an AWD-LSTM with fastai

Let's first start by inspecting fastai's language_model_learner. It's a learner class designed to be used for language models, and holds both dataloaders and the architecture along with various hyperparameters. We can use the doc() method to show us the documentation:

doc(language_model_learner)

The documentation tells us that we can pass arch = AWD-LSTM and modify the awd_lstm_lm_config to customize the architecture. The config dictionary specifies various hyperparameters and settings inspired by the aforementioned AWD-LSTM paper. By changing this dictionary we can customize our AWD-LSTM to fit our specific needs:

awd_lstm_lm_config
{'emb_sz': 400,
 'n_hid': 1152,
 'n_layers': 3,
 'pad_token': 1,
 'bidir': False,
 'output_p': 0.1,
 'hidden_p': 0.15,
 'input_p': 0.25,
 'embed_p': 0.02,
 'weight_p': 0.2,
 'tie_weights': True,
 'out_bias': True}

Let's check the documentation and source code of the AWD-LSTM class. You can check the source code directly in the notebook by appending a ?? behind the method name:

AWD_LSTM??

The source code shows us a few interesting lines we'll look more into in the next few sections:

  1. self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token)
  2. self.rnns = nn.ModuleList([self._one_rnn(emb_sz if l == 0 else n_hid, (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, bidir, weight_p, l) for l in range(n_layers)])
  3. self.input_dp = RNNDropout(input_p)
  4. self.hidden_dps = nn.ModuleList([RNNDropout(hidden_p) for l in range(n_layers)])

Note: the embedding is called encoder in the code above. The name encoder is also fastai lingo for the entire RNN-part of the architecture. The linear layers added on top for the classifier is called decoder. Neither the ULMFiT or AWD-LSTM paper uses the term encoder or decoder though.

But what is an embedding?

Once again I'll be lazy and rather refer to another blog that explains embeddings in detail. The blog is by Jay Alammar and has explanations of many deep learning and NLP concepts. The essence is that we'll turn each token in our vocabulary into a vector of some size that represents various aspects of that token. The weights of this vector will be trainable and gives our neural network a lot of flexibility in assigning various properties to each token.

The embedding is created by: self.encoder = nn.Embedding(vocab_sz, emb_sz, padding_idx=pad_token) Here we see that fastai is built on top of pyTorch and relies on pyTorch's fundamental methods in its own code. The encoder layer is a call to nn.Embedding, see documentation. Let's create an embedding of size 10x3 with padding_idx = 0:

embedding = nn.Embedding(num_embeddings=10, embedding_dim=3, padding_idx=0)
embedding.weight
Parameter containing:
tensor([[ 0.0000,  0.0000,  0.0000],
        [ 1.6721, -1.3130,  0.6414],
        [ 1.1675,  0.1174,  1.8511],
        [-0.3341, -1.0047, -0.8467],
        [-0.7737, -0.3947, -1.5273],
        [-1.1472, -0.0429, -0.0994],
        [-1.0594,  1.3725,  0.3796],
        [ 0.1682,  0.7212,  0.9494],
        [ 1.2791,  0.1334, -0.5075],
        [ 0.4486,  0.4936,  0.2588]], requires_grad=True)

The embedding now has 10 vectors of length 3 with randomly initialized weights. Note that the first one (index 0) is all 0. This is because 0 is our padding index. Next we'll pass some some sample data inp, and inspect the result. We can think of our input as the index of words in a dictionary. E.g. 1='this', 7='is', 4='not' and 3='easy'. 0 will be our padding token. The padding token is a special token that is used to ensure that some text has a certain length. This is useful when stacking various pieces of text into a batch where sizes needs to match.

inp = torch.LongTensor([1,7,4,3,0,0])
emb = embedding(inp)
emb
tensor([[ 1.6721, -1.3130,  0.6414],
        [ 0.1682,  0.7212,  0.9494],
        [-0.7737, -0.3947, -1.5273],
        [-0.3341, -1.0047, -0.8467],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000]], grad_fn=<EmbeddingBackward>)

We see that the embedding produced by feeding the input corresponds to the weights of our original embedding. That is, index 1 of inp is the item '7'. So emb[1] is basically a lookup for embedding.weight[7].

embedding.weight[7], emb[1]
(tensor([0.1682, 0.7212, 0.9494], grad_fn=<SelectBackward>),
 tensor([0.1682, 0.7212, 0.9494], grad_fn=<SelectBackward>))

To summarize:We'll need an embedding with the number of embeddings equal to our vocabulary size, and embedding size of 400 and a padding token-id which corresponds to whichever token has been used as padding in our vocabulary.

Compostion of the RNN-layers

Secondly we create a list of RNN-layers with various dimensions:

self.rnns = nn.ModuleList([self._one_rnn(emb_sz if l == 0 else n_hid, 
                         (n_hid if l != n_layers - 1 else emb_sz)//self.n_dir, bidir, weight_p, l) for l in(n_layers)])

The code stacks RNN-layers of embedding size x hidden size for the first layer, and hidden size x embedding size for the final. Let's verify that this works for various number of layers:

AWD_LSTM(vocab_sz=10_000, emb_sz=400, n_hid=1152, n_layers=2)
AWD_LSTM(
  (encoder): Embedding(10000, 400, padding_idx=1)
  (encoder_dp): EmbeddingDropout(
    (emb): Embedding(10000, 400, padding_idx=1)
  )
  (rnns): ModuleList(
    (0): WeightDropout(
      (module): LSTM(400, 1152, batch_first=True)
    )
    (1): WeightDropout(
      (module): LSTM(1152, 400, batch_first=True)
    )
  )
  (input_dp): RNNDropout()
  (hidden_dps): ModuleList(
    (0): RNNDropout()
    (1): RNNDropout()
  )
)
AWD_LSTM(vocab_sz=10_000, emb_sz=400, n_hid=1152, n_layers=5)
AWD_LSTM(
  (encoder): Embedding(10000, 400, padding_idx=1)
  (encoder_dp): EmbeddingDropout(
    (emb): Embedding(10000, 400, padding_idx=1)
  )
  (rnns): ModuleList(
    (0): WeightDropout(
      (module): LSTM(400, 1152, batch_first=True)
    )
    (1): WeightDropout(
      (module): LSTM(1152, 1152, batch_first=True)
    )
    (2): WeightDropout(
      (module): LSTM(1152, 1152, batch_first=True)
    )
    (3): WeightDropout(
      (module): LSTM(1152, 1152, batch_first=True)
    )
    (4): WeightDropout(
      (module): LSTM(1152, 400, batch_first=True)
    )
  )
  (input_dp): RNNDropout()
  (hidden_dps): ModuleList(
    (0): RNNDropout()
    (1): RNNDropout()
    (2): RNNDropout()
    (3): RNNDropout()
    (4): RNNDropout()
  )
)

We see the first and final layers have similar dimensions in the two examples.

To summarize:We'll use a 3 layer network with input and output dimensions of (400, 1152), (1152, 1152) and (1152, 400) as in the AWD-LSTM paper. This should be handled automatically by the library.

nn.LSTM

In the module list above, the layers are actually WeightDropout layers. We can verify this from the hidden constructor method that is called when the RNNs are being created. First, a regular nn.LSTM layer is created before being passed to the WeightDropout module.

def _one_rnn(self, n_in, n_out, bidir, weight_p, l):
    "Return one of the inner rnn"
    rnn = nn.LSTM(n_in, n_out, 1, batch_first=True, bidirectional=bidir)
    return WeightDropout(rnn, weight_p)

Lets have a look at an example from the nn.LSTM documentation, also see source code here. We'll make a 1 layer LSTM with input size of 10 and hidden size of 20. Note that in the AWD-LSTM case the input size is equal to the embedding size (400 by default).

inp_s = 10 # input size
hid_s = 20 # hidden size
lstm = nn.LSTM(input_size = inp_s, hidden_size = hid_s, num_layers=1)

The documentation details that the LSTM expects input in the form of input(seq_len, batch, input_size). Seq_len is the length of the part of the text the model will see in each iteration (seq_len = 72 by default in fastais language_model_learner, that is 72 tokens). The batch size is the number of documents the models sees in each iteration.

h0 and c0 are the inital hidden and cell states (set to 0 if not provided). The documentation specifiy their shapes as: (num_layers * num_directions, batch, hidden_size).

bs = 16
n_l, n_d = 1, 1 # we are testing a 1 layer and 1 direction lstm
seq_len = 5
inp = torch.randn(seq_len, bs, inp_s)
h0 = torch.randn(n_l*n_d, bs, hid_s)
c0 = torch.randn(n_l*n_d, bs, hid_s)
inp.shape, h0.shape, c0.shape
(torch.Size([5, 16, 10]), torch.Size([1, 16, 20]), torch.Size([1, 16, 20]))

The output from the LSTM should be a tuple of output, (h_n, c_n)where output has shape given by: (seq_len, batch, num_directions * hidden_size):

out, (hn, cn) = lstm(inp)
out.shape
torch.Size([5, 16, 20])

Let's also check the actual shape of our weights by looping through the state_dict():

[(key, lstm.state_dict()[key].shape) for  key in lstm.state_dict().keys()]
[('weight_ih_l0', torch.Size([80, 10])),
 ('weight_hh_l0', torch.Size([80, 20])),
 ('bias_ih_l0', torch.Size([80])),
 ('bias_hh_l0', torch.Size([80]))]

Here we recognize the input size of 10 and hidden size of 20, but where does the 80 come from? The documentation specifies that the weights will be of dimension (4*hidden_size, input_size). The '4' is called gate_size, and we can find this in the source code for the base RNN module also:

if mode == 'LSTM':
    gate_size = 4 * hidden_size

To summarize:we expect two sets of weights and biases per LSTM: weight_ih_l0 with a shape of (4hidden_size, input_size)

  • weight_hh_l0 with a shape of (4*hidden_size, hidden_size)
  • bias_ih_l0 with a shape of (4*hidden_size)
  • bias_hh_l0 with a shape of (4*hidden_size)

WeightDropout

We can see from the _one_rnn that the nn.LSTM is transformed to a WeightDropout module. The documentation describes the module as 'A module that warps another layer in which some weights will be replaced by 0 during training'. From the source code we can see that it's the weight_hh_l0 weights that will be modified, and that these weights are duplicated with suffix 'raw' in the WeightDropout module: self.register_parameter(f'{layer}_raw', nn.Parameter(w.data)).

Let's see if we can verify this by first checking the weights from the lstm from the above section. Of the 80*20 = 1600 weights in the hh_l0 layer, none are 0:

orig_wts = getattr(lstm, 'weight_hh_l0')
orig_wts.shape, (orig_wts == 0.).sum()
(torch.Size([80, 20]), tensor(0))

But if pass the lstm through the WeightDroput module, approximately half of the 1600 weights are set to 0. Note that we have to call the model on the input since the weights are only reset during the forward pass. The weights are also only reset for the WeighDropout's internal LSTM module, while a copy with suffix '_raw' retains the original weights.

wd = WeightDropout(lstm, weight_p=0.5)
_,_ = wd(inp) # we don't need the output in this case

The .module attribute of the wd object is our original LSTM:

wd.module
LSTM(10, 20)

And about half its weights have been set to 0:

(getattr(wd.module, 'weight_hh_l0')==0.0).sum()
tensor(826)

The original weights from the lstm matches the '_raw' weights of the WeightDropout module:

test_eq(orig_wts, getattr(wd, 'weight_hh_l0_raw'))

The new layers are as expected:

wd.state_dict().keys()
odict_keys(['weight_hh_l0_raw', 'module.weight_ih_l0', 'module.bias_ih_l0', 'module.bias_hh_l0'])

RNN dropout

Finally several RNNDropout layers are being created - one for the input and one for each LSTM. This dropout is applied to the input embedding and on the output of each LSTM. We can test the functionality with inp from the above section.

dp = RNNDropout(0.5)
dp_out = dp(inp)
inp.shape, dp_out.shape
(torch.Size([5, 16, 10]), torch.Size([5, 16, 10]))

The documentation also says: 'Dropout with probability p that is consistent on the seq_len dimension.' In our input from the above section, seq_len is the first dimension (index 0), and if we check for items equaling 0 and sum along the second dimension (index 1) we see that the same tokens are dropped out for the entire batch (our sample batch size is 16) consistently in approximately half of the instances.

(dp_out == 0).sum((1))
tensor([[ 0,  0, 16, 16, 16, 16,  0,  0, 16,  0],
        [ 0, 16, 16, 16, 16,  0,  0, 16, 16,  0],
        [ 0,  0,  0, 16,  0,  0,  0,  0, 16, 16],
        [ 0,  0, 16, 16,  0, 16,  0, 16,  0, 16],
        [ 0, 16,  0, 16,  0, 16, 16,  0, 16, 16]])

IMDb inspection

Let's take a look at a minimal IMDB example from the fastai documentation to verify our understanding of the AWD-LSTM architecture.

imdb_path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(imdb_path/'texts.csv')
dls = TextDataLoaders.from_df(df, path=imdb_path, text_col='text', is_lm=True, valid_col='is_valid')
learn = language_model_learner(dls, AWD_LSTM)

The vocab is of length 7080 and vocab index 1 is 'xxpad':

dls.vocab[:5], len(dls.vocab)
(['xxunk', 'xxpad', 'xxbos', 'xxeos', 'xxfld'], 7080)

In our model we recognize Embedding(7080, 400, padding_idx=1) as vocab_size x embedding size with the correct padding token. We also see that the (input, output) dimensions of our LSTM-layers are as expected, and with the expected dropout layers added.

learn.model
SequentialRNN(
  (0): AWD_LSTM(
    (encoder): Embedding(7080, 400, padding_idx=1)
    (encoder_dp): EmbeddingDropout(
      (emb): Embedding(7080, 400, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (module): LSTM(400, 1152, batch_first=True)
      )
      (1): WeightDropout(
        (module): LSTM(1152, 1152, batch_first=True)
      )
      (2): WeightDropout(
        (module): LSTM(1152, 400, batch_first=True)
      )
    )
    (input_dp): RNNDropout()
    (hidden_dps): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
      (2): RNNDropout()
    )
  )
  (1): LinearDecoder(
    (decoder): Linear(in_features=400, out_features=7080, bias=True)
    (output_dp): RNNDropout()
  )
)

The model summary shows us the default batch size of 64 and seq_len of 72.

learn.summary()
SequentialRNN (Input shape: ['64 x 72'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
RNNDropout           64 x 72 x 400        0          False     
________________________________________________________________
RNNDropout           64 x 72 x 1152       0          False     
________________________________________________________________
RNNDropout           64 x 72 x 1152       0          False     
________________________________________________________________
Linear               64 x 72 x 7080       2,839,080  True      
________________________________________________________________
RNNDropout           64 x 72 x 400        0          False     
________________________________________________________________

Total params: 2,839,080
Total trainable params: 2,839,080
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7fa03d7cbdd0>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group number 3

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback
  - ModelReseter
  - RNNRegularizer

And finally, the layer names and shapes also is consistent with a gate size of 4 (1152*4 = 4608). Note the enumeration of the layers: 0. is the encoder part of the architecture (including the embedding called encoder) and 1. is the decoder.

for key in learn.model.state_dict().keys():
    print(key, '\t', learn.model.state_dict()[key].shape)
0.encoder.weight 	 torch.Size([7080, 400])
0.encoder_dp.emb.weight 	 torch.Size([7080, 400])
0.rnns.0.weight_hh_l0_raw 	 torch.Size([4608, 1152])
0.rnns.0.module.weight_ih_l0 	 torch.Size([4608, 400])
0.rnns.0.module.bias_ih_l0 	 torch.Size([4608])
0.rnns.0.module.bias_hh_l0 	 torch.Size([4608])
0.rnns.1.weight_hh_l0_raw 	 torch.Size([4608, 1152])
0.rnns.1.module.weight_ih_l0 	 torch.Size([4608, 1152])
0.rnns.1.module.bias_ih_l0 	 torch.Size([4608])
0.rnns.1.module.bias_hh_l0 	 torch.Size([4608])
0.rnns.2.weight_hh_l0_raw 	 torch.Size([1600, 400])
0.rnns.2.module.weight_ih_l0 	 torch.Size([1600, 1152])
0.rnns.2.module.bias_ih_l0 	 torch.Size([1600])
0.rnns.2.module.bias_hh_l0 	 torch.Size([1600])
1.decoder.weight 	 torch.Size([7080, 400])
1.decoder.bias 	 torch.Size([7080])