from fastai2.text.all import *

Load vocabulary and model weights

In the previous post we took a pretrained language model and fine tuned it on our movie review dataset. In this post we'll try to see if we can make sense of the weights that the models learned. Specifically we'll look at the weights of the initial embedding layer. This is the first layer of the model, and we would expect the weights to reflect patterns in the language.

In order to inspect the weights, we don't need to load the dataset or a learner object. We can simply load the saved weights directly. We will also need the vocabulary of the model, the itos, to map which weights belong to which token.

path = Path('/data/hgi/.fastai/data/norec/')
Path.BASE_PATH = path
(path/'models').ls()
(#7) [Path('models/finetuned_model.pth'),Path('models/norwegian_wgts.h5'),Path('models/norwegian_enc.pth'),Path('models/finetuned_encoder.pth'),Path('models/norwegian.zip'),Path('models/norwegian_enc.h5'),Path('models/norwegian_itos.pkl')]

First we'll load the vocabulary of our model, the norwegian_itos.pkl:

with open(path/'models/norwegian_itos.pkl', 'rb') as f:
    itos = pickle.load(f)
len(itos)
30002

And then the weights of the finetuned model:

mod = torch.load(path/'models/finetuned_model.pth')
mod.keys()
dict_keys(['model', 'opt'])

Let's check the model part of the dictionary:

[f'{k:30}{v.shape}' for k,v in mod['model'].items()]
['0.encoder.weight              torch.Size([30002, 400])',
 '0.encoder_dp.emb.weight       torch.Size([30002, 400])',
 '0.rnns.0.weight_hh_l0_raw     torch.Size([4600, 1150])',
 '0.rnns.0.module.weight_ih_l0  torch.Size([4600, 400])',
 '0.rnns.0.module.bias_ih_l0    torch.Size([4600])',
 '0.rnns.0.module.bias_hh_l0    torch.Size([4600])',
 '0.rnns.1.weight_hh_l0_raw     torch.Size([4600, 1150])',
 '0.rnns.1.module.weight_ih_l0  torch.Size([4600, 1150])',
 '0.rnns.1.module.bias_ih_l0    torch.Size([4600])',
 '0.rnns.1.module.bias_hh_l0    torch.Size([4600])',
 '0.rnns.2.weight_hh_l0_raw     torch.Size([1600, 400])',
 '0.rnns.2.module.weight_ih_l0  torch.Size([1600, 1150])',
 '0.rnns.2.module.bias_ih_l0    torch.Size([1600])',
 '0.rnns.2.module.bias_hh_l0    torch.Size([1600])',
 '1.decoder.weight              torch.Size([30002, 400])',
 '1.decoder.bias                torch.Size([30002])']

We want the 0.encoder.weight layer. Note the shape of 30002 x 400 is vocab x embedding size. For sake of simplicity we will combine the weights and itos into a pandas dataframe with the token in the first column.

wts = pd.DataFrame(mod['model']['0.encoder.weight'])
wts.insert(0, 'token', itos) # create itos column as first column
print(wts.shape)
wts.head(3)
(30002, 401)
token 0 1 2 3 4 5 6 7 8 ... 390 391 392 393 394 395 396 397 398 399
0 _unk_ 0.532715 0.289551 0.000947 0.555176 0.645508 -0.158447 -0.063538 0.475830 -0.143799 ... 0.303711 -0.447510 0.070312 0.091370 0.530273 -0.200806 -0.341064 0.335205 0.305908 0.035522
1 _pad_ -0.155029 -0.045685 0.090393 -0.244507 -0.101379 -0.004860 -0.033386 -0.083130 0.025513 ... -0.103577 0.100159 -0.036377 0.010338 -0.089478 0.059662 0.187744 -0.114014 -0.085083 0.013466
2 . 0.362061 -0.982910 -0.106384 0.452637 0.752930 -0.117310 0.069397 0.145386 0.092346 ... 0.346191 -0.233398 0.073853 0.242065 0.050323 -0.436035 -0.372314 0.303223 0.360596 0.108032

3 rows × 401 columns

We now have a dataframe with one row per token in the vocab, and 400 columns with the embedding weights for that particular token.

Inspect the weights

We can see that our weights vary between approximately -3.8 to 2.8. Note that .values returns thee underlying numpy ndarray of our data frame.

wts.iloc[:, 1:].values.min(), wts.iloc[:, 1:].values.max()
(-3.8261719, 2.7851562)

We can also plot all the weights as a histogram. We add a .flatten() to our dataframe.values to create a single histogram for all the 30002*400 weights.

fix, ax = plt.subplots(1,1)
ax.hist(wts.iloc[:, 1:].values.flatten(), bins=100);
ax.set_xlabel('Embedding weight')
Text(0.5, 0, 'Embedding weight')

As expected the weights are centered around 0, with a few extreme values.

Good weights, bad weights

Let's look closer at a few select tokens, and see if we can make sense of the corresponding weights. We'll choose the words good('god') and bad('dårlig'). We would expect the two tokens to have some similarities, but also in some respect to be opposites.

sample_tokens = ['god', 'dårlig'] # good, bad
sample = wts.loc[wts['token'].isin(sample_tokens),:]
sample
token 0 1 2 3 4 5 6 7 8 ... 390 391 392 393 394 395 396 397 398 399
590 god 0.410400 0.414062 -0.260498 0.235474 -0.162964 -0.142700 -0.017609 0.064209 0.533203 ... 0.276367 0.239380 -0.054108 0.024933 -0.006481 0.557617 -0.384033 0.313965 -0.033844 -0.141357
1015 dårlig 0.404541 0.183838 0.001801 0.097351 -0.022003 -0.049194 -0.019836 -0.060455 0.001980 ... -0.227661 0.582031 -0.130127 0.232788 0.145996 0.483887 0.069214 -0.012825 -0.185181 -0.145630

2 rows × 401 columns

Let's plot the first 30 weights sequentially. They seem to mostly follow each other:

fig, ax = plt.subplots(1,1, figsize = (12, 6))
for row in sample.values:
    ax.plot(row[1:30], label = row[0])
    ax.legend()

We can also make a scatter plot to compare the weights of the two tokens. It looks like a linear relationship, but with some variation:

fig, ax = plt.subplots(1,1, figsize = (6, 6))
ax.scatter(x = sample.iloc[0,1:], y=sample.iloc[1,1:], alpha=0.5)
ax.set_xlabel('Weights for "God"');
ax.set_ylabel('Weights for "Dårlig"');

Note: We are using the object oriented syntax for matplotlib in the above examples. I particularly like Chris Moffits tutorial on matplotlib.

Finally, we can also verify the relationship by asserting that the correlation coefficient is greater than 0:

np.corrcoef(sample.iloc[:, 1:])[0,1]
0.5042995297412233

But this approach is kind of unwieldy. It's fine for comparing pairs of tokens, but if we want to somehow compare all the tokens, we'll need another method.

UMAP

UMAP is a dimensionality reduction algorithm which can be helpful to visualize high dimensional data. It was introduced in a 2018 paper. The authors also made a python library. This makes it easy to use the algorithm. I installed the UMAP-learn library along with the suggested pynndescent.

In brief, UMAP can take a high dimensional data structure and turn it into fewer dimensions, while retaining some of the characteristics of the original data structure, such as clusters. We can thus take our 400 dimensional weights, and turn them into a new data structure of only two dimensions. This will be much easier to illustrate.

Note: The new data structure is called an embedding which is kind of confusing in this case where we are inspecting an embedding!
import umap
reduced = umap.UMAP(n_components=2,
                      n_neighbors=15,
                      min_dist=0.3,
                      random_state=42, 
                      metric='cosine').fit_transform(wts.values[:, 1:])

The 'embedding' produced by the UMAP algorithm, reduced, only has two dimensions:

wts.shape, reduced.shape
((30002, 401), (30002, 2))

That means we have reduced the number of "weights" from approximately 12 000 000 (30002x400) to around 60 000 (30002x2), a reduction of 99.5 %. But does the result make any sense? Let's look at a scatter plot:

fig, ax = plt.subplots(1,1, figsize = (6, 6))
ax.scatter(reduced[:, 0], reduced[:, 1], alpha=0.01);
ax.set_xlabel('x')
ax.set_ylabel('y');

Note that results from the UMAP algorithm can vary a lot with varying hyper parameters. I simply tried a couple of hyperparameters within the range of the recommended defaults. The above result seemed to be good enough, that is, it has several interesting clusters and shapes that we can investigate further. But there might be other embeddings that are 'better' though.

We would like to inspect the actual tokens that correspond to each point above. But making an interactive chart with matplotlib isn't straight forward to my knowledge. But luckily there are other plotting libraries we can test!

Visualizing with altair

I haven't used altair before, but heard a lot about it. It's also supposed to play well with fastpages, the platform used to write this blog. About time to take it for a spin!

Altair prefers input data in the form of a data frame, so let's combine our reduced embedding with the vocabulary itos:

df = pd.DataFrame(reduced, columns=['x', 'y'])
df.insert(0, 'token', itos)
df.head()
token x y
0 _unk_ 12.963406 9.573566
1 _pad_ 5.633428 0.549199
2 . 3.142599 1.704447
3 i 14.942327 9.356459
4 , 3.190900 1.649735

Note that altair at the time of writing has a maximum limit of 5000 data points for such a plot, so we simply grab a random sample of 5000 rows from our data frame:

import altair as alt
alt.Chart(df.sample(5000, random_state=42)).mark_circle(size=50, fillOpacity=0.2).encode(
    x='x',
    y='y',
    tooltip=['token']
).interactive()

We recognize the patterns of the scatter plot from the above plot. But with this plot we can inspect each token by hovering over a point. There are several interesting clusters of tokens:

  • x=-2, y=2: 4-digit numbers, probably years
  • x=9, y=15, 3-digit numbers (note the 2 digit numbers directly to the right)
  • x:17, y=0: infinitive form verbs (present form directly above)
  • x:10, y=-2: place names
  • x:6, y=6: names of people

It's kind of remarkable how meaningful and easy to interpret the clustering is. I'm sure there are many other relationships that can be discovered given a closer inspection. It certainly seems like our model has learned a meaningful representation of the language.

But it's difficult to take the above visualization and diagnose our model in a specific way. It's not clear if we get any actionable insights from it. But seeing that things 'make sense' definitively give us some confidence in our model!