Inspecting the embedding of an AWD-LSTM with UMAP
Can we use dimensionality reduction algorithms to make sense of the embedding weights of an AWD-LSTM language model? tl;dr: yes.
- Load vocabulary and model weights
- Inspect the weights
- Good weights, bad weights
- UMAP
- Visualizing with altair
from fastai2.text.all import *
In the previous post we took a pretrained language model and fine tuned it on our movie review dataset. In this post we'll try to see if we can make sense of the weights that the models learned. Specifically we'll look at the weights of the initial embedding layer. This is the first layer of the model, and we would expect the weights to reflect patterns in the language.
In order to inspect the weights, we don't need to load the dataset or a learner object. We can simply load the saved weights directly. We will also need the vocabulary of the model, the itos, to map which weights belong to which token.
path = Path('/data/hgi/.fastai/data/norec/')
Path.BASE_PATH = path
(path/'models').ls()
First we'll load the vocabulary of our model, the norwegian_itos.pkl
:
with open(path/'models/norwegian_itos.pkl', 'rb') as f:
itos = pickle.load(f)
len(itos)
And then the weights of the finetuned model:
mod = torch.load(path/'models/finetuned_model.pth')
mod.keys()
Let's check the model
part of the dictionary:
[f'{k:30}{v.shape}' for k,v in mod['model'].items()]
We want the 0.encoder.weight
layer. Note the shape of 30002 x 400 is vocab x embedding size. For sake of simplicity we will combine the weights and itos into a pandas dataframe with the token in the first column.
wts = pd.DataFrame(mod['model']['0.encoder.weight'])
wts.insert(0, 'token', itos) # create itos column as first column
print(wts.shape)
wts.head(3)
We now have a dataframe with one row per token in the vocab, and 400 columns with the embedding weights for that particular token.
We can see that our weights vary between approximately -3.8 to 2.8. Note that .values
returns thee underlying numpy ndarray of our data frame.
wts.iloc[:, 1:].values.min(), wts.iloc[:, 1:].values.max()
We can also plot all the weights as a histogram. We add a .flatten()
to our dataframe.values
to create a single histogram for all the 30002*400 weights.
fix, ax = plt.subplots(1,1)
ax.hist(wts.iloc[:, 1:].values.flatten(), bins=100);
ax.set_xlabel('Embedding weight')
As expected the weights are centered around 0, with a few extreme values.
Let's look closer at a few select tokens, and see if we can make sense of the corresponding weights. We'll choose the words good('god') and bad('dårlig'). We would expect the two tokens to have some similarities, but also in some respect to be opposites.
sample_tokens = ['god', 'dårlig'] # good, bad
sample = wts.loc[wts['token'].isin(sample_tokens),:]
sample
Let's plot the first 30 weights sequentially. They seem to mostly follow each other:
fig, ax = plt.subplots(1,1, figsize = (12, 6))
for row in sample.values:
ax.plot(row[1:30], label = row[0])
ax.legend()
We can also make a scatter plot to compare the weights of the two tokens. It looks like a linear relationship, but with some variation:
fig, ax = plt.subplots(1,1, figsize = (6, 6))
ax.scatter(x = sample.iloc[0,1:], y=sample.iloc[1,1:], alpha=0.5)
ax.set_xlabel('Weights for "God"');
ax.set_ylabel('Weights for "Dårlig"');
Finally, we can also verify the relationship by asserting that the correlation coefficient is greater than 0:
np.corrcoef(sample.iloc[:, 1:])[0,1]
But this approach is kind of unwieldy. It's fine for comparing pairs of tokens, but if we want to somehow compare all the tokens, we'll need another method.
UMAP is a dimensionality reduction algorithm which can be helpful to visualize high dimensional data. It was introduced in a 2018 paper. The authors also made a python library. This makes it easy to use the algorithm. I installed the UMAP-learn library along with the suggested pynndescent.
In brief, UMAP can take a high dimensional data structure and turn it into fewer dimensions, while retaining some of the characteristics of the original data structure, such as clusters. We can thus take our 400 dimensional weights, and turn them into a new data structure of only two dimensions. This will be much easier to illustrate.
embedding
which is kind of confusing in this case where we are inspecting an embedding!
import umap
reduced = umap.UMAP(n_components=2,
n_neighbors=15,
min_dist=0.3,
random_state=42,
metric='cosine').fit_transform(wts.values[:, 1:])
The 'embedding' produced by the UMAP algorithm, reduced
, only has two dimensions:
wts.shape, reduced.shape
That means we have reduced the number of "weights" from approximately 12 000 000 (30002x400) to around 60 000 (30002x2), a reduction of 99.5 %. But does the result make any sense? Let's look at a scatter plot:
fig, ax = plt.subplots(1,1, figsize = (6, 6))
ax.scatter(reduced[:, 0], reduced[:, 1], alpha=0.01);
ax.set_xlabel('x')
ax.set_ylabel('y');
Note that results from the UMAP algorithm can vary a lot with varying hyper parameters. I simply tried a couple of hyperparameters within the range of the recommended defaults. The above result seemed to be good enough, that is, it has several interesting clusters and shapes that we can investigate further. But there might be other embeddings that are 'better' though.
We would like to inspect the actual tokens that correspond to each point above. But making an interactive chart with matplotlib isn't straight forward to my knowledge. But luckily there are other plotting libraries we can test!
df = pd.DataFrame(reduced, columns=['x', 'y'])
df.insert(0, 'token', itos)
df.head()
Note that altair at the time of writing has a maximum limit of 5000 data points for such a plot, so we simply grab a random sample of 5000 rows from our data frame:
import altair as alt
alt.Chart(df.sample(5000, random_state=42)).mark_circle(size=50, fillOpacity=0.2).encode(
x='x',
y='y',
tooltip=['token']
).interactive()
We recognize the patterns of the scatter plot from the above plot. But with this plot we can inspect each token by hovering over a point. There are several interesting clusters of tokens:
- x=-2, y=2: 4-digit numbers, probably years
- x=9, y=15, 3-digit numbers (note the 2 digit numbers directly to the right)
- x:17, y=0: infinitive form verbs (present form directly above)
- x:10, y=-2: place names
- x:6, y=6: names of people
It's kind of remarkable how meaningful and easy to interpret the clustering is. I'm sure there are many other relationships that can be discovered given a closer inspection. It certainly seems like our model has learned a meaningful representation of the language.
But it's difficult to take the above visualization and diagnose our model in a specific way. It's not clear if we get any actionable insights from it. But seeing that things 'make sense' definitively give us some confidence in our model!