Fine-tuning a Norwegian language model for ULMFiT
In this post we'll load a pretrained language model and fine tune it on our target corpus of TV and movie reviews.
from fastai2.text.all import *
In a previous post we explored the Norec Norwegian language corpus of film and TV reviews. In this post I want to use the ULMFiT method to predict the sentiment of the reviews based on the text. ULMFiT has three main steps:
- Train a language model on a large general purpose corpus such as Wikipedia
- Fine-tune the language model on the text your are working with - the style is most likely different than a Wikipedia article
- Use the encoder of the fine-tuned language to transform text to feature vectors, and finally add a linear classifier on top to predict the class of the review.
In this post we'll focus on step 1 and 2, the language model and the fine-tuning. Training a language model from scratch is a bit of work. First you have to get the data to train it, and the training will also take a long time. Luckily the the fast.ai language model zoo already lists a pretrained language model for Norwegian. Note that this is a ULMFiT model zoo, so we expect to find weights for a AWD-LSTM. See this post to better understand how to customize an AWD-LSTM with fastai.
Let's first grab the dataset from a previous post. It's available as a csv from github:
df = pd.read_csv('https://raw.githubusercontent.com/hallvagi/dl-explorer/master/uploads/norec.csv')
df.head(3)
The pretrained weights we want to use is located in this repo. There is some information listed here:
- The weights were trained on 90% of all text in the corresponding language wikipedia as per 3. July 2018. The remaining 10% was used for validation.
- Norwegian: Trained on 80,284,231 tokens, and validated on 8,920,387 tokens. We achieve a perplexity of 26.31
And file descriptions:
- enc.h5 Contains the weights in 'Hierarchical Data Format'
- enc.pth Contains the weights in 'Pytorch model format'
- itos.pkl (Integers to Strings) contains the vocabulary mapping from ids (0 - 30000) to strings
It looks like we will need the enc.pth (fastai is built on top of PyTorch) and the vocabulary (itos.pkl). But how do we actually load the model? The repo doesn't really specify this part, so let's see if we can figure it out. First we'll download and extract the data to a desired location and have a look at the files:
path = Path('~/.fastai/data/norec') # choses a path of your liking!
os.makedirs(path/'models', exist_ok=True)
model_url = 'https://www.dropbox.com/s/lwr5kvbxri1gvv9/norwegian.zip'
!wget {model_url} -O {path/'models/norwegian.zip'} -q
!unzip -q {path/'models/norwegian.zip'} -d {path/'models'}
Path.BASE_PATH = path # paths are printed relative to the BASE_PATH
(path/'models').ls()
The first file we want to check out is the norwegian_itos.pkl. This the vocabulary of the model, that is, the words and tokens it's able to recognize. itos means integer-to-string. The index of a particular token in the list is the key to that token.
with open(path/'models/norwegian_itos.pkl', 'rb') as f:
itos = pickle.load(f)
itos[:10], itos[-5:], len(itos)
The very first token, i.e. index 0, is _unk_ or unknown. The other tokens in the first part of the list are common words such as 'i' (in) and 'og' (and). Among the final tokens there are even some English words. This is not really surprising since Norwegian has "borrowed" several words from English. It seems, however, that the special tokens for unknown and padding (_unk_ and _pad_) are different than the fastai defaults:
print(defaults.text_spec_tok)
Will this cause issues later?
Secondly, let's have a look at the weights. We'll load it with pyTorch.
enc = torch.load(path/'models/norwegian_enc.pth')
enc.keys()
It's a dictionary with keys, and the keys are the names of the layers of the model. We can see that is has an embedding layer (named 'encoder'), and three RNNs(LSTMs more precisely) with various descriptions. We recognize the three layer LSTM from the AWD-LSTM and the ULMFiT paper. So we must make sure that the model we set up matches matches this. Let's have a look at the dimensions of the weights.
for k,v in enc.items():
print(k," \t", v.shape)
We notice that the hidden size is different than the fastai default of 1152, but apart from that everything looks fine. Let's save a few weights from the embedding layer to compare with our final model.
sample_weights = enc['encoder.weight'][0][:5]
sample_weights
First we have to make sure that our data loader uses our custom vocabulary instead of doing tokenization on its own, so we pass text_vocab = itos
. We also set is_lm = True
since we want a language model and not a classifier. We use the basic factory method, since we have no need of customization at this point.
dls_lm = TextDataLoaders.from_df(df, text_col='text', text_vocab=itos, is_lm=True, valid_pct=0.1)
dls_lm.show_batch(max_n=3)
This looks pretty good. We can recognize our _unk token for example. We also see that the label column, that is the \"text" column, is offset by 1 token from the input. This makes sense since the goal of a language model is to predict the next word in a sequence.
The next step is to configure the AWD-LSTM architecture. Let's have a look at the default config:
awd_lstm_lm_config
Most of these look good, but we will change the n_hid to 1150. Note also the pad_token=1. This is the index of the padding token, and from our itos above we see that itos[1] = _pad_
awd_lstm_lm_config['n_hid'] = 1150
Now we can pass the config to our learner object. Notice that we set pretrained=False
, we want to load our own weights. The final .to_fp16()
means that the model is trained with mixed precision (16 bit floating point) which can often speed up training quite a bit.
learn_lm = language_model_learner(dls_lm,
arch=AWD_LSTM,
metrics=[accuracy, Perplexity()],
path=path,
config=awd_lstm_lm_config,
pretrained=False).to_fp16()
The model summary now looks correct:
learn_lm.model
The weights of our model have been initialized randomly, so they should not match at the moment. Let's compare our sample weights from the above section with those from our language model:
learn_lm.model.state_dict()['0.encoder.weight'][0][:5].cpu(), sample_weights
But now we should be able to load the encoder:
learn_lm.load_encoder(path/'models/norwegian_enc')
It worked! We can also see that the weights match:
learn_lm.model.state_dict()['0.encoder.weight'][0][:5].cpu(), sample_weights
But are we able to predict any useful text?
learn_lm.predict('Hovedstaden i Norge er') # the captical of norway is
What is the problem now? We see that predict()
by default has no_unk=True
. The error message tells us that the library tries to get the index of the UNK token. The UNK token is as we noted earlier different in our itos
(vocabulary) than that what the library expects
UNK, itos[0]
This is not really a problem for our model. The models only sees the underlying numbers and indexes, and they are still correct. But if we want to use most of the convenience functions of the fastai2 library, we either have to customise the code, or simpler still, change the vocab.
So let's look through our itos and see if we can find any special tokens:
print(defaults.text_spec_tok) # fastai defaults
Let's look for tokens that contains an underscore _
_toks = [token for token in itos if '_' in token]
_toks[:5]
And then for tokens that begin with an x. We use a simple regex to check for x in the beginning of the token.
x_toks = [token for token in itos if re.match(r'^x', token) != None]
x_toks[:5]
'unk', 'pad', 'xfld', 'xbos' seems pretty obvious. But I'm less sure of eg. 't_up' and 'tk_rep'. So we replace a bit conservatively:
to_replace = _toks[:2]+x_toks[:2]
to_replace
replace_with = defaults.text_spec_tok[:2]+defaults.text_spec_tok[7:8]+defaults.text_spec_tok[2:3]
replace_with
Then we loop trough our itos and replace the selected tokens:
for tok_remove, tok_insert in zip(to_replace, replace_with):
idx = itos.index(tok_remove)
itos[idx] = tok_insert
To verify that we did things correct:
idxs = [itos.index(token) for token in replace_with]
idxs
[itos[idx] for idx in idxs]
Let's make yet another version of our dataloader and language learner.
dls_lm = TextDataLoaders.from_df(df, text_col='text', text_vocab=itos, is_lm=True, valid_pct=0.1)
learn_lm = language_model_learner(dls_lm, arch=AWD_LSTM, metrics=[accuracy, Perplexity()],
path=path, config=awd_lstm_lm_config, pretrained=False).to_fp16()
learn_lm.load_encoder(path/'models/norwegian_enc')
TEXT = "Hovedstaden i Norge er" # the capital of norway is
preds = [learn_lm.predict(TEXT, 40, temperature=0.75) for _ in range(3)]
preds
Well, that kind of makes sense. LSTMs are less capable at generating text than the more complex transformer architectures, but our concern in this particular case is how well we eventually do sentiment classification.
Now we are finally ready to fine tune our language model. We'll use the "standard" training regime from the documentation and the fast.ai courses. That is 1 epoch where only the linear layers in the head of the model are trainable, and finally 10 epochs with all layers unfrozen.
learn_lm.lr_find()
A learning rate of 1e-2 seem to be a safe choice:
learn_lm.fit_one_cycle(1, 1e-2)
Next we will unfreeze all layers and train for 10 epochs at a reduced learning rate. The idea is that once we unfreeze the lower layers of our model we should make smaller changes to avoid catastrophic forgetting. We will go for a leraning rate of 1e-3. One can always test if longer training improves results, but in this case we will simply assumes that 10 epochs is good enough.
learn_lm.unfreeze()
learn_lm.lr_find()
learn_lm.fit_one_cycle(10, 1e-3)
We will save the model and the encoder for future use. The save()
method save objects to path/'models' by default. The encoder will be used in our classifier in the next post.
learn_lm.save('finetuned_model')
learn_lm.save_encoder('finetuned_encoder')
Is the model any good at predicting text?
TEXT = "Denne filmen er et godt eksempel på" # this film is a good example of
preds = [learn_lm.predict(TEXT, 40, temperature=0.75) for _ in range(3)]
preds
This is by no means as impressive as the recent transformer models, but the model certainly understands language fairly well. Also, our particular use case isn't really text generation, but sentiment classification. Transformers only do marginally better than ULMFiT according to Papers with code on the similar IMDB classification task. Classification will be the topic of an upcoming post.