Word2Vec: Obtain word embeddings

0. Introduction

Word2vec is the tool for generating the distributed representation of words, which is proposed by Mikolov et al[1]. When the tool assigns a real-valued vector to each word, the closer the meanings of the words, the greater similarity the vectors will indicate.

Distributed representation means assigning a real-valued vector for each word and representing the word by the vector. When representing a word by distributed representation, we call the vector word embeddings. In this notebook, we aim at explaining how to get the word embeddings from Penn Tree Bank dataset.

Let’s think about what the meaning of word is. Since we are human, so we can understand that the words “animal” and “dog” are deeply related each other. But what information will Word2vec use to learn the vectors for words? The words “animal” and “dog” should have similar vectors, but the words “food” and “dog” should be far from each other. How to know the features of those words automatically?

1. Basic Idea

Word2vec learns the similarity of word meanings from simple information. It learns the representation of words from sentences. The core idea is based on the assumption that the meaning of a word is affected by the words around it. This idea follows distributional hypothesis[2].

The word we focus on to learn its representation is called “center word”, and the words around it are called “context words”. Depending on the window size C determines the number of context words which is considered.

Here, let’s see the algorithm by using an example sentence: “The cute cat jumps over the lazy dog.

  • All of the following figures consider “cat” as the center word.
  • According to the window size C, you can see that the number of context words is changed.
center\_context\_word.png

center_context_word.png

2. Main Algorithm

Word2vec, the tool for creating the word embeddings, is actually built with two models, which are called Skip-gram and CBoW.

To explain the models with the figures below, we will use the following symbols.

  • \(|\mathcal{V}|\) : The size of vocabulary |
  • \(D\) : The size of embedding vector |
  • \({\bf v}_t\) : A one-hot center word vector |
  • \(V_{\pm C}\) : A set of \(C\) context vectors around \({\bf v}_t\), namely, \(\{{\bf v}_{t+c}\}_{c=-C}^C \backslash {\bf v}_t\) |
  • \({\bf l}_H\) : An embedding vector of an input word vector |
  • \({\bf l}_O\) : An output vector of the network |
  • \({\bf W}_H\) : The embedding matrix for inputs |
  • \({\bf W}_O\) : The embedding matrix for outputs |

Note

Using negative sampling or hierarchical softmax for the loss function is very common, however, in this notebook, we will use the softmax over all words and skip the other variants for the sake of simplicity.

2.1 Skip-gram

This model learns to predict context words \(V_{t \pm C}\) when a center word \({\bf v}_t\) is given. In the model, each row of the embedding matrix for input \(W_H\) becomes a word embedding of each word.

When you input a center word \({\bf v}_t\) into the network, you can predict one of context words \(\hat{\bf v}_{t+i} \in V_{t \pm C}\) as follows:

  1. Calculate an embedding vector of the input center word vector: \({\bf l}_H = {\bf W}_H {\bf v}_t\)
  2. Calculate an output vector of the embedding vector: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
  3. Calculate a probability vector of a context word: \(\hat{\bf v}_{t+i} = \text{softmax}({\bf l}_O)\)

Each element of the \(|\mathcal{V}|\)-dimensional vector \(\hat{\bf v}_{t+i}\) is a probability that a word in the vocabulary turns out to be a context word at position \(i\). So, the probability \(p({\bf v}_{t+i} \mid {\bf v}_t)\) can be estimated by a dot product of the one-hot vector \({\bf v}_{t+i}\) which represents the actual word at the position \(i\) and the output vector \(\hat{\bf v}_{t+i}\).

\(p({\bf v}_{t+i} \mid {\bf v}_t) = {\bf v}_{t+i}^T \hat{\bf v}_{t+i}\)

The loss function for all the context words \(V_{t \pm C}\) given a center word \({\bf v}_t\) is defined as following:

$

\begin{eqnarray} L(V_{t \pm C} | {\bf v}_t; {\bf W}_H, {\bf W}_O) &=& \sum_{V_{t \pm C}} -\log\left(p({\bf v}_{t+i} \mid {\bf v}_t)\right) \\ &=& \sum_{V_{t \pm C}} -\log({\bf v}_{t+i}^T \hat{\bf v}_{t+i}) \end{eqnarray}

$

2.2 Continuous Bag of Words (CBoW)

This model learns to predict the center word \({\bf v}_t\) when context words \(V_{t \pm C}\) is given.

When you give a set of context words \(V_{t \pm C}\) to the network, you can estimate the probability of the center word \(\hat{v}_t\) as follows:

  1. Calculate a mean embedding vector over all context words: \({\bf l}_H = \frac{1}{2C} \sum_{V_{t \pm C}} {\bf W}_H {\bf v}_{t+i}\)
  2. Calculate an output vector: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
  3. Calculate an probability vector: \(\hat{\bf v}_t = \text{softmax}({\bf l}_O)\)

Each element of \(\hat{\bf v}_t\) is a probability that a word in the vocabulary is considered as the center word. So, the prediction \(p({\bf v}_t \mid V_{t \pm C})\) can be calculated by \({\bf v}_t^T \hat{\bf v}_t\), where \({\bf v}_t\) denots the one-hot vector of the actual center word vector in the sentence from the dataset.

The loss function for the center word prediction is defined as follows:

$

\begin{eqnarray} L({\bf v}_t|V_{t \pm C}; W_H, W_O) &=& -\log(p({\bf v}_t|V_{t \pm C})) \\ &=& -\log({\bf v}_t^T \hat{\bf v}_t) \end{eqnarray}

$

3. Details of skip-gram

In this notebook, we mainly explain skip-gram model because

  1. It is easier to understand the algorithm than CBoW.
  2. Even if the number of words increases, the accuracy is largely maintained. So, it is more scalable.

So, let’s think about a concrete example of calculating skip-gram under this setup:

  • The size of vocabulary \(|\mathcal{V}|\) is 10.
  • The size of embedding vector \(D\) is 2.
  • Center word is “dog”.
  • Context word is “animal”.

Since there should be more than one context words, repeat the following process for each context word.

  1. The one-hot vector of “dog” is [0 0 1 0 0 0 0 0 0 0] and you input it as the center word.
  2. The third row of embedding matrix \({\bf W}_H\) is used for the word embedding of “dog” \({\bf l}_H\).
  3. Then multiply \({\bf W}_O\) with \({\bf l}_H\) to obtain the output vector \({\bf l}_O\)
  4. Give \({\bf l}_O\) to the softmax function to make it a predicted probability vector \(\hat{\bf v}_{t+c}\) for a context word at the position \(c\).
  5. Calculate the error between \(\hat{\bf v}_{t+c}\) and the one-hot vector of “animal”; [1 0 0 0 0 0 0 0 0 0 0].
  6. Propagate the error back to the network to update the parameters.
skipgram\_detail.png

skipgram_detail.png

4. Implementation of skip-gram in Chainer

There is an example of Word2vec in the official repository of Chainer, so we will explain how to implement skip-gram based on this: chainer/examples/word2vec

First, we execute the following cell and install “Chainer” and its GPU back end “CuPy”. If the “runtime type” of Colaboratory is GPU, you can run Chainer with GPU as a backend.

In [1]:
!curl https://colab.chainer.org/install | sh -
Reading package lists... Done
Building dependency tree
Reading state information... Done
libcusparse8.0 is already the newest version (8.0.61-1).
libnvrtc8.0 is already the newest version (8.0.61-1).
libnvtoolsext1 is already the newest version (8.0.61-1).
0 upgraded, 0 newly installed, 0 to remove and 0 not upgraded.

4.1 Preparation

First, let’s import necessary packages:

In [ ]:
import argparse
import collections

import numpy as np
import six

import chainer
from chainer import cuda
import chainer.functions as F
import chainer.initializers as I
import chainer.links as L
import chainer.optimizers as O
from chainer import reporter
from chainer import training
from chainer.training import extensions

4.2 Define a skip-gram model

Next, let’s define a network for skip-gram.

In [ ]:
class SkipGram(chainer.Chain):

    def __init__(self, n_vocab, n_units):
        super().__init__()
        with self.init_scope():
            self.embed = L.EmbedID(
                n_vocab, n_units, initialW=I.Uniform(1. / n_units))
            self.out = L.Linear(n_units, n_vocab, initialW=0)

    def __call__(self, x, context):
        e = self.embed(context)
        shape = e.shape
        x = F.broadcast_to(x[:, None], (shape[0], shape[1]))
        e = F.reshape(e, (shape[0] * shape[1], shape[2]))
        x = F.reshape(x, (shape[0] * shape[1],))
        center_predictions = self.out(e)
        loss = F.softmax_cross_entropy(center_predictions, x)
        reporter.report({'loss': loss}, self)
        return loss

Note

  • The weight matrix self.embed.W is the embbeding matrix for input vector x.
  • __call__ takes the word ID of a center word x and word IDs of context words contexts as inputs, and outputs the error calculated by the loss function softmax_cross_entropy.
  • Note that the initial shape of x and contexts are (batch_size,) and (batch_size, n_context), respectively.
  • The batch_size means the size of mini-batch, and n_context means the number of context words.

First, we obtain the embedding vectors of contexts by e = self.embed(contexts).

Then F.broadcast_to(x[:, None], (shape[0], shape[1])) performs broadcasting of x ((batch_size,)) to (batch_size, n_context) by copying the same value n_context time to fill the second axis, and then the broadcasted x is reshaped into 1-D vector (batchsize * n_context,) while e is reshaped to (batch_size * n_context, n_units).

In skip-gram model, predicting a context word from the center word is the same as predicting the center word from a context word because the center word is always a context word when considering the context word as a center word. So, we create batch_size * n_context center word predictions by applying self.out linear layer to the embedding vectors of context words. Then, calculate softmax cross entropy between the broadcasted center word ID x and the predictions.

4.3 Prepare dataset and iterator

Let’s retrieve the Penn Tree Bank (PTB) dataset by using Chainer’s dataset utility get_ptb_words() method.

In [ ]:
train, val, _ = chainer.datasets.get_ptb_words()
n_vocab = max(train) + 1  # The minimum word ID is 0

Then define an iterator to make mini-batches that contain a set of center words with their context words.

In [ ]:
class WindowIterator(chainer.dataset.Iterator):

    def __init__(self, dataset, window, batch_size, repeat=True):
        self.dataset = np.array(dataset, np.int32)
        self.window = window
        self.batch_size = batch_size
        self._repeat = repeat

        self.order = np.random.permutation(
            len(dataset) - window * 2).astype(np.int32)
        self.order += window
        self.current_position = 0
        self.epoch = 0
        self.is_new_epoch = False

    def __next__(self):
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        i = self.current_position
        i_end = i + self.batch_size
        position = self.order[i: i_end]
        w = np.random.randint(self.window - 1) + 1
        offset = np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])
        pos = position[:, None] + offset[None, :]
        context = self.dataset.take(pos)
        center = self.dataset.take(position)

        if i_end >= len(self.order):
            np.random.shuffle(self.order)
            self.epoch += 1
            self.is_new_epoch = True
            self.current_position = 0
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return center, context

    @property
    def epoch_detail(self):
        return self.epoch + float(self.current_position) / len(self.order)

    def serialize(self, serializer):
        self.current_position = serializer('current_position',
                                           self.current_position)
        self.epoch = serializer('epoch', self.epoch)
        self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
        if self._order is not None:
            serializer('_order', self._order)

def convert(batch, device):
    center, context = batch
    if device >= 0:
        center = cuda.to_gpu(center)
        context = cuda.to_gpu(context)
    return center, context
  • In the constructor, we create an array self.order which denotes shuffled indices of [window, window + 1, ..., len(dataset) - window - 1] in order to choose a center word randomly from dataset in a mini-batch.
  • The iterator definition __next__ returns batch_size sets of center word and context words.
  • The code self.order[i:i_end] returns the indices for a set of center words from the random-ordered array self.order. The center word IDs center at the random indices are retrieved by self.dataset.take.
  • np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)]) creates a set of offsets to retrieve context words from the dataset.
  • The code position[:, None] + offset[None, :] generates the indices of context words for each center word index in position. The context word IDs context are retrieved by self.dataset.take.

4.4 Prepare model, optimizer, and updater

In [ ]:
unit = 100  # number of hidden units
window = 5
batchsize = 1000
gpu = 0

# Instantiate model
model = SkipGram(n_vocab, unit)

if gpu >= 0:
    model.to_gpu(gpu)

# Create optimizer
optimizer = O.Adam()
optimizer.setup(model)

# Create iterators for both train and val datasets
train_iter = WindowIterator(train, window, batchsize)
val_iter = WindowIterator(val, window, batchsize, repeat=False)

# Create updater
updater = training.StandardUpdater(
    train_iter, optimizer, converter=convert, device=gpu)

4.5 Start training

In [7]:
epoch = 100

trainer = training.Trainer(updater, (epoch, 'epoch'), out='word2vec_result')
trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=gpu))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
trainer.run()
epoch       main/loss   validation/main/loss  elapsed_time
1           6.87314     6.48688               54.154
2           6.44018     6.40645               107.352
3           6.35021     6.3558                159.544
4           6.28615     6.31679               212.612
5           6.23762     6.28779               266.059
6           6.19942     6.22658               319.874
7           6.15986     6.20715               372.798
8           6.13787     6.21461               426.456
9           6.10637     6.24927               479.725
10          6.08759     6.23192               532.966
11          6.06768     6.19332               586.339
12          6.04607     6.17291               639.295
13          6.0321      6.21226               692.67
14          6.02178     6.18489               746.599
15          6.00098     6.17341               799.408
16          5.99099     6.19581               852.966
17          5.97425     6.22275               905.819
18          5.95974     6.20495               958.404
19          5.96579     6.16532               1012.49
20          5.95292     6.21457               1066.24
21          5.93696     6.18441               1119.45
22          5.91804     6.20695               1171.98
23          5.93265     6.15757               1225.99
24          5.92238     6.17064               1279.85
25          5.9154      6.21545               1334.01
26          5.90538     6.1812                1387.68
27          5.8807      6.18523               1439.72
28          5.89009     6.19992               1492.67
29          5.8773      6.24146               1545.48
30          5.89217     6.21846               1599.79
31          5.88493     6.21654               1653.95
32          5.87784     6.18502               1707.45
33          5.88031     6.14161               1761.75
34          5.86278     6.22893               1815.29
35          5.83335     6.18966               1866.56
36          5.85978     6.24276               1920.18
37          5.85921     6.23888               1974.2
38          5.85195     6.19231               2027.92
39          5.8396      6.20542               2080.78
40          5.83745     6.27583               2133.37
41          5.85996     6.23596               2188
42          5.85743     6.17438               2242.4
43          5.84051     6.25449               2295.84
44          5.83023     6.30226               2348.84
45          5.84677     6.23473               2403.11
46          5.82406     6.27398               2456.11
47          5.82827     6.21509               2509.17
48          5.8253      6.23009               2562.15
49          5.83697     6.2564                2616.35
50          5.81998     6.29104               2669.38
51          5.82926     6.26068               2723.47
52          5.81457     6.30152               2776.36
53          5.82587     6.29581               2830.24
54          5.80614     6.30994               2882.85
55          5.8161      6.23224               2935.73
56          5.80867     6.26867               2988.48
57          5.79467     6.24508               3040.2
58          5.81687     6.24676               3093.57
59          5.82064     6.30236               3147.68
60          5.80855     6.30184               3200.75
61          5.81298     6.25173               3254.06
62          5.80753     6.32951               3307.42
63          5.82505     6.2472                3361.68
64          5.78396     6.28168               3413.14
65          5.80209     6.24962               3465.96
66          5.80107     6.326                 3518.83
67          5.83765     6.28848               3574.57
68          5.7864      6.3506                3626.88
69          5.80329     6.30671               3679.82
70          5.80032     6.29277               3732.69
71          5.80647     6.30722               3786.21
72          5.8176      6.30046               3840.51
73          5.79912     6.35945               3893.81
74          5.80484     6.32439               3947.35
75          5.82065     6.29674               4002.03
76          5.80872     6.27921               4056.05
77          5.80891     6.28952               4110.1
78          5.79121     6.35363               4163.39
79          5.79161     6.32894               4216.34
80          5.78601     6.3255                4268.95
81          5.79062     6.29608               4321.73
82          5.7959      6.37235               4375.25
83          5.77828     6.31001               4427.44
84          5.7879      6.25628               4480.09
85          5.79297     6.29321               4533.27
86          5.79286     6.2725                4586.44
87          5.79388     6.36764               4639.82
88          5.79062     6.33841               4692.89
89          5.7879      6.31828               4745.68
90          5.81015     6.33247               4800.19
91          5.78858     6.37569               4853.31
92          5.7966      6.35733               4907.27
93          5.79814     6.34506               4961.09
94          5.81956     6.322                 5016.65
95          5.81565     6.35974               5071.69
96          5.78953     6.37451               5125.02
97          5.7993      6.42065               5179.34
98          5.79129     6.37995               5232.89
99          5.76834     6.36254               5284.7
100         5.79829     6.3785                5338.93
In [ ]:
vocab = chainer.datasets.get_ptb_words_vocabulary()
index2word = {wid: word for word, wid in six.iteritems(vocab)}

# Save the word2vec model
with open('word2vec.model', 'w') as f:
    f.write('%d %d\n' % (len(index2word), unit))
    w = cuda.to_cpu(model.embed.W.data)
    for i, wi in enumerate(w):
        v = ' '.join(map(str, wi))
        f.write('%s %s\n' % (index2word[i], v))

4.6 Search the similar words

In [ ]:
import numpy
import six

n_result = 5  # number of search result to show


with open('word2vec.model', 'r') as f:
    ss = f.readline().split()
    n_vocab, n_units = int(ss[0]), int(ss[1])
    word2index = {}
    index2word = {}
    w = numpy.empty((n_vocab, n_units), dtype=numpy.float32)
    for i, line in enumerate(f):
        ss = line.split()
        assert len(ss) == n_units + 1
        word = ss[0]
        word2index[word] = i
        index2word[i] = word
        w[i] = numpy.array([float(s) for s in ss[1:]], dtype=numpy.float32)


s = numpy.sqrt((w * w).sum(1))
w /= s.reshape((s.shape[0], 1))  # normalize
In [ ]:
def search(query):
  if query not in word2index:
    print('"{0}" is not found'.format(query))
    return

  v = w[word2index[query]]
  similarity = w.dot(v)
  print('query: {}'.format(query))

  count = 0
  for i in (-similarity).argsort():
      if numpy.isnan(similarity[i]):
          continue
      if index2word[i] == query:
          continue
      print('{0}: {1}'.format(index2word[i], similarity[i]))
      count += 1
      if count == n_result:
          return

Search by “apple” word.

In [23]:
query = "apple"
search(query)
query: apple
computer: 0.5457335710525513
compaq: 0.5068206191062927
microsoft: 0.4654524028301239
network: 0.42985647916793823
trotter: 0.42716777324676514

5. Reference

In [ ]: