Word2Vec: Obtain word embeddings

0. Introduction

word2vecは単語の分散表現を生成するツールで、Mikolov et al[1]によって提案されました。単語の意味が近いほど類似度が大きくなるように、word2vecは各単語に実ベクトルを割り当てます。

ここで、分散表現とは各単語に対して実ベクトルを割り当て、そのベクトルで単語を表現することです。分散表現で単語を表現する場合、そのベクトルをword embeddings(単語埋め込み) と呼びます。このNotebookでは、Penn Tree Bankのデータセットからword embeddingsを獲得する方法を説明します。

さて、そもそも単語の意味とはなんでしょうか。人であれば「動物」と「犬」という単語が似ているというのはなんとなく分かります。しかし、word2vecは何の情報を元に、「動物」と「犬」は似ているとか、「食べ物」と「犬」は似ていないといった意味の類似度を学習すれば良いのでしょうか。

1. 基本的なアイデア

word2vecは単語の意味の類似度を単純な情報から学習します。それは文章における単語の並びです、つまりある単語の意味は、その単語の周囲の単語で決まるというアイデアです。 このアイデアはdistributional hypothesis(分布仮設)[2]に基づいています。

学習対象の単語をcenter word、その周囲の単語をcontext wordsと呼びます。ウィンドウサイズCに応じてcontex wordの数は変わります。

例として、The cute cat jumps over the lazy dog.という文で説明を行います。 以下の図は全てcenter wordをcatとした場合のものです。 ウィンドウサイズCに応じて、catを学習する際に使用するcontex wordが変わることがわかると思います。

|center\_context\_word.png|

2. 主なアルゴリズム

word2vecと呼ばれる手法は実はSkip-gramCBoWという2つの手法の総称です。

To explain the models with the figures below, we will use the following symbols.

  • \(|\mathcal{V}|\) : ボキャブラリ数
  • \(D\) : 埋め込みベクトルのサイズ
  • \({\bf v}_t\) : center wordのone-hotベクトル
  • \(V_{\pm C}\) : \({\bf v}_t\)の周囲のcontext wordのone-hotベクトルの集合、つまり\(\{{\bf v}_{t+c}\}_{c=-C}^C \backslash {\bf v}_t\)
  • \({\bf l}_H\) : 入力単語に対する埋め込みベクトル
  • \({\bf l}_O\) : ネットワークの出力ベクトル
  • \({\bf W}_H\) : 入力に対する埋め込み行列
  • \({\bf W}_O\) : 出力に対する埋め込み行列

Note

negative samplinghierarchical softmaxをロス関数に使うことが一般的だが、すべての単語に対するsoftmax関数を使い、説明を簡略化するため他の説明は省略します。

2.1 Skip-gram

このモデルは、 center wordが与えられたときにその周囲のcontext words \(V_{t \pm C}\)を予測するように学習します。この時、入力に対する埋め込み行列\(W_H\)の各行が各単語の分散表現になります。

center word \({\bf v}_t\)をネットワークに入力したとき、以下のようにしてcontext words \(\hat{\bf v}_{t+i} \in V_{t \pm C}\)を予測することができます

  1. 入力されたcenter wordに対する埋め込みベクトルを計算する: \({\bf l}_H = {\bf W}_H {\bf v}_t\)
  2. 埋め込みベクトルを使って出力ベクトルを計算する: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
  3. context wordの確率ベクトルを計算する: \(\hat{\bf v}_{t+i} = \text{softmax}({\bf l}_O)\)

\(|\mathcal{V}|\)次元のベクトル\(\hat{\bf v}_{t+i}\)の各要素は、各単語がcontext wordである確率です。そのため、確率\(p({\bf v}_{t+i} \mid {\bf v}_t)\)は、context wordのone-hotベクトル\({\bf v}_{t+i}\)と確率ベクトル\(\hat{\bf v}_{t+i}\)の内積で計算することができます。

\begin{eqnarray} p({\bf v}_{t+i} \mid {\bf v}_t) = {\bf v}_{t+i}^T \hat{\bf v}_{t+i} \end{eqnarray}

そして、center word \({\bf v}_t\)に対するすべてのcontext word\(V_{t \pm C}\)のロス関数は以下で計算することができます。

\begin{eqnarray} L(V_{t \pm C} | {\bf v}_t; {\bf W}_H, {\bf W}_O) &=& \sum_{V_{t \pm C}} -\log\left(p({\bf v}_{t+i} \mid {\bf v}_t)\right) \\ &=& \sum_{V_{t \pm C}} -\log({\bf v}_{t+i}^T \hat{\bf v}_{t+i}) \end{eqnarray}

2.2 Continuous Bag of Words (CBoW)

このモデルは、context word \(V_{t \pm C}\) が与えられたときにcenter word \({\bf v}_t\)を予測するように学習します。

context words \(V_{t \pm C}\)をネットワークに与えたとき、以下のようにcenter word \(\hat{v}_t\)の確率を計算することができます。

  1. すべてのcontext wordに対する埋め込みベクトルの平均を計算します: \({\bf l}_H = \frac{1}{2C} \sum_{V_{t \pm C}} {\bf W}_H {\bf v}_{t+i}\)
  2. 埋め込みベクトルを使って出力ベクトルを計算します: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
  3. center wordの確率ベクトルを計算する: \(\hat{\bf v}_t = \text{softmax}({\bf l}_O)\)

\(|\mathcal{V}|\)次元のベクトル\(\hat{\bf v}_t\)の各要素は、各単語がcenter wordである確率です。そのため、確率\(p({\bf v}_t \mid V_{t \pm C})\)は、center wordのone-hotベクトル\({\bf v}_{t}\)と確率ベクトル\(\hat{\bf v}_{t}\)の内積で計算することができます。

\begin{eqnarray} p({\bf v}_{t} \mid V_{t \pm C}) = {\bf v}_{t}^T \hat{\bf v}_{t} \end{eqnarray}

The loss function for the center word prediction is defined as follows:

そして、context word\(V_{t \pm C}\)対するcenter word \({\bf v}_t\)のロス関数は以下で計算することができます。

\begin{eqnarray} L({\bf v}_t|V_{t \pm C}; W_H, W_O) &=& -\log(p({\bf v}_t|V_{t \pm C})) \\ &=& -\log({\bf v}_t^T \hat{\bf v}_t) \end{eqnarray}

3. Skip-gramの詳細

本チュートリアルでは、以下の観点からSkip-gramをメインで扱います。

  1. 学習アルゴリズムがCBoWに比べて理解しやすい
  2. 単語数が増えても精度が落ちにくく、スケールしやすい

skip-gramのアルゴリズムを理解するために、以下の設定で具体的な例から考えてみましょう:

  • ボキャブラリ数 \(|\mathcal{V}|\) は10。
  • 埋め込みベクトルのサイズ\(D\)は2。
  • Center wordはdog。
  • Context wordはanimal。

そして、以下の工程をcontext word数回繰り返します。

  1. dogのone-hotベクトルは[0 0 1 0 0 0 0 0 0 0]で、 これをモデルに入力する。
  2. このとき、埋め込み行列\({\bf W}_H\)の3番目の行\({\bf l}_H\)がdogの埋め込みベクトルとなる。
  3. そして、出力ベクトル\({\bf l}_O\)を計算するため、\({\bf W}_O\)\({\bf l}_H\)の積を計算する。
  4. \(c\)の位置にあるcontext wordの確率ベクトル\(\hat{\bf v}_{t+c}\)を予測するため\({\bf l}_O\)をsoftmax関数に入力する。
  5. \(\hat{\bf v}_{t+c}\)と animalのone-hotベクトル[1 0 0 0 0 0 0 0 0 0 0]の誤差を計算する。
  6. 誤差を伝播させてネットワークのパラメータを更新する。

|skipgram\_detail.png|

4. Chainerによるskip-gram実装方法

GitHubレポジトリ上のexamples内にword2vecに関するコードがあるので、それに基づいて説明をしていきます。chainer/examples/word2vec

まずは、以下のセルを実行して、ChainerとそのGPUバックエンドであるCuPyをインストールします。Colaboratoryの「ランタイムのタイプ」がGPUであれば、GPUをバックエンドとしてChainerを動かすことができます。

[ ]:
!curl https://colab.chainer.org/install | sh -
Reading package lists... Done
Building dependency tree
Reading state information... Done
libcusparse8.0 is already the newest version (8.0.61-1).
libnvrtc8.0 is already the newest version (8.0.61-1).
libnvtoolsext1 is already the newest version (8.0.61-1).
0 upgraded, 0 newly installed, 0 to remove and 0 not upgraded.

4.1 準備

必要なパッケージをimportしましょう。

[ ]:
import argparse
import collections

import numpy as np
import six

import chainer
from chainer import cuda
import chainer.functions as F
import chainer.initializers as I
import chainer.links as L
import chainer.optimizers as O
from chainer import reporter
from chainer import training
from chainer.training import extensions

4.2 skip-gramモデルの定義

次にskip-gramのネットワーク構造を定義しましょう。

[ ]:
class SkipGram(chainer.Chain):

    def __init__(self, n_vocab, n_units):
        super().__init__()
        with self.init_scope():
            self.embed = L.EmbedID(
                n_vocab, n_units, initialW=I.Uniform(1. / n_units))
            self.out = L.Linear(n_units, n_vocab, initialW=0)

    def __call__(self, x, context):
        e = self.embed(context)
        shape = e.shape
        x = F.broadcast_to(x[:, None], (shape[0], shape[1]))
        e = F.reshape(e, (shape[0] * shape[1], shape[2]))
        x = F.reshape(x, (shape[0] * shape[1],))
        center_predictions = self.out(e)
        loss = F.softmax_cross_entropy(center_predictions, x)
        reporter.report({'loss': loss}, self)
        return loss

Note

  • 重み行列self.embed.Wは入力xに対する埋め込み行列です。
  • __call__は center wordの単語ID xとcontext wordの単語ID contextsを入力として取ります。そして、ロス関数softmax_cross_entropyで計算された誤差を出力します。
  • 注意してもらいたいのが、 xcontextsの形がそれぞれ(batch_size,)(batch_size, n_context)になっていることです。
  • batch_sizeはミニバッチサイズを意味し、 n_contextはcontext word数を意味します。

まず、e = self.embed(contexts)contextsに対応する分散表現を取得しています。

そして、 F.broadcast_to(x[:, None], (shape[0], shape[1]))とすることで、x((batch_size,)) を(batch_size, n_context)の形にブロードキャストします。このとき、 列方向にn_context回だけ同じ値がコピーされます。そして、ブロードキャストされたxは1次元ベクトルにreshapeされ、(batchsize * n_context,)になります。一方で、e(batch_size * n_context, n_units)の形にreshapeされます。

注意してもらいたいのが、skip-gramの場合、center wordとcontext wordは1対1で対応するため、center wordとcontext wordを入れ替えてモデル化しても問題がないです。そのため、上記ではcenter wordとcontext wordを入れ替えて学習させているように見えますが、問題はありません。なぜこのようなことをするかと言うと、CBoWモデルとコードの整合性が取りやすいからです。

4.3 datasetとiteratorの準備

Chainer’が用意するユーティリティ関数get_ptb_words()を使って、Penn Tree Bank (PTB)のデータセットをダウンロードしましょう。

[ ]:
train, val, _ = chainer.datasets.get_ptb_words()
n_vocab = max(train) + 1  # The minimum word ID is 0

center wordと、そのcontext wordを含むミニバッチを生成するIteratorを定義しましょう。

[ ]:
class WindowIterator(chainer.dataset.Iterator):

    def __init__(self, dataset, window, batch_size, repeat=True):
        self.dataset = np.array(dataset, np.int32)
        self.window = window
        self.batch_size = batch_size
        self._repeat = repeat

        self.order = np.random.permutation(
            len(dataset) - window * 2).astype(np.int32)
        self.order += window
        self.current_position = 0
        self.epoch = 0
        self.is_new_epoch = False

    def __next__(self):
        if not self._repeat and self.epoch > 0:
            raise StopIteration

        i = self.current_position
        i_end = i + self.batch_size
        position = self.order[i: i_end]
        w = np.random.randint(self.window - 1) + 1
        offset = np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])
        pos = position[:, None] + offset[None, :]
        context = self.dataset.take(pos)
        center = self.dataset.take(position)

        if i_end >= len(self.order):
            np.random.shuffle(self.order)
            self.epoch += 1
            self.is_new_epoch = True
            self.current_position = 0
        else:
            self.is_new_epoch = False
            self.current_position = i_end

        return center, context

    @property
    def epoch_detail(self):
        return self.epoch + float(self.current_position) / len(self.order)

    def serialize(self, serializer):
        self.current_position = serializer('current_position',
                                           self.current_position)
        self.epoch = serializer('epoch', self.epoch)
        self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
        if self._order is not None:
            serializer('_order', self._order)

def convert(batch, device):
    center, context = batch
    if device >= 0:
        center = cuda.to_gpu(center)
        context = cuda.to_gpu(context)
    return center, context
  • コンストラクタの中で、文書中の単語の位置をシャッフルしたリストself.orderを作成しています。文書からランダムに単語を選択し学習するようにするためです。ウィンドウサイズ分だけ最初と最後を切り取った単語の位置がシャッフルされて入っています。
  • イテレータの定義__next__は、コンストラクタのパラメータに従ってミニバッチサイズ個のcenter word centerとcontext word contextを返します。
  • self.order[i:i_end]で、単語の位置をシャッフルしたリストself.orderからbatch_size分のcenter wordのインデックスpositionを生成します。(positionは後でself.dataset.takeによってcenter word centerに変換されます。)
  • np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])で、ウインドウを表現するオフセットoffsetを作成しています。
  • position[:, None] + offset[None, :]によって、それぞれのcenter wordに対するcontext word のインデックスposを生成します。posは後でself.dataset.takeによってcontext word contextに変換されます。

4.4 model, optimizer, updaterの準備

[ ]:
unit = 100  # number of hidden units
window = 5
batchsize = 1000
gpu = 0

# Instantiate model
model = SkipGram(n_vocab, unit)

if gpu >= 0:
    model.to_gpu(gpu)

# Create optimizer
optimizer = O.Adam()
optimizer.setup(model)

# Create iterators for both train and val datasets
train_iter = WindowIterator(train, window, batchsize)
val_iter = WindowIterator(val, window, batchsize, repeat=False)

# Create updater
updater = training.StandardUpdater(
    train_iter, optimizer, converter=convert, device=gpu)

4.5 trainingの開始

[ ]:
epoch = 100

trainer = training.Trainer(updater, (epoch, 'epoch'), out='word2vec_result')
trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=gpu))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
trainer.run()
epoch       main/loss   validation/main/loss  elapsed_time
1           6.87314     6.48688               54.154
2           6.44018     6.40645               107.352
3           6.35021     6.3558                159.544
4           6.28615     6.31679               212.612
5           6.23762     6.28779               266.059
6           6.19942     6.22658               319.874
7           6.15986     6.20715               372.798
8           6.13787     6.21461               426.456
9           6.10637     6.24927               479.725
10          6.08759     6.23192               532.966
11          6.06768     6.19332               586.339
12          6.04607     6.17291               639.295
13          6.0321      6.21226               692.67
14          6.02178     6.18489               746.599
15          6.00098     6.17341               799.408
16          5.99099     6.19581               852.966
17          5.97425     6.22275               905.819
18          5.95974     6.20495               958.404
19          5.96579     6.16532               1012.49
20          5.95292     6.21457               1066.24
21          5.93696     6.18441               1119.45
22          5.91804     6.20695               1171.98
23          5.93265     6.15757               1225.99
24          5.92238     6.17064               1279.85
25          5.9154      6.21545               1334.01
26          5.90538     6.1812                1387.68
27          5.8807      6.18523               1439.72
28          5.89009     6.19992               1492.67
29          5.8773      6.24146               1545.48
30          5.89217     6.21846               1599.79
31          5.88493     6.21654               1653.95
32          5.87784     6.18502               1707.45
33          5.88031     6.14161               1761.75
34          5.86278     6.22893               1815.29
35          5.83335     6.18966               1866.56
36          5.85978     6.24276               1920.18
37          5.85921     6.23888               1974.2
38          5.85195     6.19231               2027.92
39          5.8396      6.20542               2080.78
40          5.83745     6.27583               2133.37
41          5.85996     6.23596               2188
42          5.85743     6.17438               2242.4
43          5.84051     6.25449               2295.84
44          5.83023     6.30226               2348.84
45          5.84677     6.23473               2403.11
46          5.82406     6.27398               2456.11
47          5.82827     6.21509               2509.17
48          5.8253      6.23009               2562.15
49          5.83697     6.2564                2616.35
50          5.81998     6.29104               2669.38
51          5.82926     6.26068               2723.47
52          5.81457     6.30152               2776.36
53          5.82587     6.29581               2830.24
54          5.80614     6.30994               2882.85
55          5.8161      6.23224               2935.73
56          5.80867     6.26867               2988.48
57          5.79467     6.24508               3040.2
58          5.81687     6.24676               3093.57
59          5.82064     6.30236               3147.68
60          5.80855     6.30184               3200.75
61          5.81298     6.25173               3254.06
62          5.80753     6.32951               3307.42
63          5.82505     6.2472                3361.68
64          5.78396     6.28168               3413.14
65          5.80209     6.24962               3465.96
66          5.80107     6.326                 3518.83
67          5.83765     6.28848               3574.57
68          5.7864      6.3506                3626.88
69          5.80329     6.30671               3679.82
70          5.80032     6.29277               3732.69
71          5.80647     6.30722               3786.21
72          5.8176      6.30046               3840.51
73          5.79912     6.35945               3893.81
74          5.80484     6.32439               3947.35
75          5.82065     6.29674               4002.03
76          5.80872     6.27921               4056.05
77          5.80891     6.28952               4110.1
78          5.79121     6.35363               4163.39
79          5.79161     6.32894               4216.34
80          5.78601     6.3255                4268.95
81          5.79062     6.29608               4321.73
82          5.7959      6.37235               4375.25
83          5.77828     6.31001               4427.44
84          5.7879      6.25628               4480.09
85          5.79297     6.29321               4533.27
86          5.79286     6.2725                4586.44
87          5.79388     6.36764               4639.82
88          5.79062     6.33841               4692.89
89          5.7879      6.31828               4745.68
90          5.81015     6.33247               4800.19
91          5.78858     6.37569               4853.31
92          5.7966      6.35733               4907.27
93          5.79814     6.34506               4961.09
94          5.81956     6.322                 5016.65
95          5.81565     6.35974               5071.69
96          5.78953     6.37451               5125.02
97          5.7993      6.42065               5179.34
98          5.79129     6.37995               5232.89
99          5.76834     6.36254               5284.7
100         5.79829     6.3785                5338.93
[ ]:
vocab = chainer.datasets.get_ptb_words_vocabulary()
index2word = {wid: word for word, wid in six.iteritems(vocab)}

# Save the word2vec model
with open('word2vec.model', 'w') as f:
    f.write('%d %d\n' % (len(index2word), unit))
    w = cuda.to_cpu(model.embed.W.data)
    for i, wi in enumerate(w):
        v = ' '.join(map(str, wi))
        f.write('%s %s\n' % (index2word[i], v))

4.6 似た単語の検索

[ ]:
import numpy
import six

n_result = 5  # number of search result to show


with open('word2vec.model', 'r') as f:
    ss = f.readline().split()
    n_vocab, n_units = int(ss[0]), int(ss[1])
    word2index = {}
    index2word = {}
    w = numpy.empty((n_vocab, n_units), dtype=numpy.float32)
    for i, line in enumerate(f):
        ss = line.split()
        assert len(ss) == n_units + 1
        word = ss[0]
        word2index[word] = i
        index2word[i] = word
        w[i] = numpy.array([float(s) for s in ss[1:]], dtype=numpy.float32)


s = numpy.sqrt((w * w).sum(1))
w /= s.reshape((s.shape[0], 1))  # normalize
[ ]:
def search(query):
  if query not in word2index:
    print('"{0}" is not found'.format(query))
    return

  v = w[word2index[query]]
  similarity = w.dot(v)
  print('query: {}'.format(query))

  count = 0
  for i in (-similarity).argsort():
      if numpy.isnan(similarity[i]):
          continue
      if index2word[i] == query:
          continue
      print('{0}: {1}'.format(index2word[i], similarity[i]))
      count += 1
      if count == n_result:
          return

appleで検索してみましょう。

[ ]:
query = "apple"
search(query)
query: apple
computer: 0.5457335710525513
compaq: 0.5068206191062927
microsoft: 0.4654524028301239
network: 0.42985647916793823
trotter: 0.42716777324676514

5. Reference

[ ]: