Word2Vec: Obtain word embeddings¶
0. Introduction¶
word2vecは単語の分散表現を生成するツールで、Mikolov et al[1]によって提案されました。単語の意味が近いほど類似度が大きくなるように、word2vecは各単語に実ベクトルを割り当てます。
ここで、分散表現とは各単語に対して実ベクトルを割り当て、そのベクトルで単語を表現することです。分散表現で単語を表現する場合、そのベクトルをword embeddings(単語埋め込み) と呼びます。このNotebookでは、Penn Tree Bankのデータセットからword embeddingsを獲得する方法を説明します。
さて、そもそも単語の意味とはなんでしょうか。人であれば「動物」と「犬」という単語が似ているというのはなんとなく分かります。しかし、word2vecは何の情報を元に、「動物」と「犬」は似ているとか、「食べ物」と「犬」は似ていないといった意味の類似度を学習すれば良いのでしょうか。
1. 基本的なアイデア¶
word2vecは単語の意味の類似度を単純な情報から学習します。それは文章における単語の並びです、つまりある単語の意味は、その単語の周囲の単語で決まるというアイデアです。 このアイデアはdistributional hypothesis(分布仮設)[2]に基づいています。
学習対象の単語をcenter word、その周囲の単語をcontext wordsと呼びます。ウィンドウサイズC
に応じてcontex wordの数は変わります。
例として、The cute cat jumps over the lazy dog.という文で説明を行います。 以下の図は全てcenter wordをcatとした場合のものです。 ウィンドウサイズC
に応じて、catを学習する際に使用するcontex wordが変わることがわかると思います。
2. 主なアルゴリズム¶
word2vecと呼ばれる手法は実はSkip-gramとCBoWという2つの手法の総称です。
To explain the models with the figures below, we will use the following symbols.
- \(|\mathcal{V}|\) : ボキャブラリ数
- \(D\) : 埋め込みベクトルのサイズ
- \({\bf v}_t\) : center wordのone-hotベクトル
- \(V_{\pm C}\) : \({\bf v}_t\)の周囲のcontext wordのone-hotベクトルの集合、つまり\(\{{\bf v}_{t+c}\}_{c=-C}^C \backslash {\bf v}_t\)
- \({\bf l}_H\) : 入力単語に対する埋め込みベクトル
- \({\bf l}_O\) : ネットワークの出力ベクトル
- \({\bf W}_H\) : 入力に対する埋め込み行列
- \({\bf W}_O\) : 出力に対する埋め込み行列
Note
negative samplingやhierarchical softmaxをロス関数に使うことが一般的だが、すべての単語に対するsoftmax関数を使い、説明を簡略化するため他の説明は省略します。
2.1 Skip-gram¶
このモデルは、 center wordが与えられたときにその周囲のcontext words \(V_{t \pm C}\)を予測するように学習します。この時、入力に対する埋め込み行列\(W_H\)の各行が各単語の分散表現になります。
center word \({\bf v}_t\)をネットワークに入力したとき、以下のようにしてcontext words \(\hat{\bf v}_{t+i} \in V_{t \pm C}\)を予測することができます
- 入力されたcenter wordに対する埋め込みベクトルを計算する: \({\bf l}_H = {\bf W}_H {\bf v}_t\)
- 埋め込みベクトルを使って出力ベクトルを計算する: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
- context wordの確率ベクトルを計算する: \(\hat{\bf v}_{t+i} = \text{softmax}({\bf l}_O)\)
\(|\mathcal{V}|\)次元のベクトル\(\hat{\bf v}_{t+i}\)の各要素は、各単語がcontext wordである確率です。そのため、確率\(p({\bf v}_{t+i} \mid {\bf v}_t)\)は、context wordのone-hotベクトル\({\bf v}_{t+i}\)と確率ベクトル\(\hat{\bf v}_{t+i}\)の内積で計算することができます。
そして、center word \({\bf v}_t\)に対するすべてのcontext word\(V_{t \pm C}\)のロス関数は以下で計算することができます。
2.2 Continuous Bag of Words (CBoW)¶
このモデルは、context word \(V_{t \pm C}\) が与えられたときにcenter word \({\bf v}_t\)を予測するように学習します。
context words \(V_{t \pm C}\)をネットワークに与えたとき、以下のようにcenter word \(\hat{v}_t\)の確率を計算することができます。
- すべてのcontext wordに対する埋め込みベクトルの平均を計算します: \({\bf l}_H = \frac{1}{2C} \sum_{V_{t \pm C}} {\bf W}_H {\bf v}_{t+i}\)
- 埋め込みベクトルを使って出力ベクトルを計算します: \({\bf l}_O = {\bf W}_O {\bf l}_H\)
- center wordの確率ベクトルを計算する: \(\hat{\bf v}_t = \text{softmax}({\bf l}_O)\)
\(|\mathcal{V}|\)次元のベクトル\(\hat{\bf v}_t\)の各要素は、各単語がcenter wordである確率です。そのため、確率\(p({\bf v}_t \mid V_{t \pm C})\)は、center wordのone-hotベクトル\({\bf v}_{t}\)と確率ベクトル\(\hat{\bf v}_{t}\)の内積で計算することができます。
The loss function for the center word prediction is defined as follows:
そして、context word\(V_{t \pm C}\)対するcenter word \({\bf v}_t\)のロス関数は以下で計算することができます。
3. Skip-gramの詳細¶
本チュートリアルでは、以下の観点からSkip-gramをメインで扱います。
- 学習アルゴリズムがCBoWに比べて理解しやすい
- 単語数が増えても精度が落ちにくく、スケールしやすい
skip-gramのアルゴリズムを理解するために、以下の設定で具体的な例から考えてみましょう:
- ボキャブラリ数 \(|\mathcal{V}|\) は10。
- 埋め込みベクトルのサイズ\(D\)は2。
- Center wordはdog。
- Context wordはanimal。
そして、以下の工程をcontext word数回繰り返します。
- dogのone-hotベクトルは
[0 0 1 0 0 0 0 0 0 0]
で、 これをモデルに入力する。 - このとき、埋め込み行列\({\bf W}_H\)の3番目の行\({\bf l}_H\)がdogの埋め込みベクトルとなる。
- そして、出力ベクトル\({\bf l}_O\)を計算するため、\({\bf W}_O\)と\({\bf l}_H\)の積を計算する。
- \(c\)の位置にあるcontext wordの確率ベクトル\(\hat{\bf v}_{t+c}\)を予測するため\({\bf l}_O\)をsoftmax関数に入力する。
- \(\hat{\bf v}_{t+c}\)と animalのone-hotベクトル
[1 0 0 0 0 0 0 0 0 0 0]
の誤差を計算する。 - 誤差を伝播させてネットワークのパラメータを更新する。
4. Chainerによるskip-gram実装方法¶
GitHubレポジトリ上のexamples内にword2vecに関するコードがあるので、それに基づいて説明をしていきます。chainer/examples/word2vec
まずは、以下のセルを実行して、ChainerとそのGPUバックエンドであるCuPyをインストールします。Colaboratoryの「ランタイムのタイプ」がGPUであれば、GPUをバックエンドとしてChainerを動かすことができます。
[ ]:
!curl https://colab.chainer.org/install | sh -
Reading package lists... Done
Building dependency tree
Reading state information... Done
libcusparse8.0 is already the newest version (8.0.61-1).
libnvrtc8.0 is already the newest version (8.0.61-1).
libnvtoolsext1 is already the newest version (8.0.61-1).
0 upgraded, 0 newly installed, 0 to remove and 0 not upgraded.
4.1 準備¶
必要なパッケージをimport
しましょう。
[ ]:
import argparse
import collections
import numpy as np
import six
import chainer
from chainer import cuda
import chainer.functions as F
import chainer.initializers as I
import chainer.links as L
import chainer.optimizers as O
from chainer import reporter
from chainer import training
from chainer.training import extensions
4.2 skip-gramモデルの定義¶
次にskip-gramのネットワーク構造を定義しましょう。
[ ]:
class SkipGram(chainer.Chain):
def __init__(self, n_vocab, n_units):
super().__init__()
with self.init_scope():
self.embed = L.EmbedID(
n_vocab, n_units, initialW=I.Uniform(1. / n_units))
self.out = L.Linear(n_units, n_vocab, initialW=0)
def __call__(self, x, context):
e = self.embed(context)
shape = e.shape
x = F.broadcast_to(x[:, None], (shape[0], shape[1]))
e = F.reshape(e, (shape[0] * shape[1], shape[2]))
x = F.reshape(x, (shape[0] * shape[1],))
center_predictions = self.out(e)
loss = F.softmax_cross_entropy(center_predictions, x)
reporter.report({'loss': loss}, self)
return loss
Note
- 重み行列
self.embed.W
は入力x
に対する埋め込み行列です。 __call__
は center wordの単語IDx
とcontext wordの単語IDcontexts
を入力として取ります。そして、ロス関数softmax_cross_entropy
で計算された誤差を出力します。- 注意してもらいたいのが、
x
とcontexts
の形がそれぞれ(batch_size,)
と(batch_size, n_context)
になっていることです。 batch_size
はミニバッチサイズを意味し、n_context
はcontext word数を意味します。
まず、e = self.embed(contexts)
でcontexts
に対応する分散表現を取得しています。
そして、 F.broadcast_to(x[:, None], (shape[0], shape[1]))
とすることで、x
((batch_size,)
) を(batch_size, n_context)
の形にブロードキャストします。このとき、 列方向にn_context
回だけ同じ値がコピーされます。そして、ブロードキャストされたx
は1次元ベクトルにreshapeされ、(batchsize * n_context,)
になります。一方で、e
は(batch_size * n_context, n_units)
の形にreshapeされます。
注意してもらいたいのが、skip-gramの場合、center wordとcontext wordは1対1で対応するため、center wordとcontext wordを入れ替えてモデル化しても問題がないです。そのため、上記ではcenter wordとcontext wordを入れ替えて学習させているように見えますが、問題はありません。なぜこのようなことをするかと言うと、CBoWモデルとコードの整合性が取りやすいからです。
4.3 datasetとiteratorの準備¶
Chainer’が用意するユーティリティ関数get_ptb_words()
を使って、Penn Tree Bank (PTB)のデータセットをダウンロードしましょう。
[ ]:
train, val, _ = chainer.datasets.get_ptb_words()
n_vocab = max(train) + 1 # The minimum word ID is 0
center wordと、そのcontext wordを含むミニバッチを生成するIteratorを定義しましょう。
[ ]:
class WindowIterator(chainer.dataset.Iterator):
def __init__(self, dataset, window, batch_size, repeat=True):
self.dataset = np.array(dataset, np.int32)
self.window = window
self.batch_size = batch_size
self._repeat = repeat
self.order = np.random.permutation(
len(dataset) - window * 2).astype(np.int32)
self.order += window
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
i = self.current_position
i_end = i + self.batch_size
position = self.order[i: i_end]
w = np.random.randint(self.window - 1) + 1
offset = np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])
pos = position[:, None] + offset[None, :]
context = self.dataset.take(pos)
center = self.dataset.take(position)
if i_end >= len(self.order):
np.random.shuffle(self.order)
self.epoch += 1
self.is_new_epoch = True
self.current_position = 0
else:
self.is_new_epoch = False
self.current_position = i_end
return center, context
@property
def epoch_detail(self):
return self.epoch + float(self.current_position) / len(self.order)
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
if self._order is not None:
serializer('_order', self._order)
def convert(batch, device):
center, context = batch
if device >= 0:
center = cuda.to_gpu(center)
context = cuda.to_gpu(context)
return center, context
- コンストラクタの中で、文書中の単語の位置をシャッフルしたリスト
self.order
を作成しています。文書からランダムに単語を選択し学習するようにするためです。ウィンドウサイズ分だけ最初と最後を切り取った単語の位置がシャッフルされて入っています。 - イテレータの定義
__next__
は、コンストラクタのパラメータに従ってミニバッチサイズ個のcenter wordcenter
とcontext wordcontext
を返します。 self.order[i:i_end]
で、単語の位置をシャッフルしたリストself.order
からbatch_size
分のcenter wordのインデックスposition
を生成します。(position
は後でself.dataset.take
によってcenter wordcenter
に変換されます。)np.concatenate([np.arange(-w, 0), np.arange(1, w + 1)])
で、ウインドウを表現するオフセットoffset
を作成しています。position[:, None] + offset[None, :]
によって、それぞれのcenter wordに対するcontext word のインデックスpos
を生成します。pos
は後でself.dataset.take
によってcontext wordcontext
に変換されます。
4.4 model, optimizer, updaterの準備¶
[ ]:
unit = 100 # number of hidden units
window = 5
batchsize = 1000
gpu = 0
# Instantiate model
model = SkipGram(n_vocab, unit)
if gpu >= 0:
model.to_gpu(gpu)
# Create optimizer
optimizer = O.Adam()
optimizer.setup(model)
# Create iterators for both train and val datasets
train_iter = WindowIterator(train, window, batchsize)
val_iter = WindowIterator(val, window, batchsize, repeat=False)
# Create updater
updater = training.StandardUpdater(
train_iter, optimizer, converter=convert, device=gpu)
4.5 trainingの開始¶
[ ]:
epoch = 100
trainer = training.Trainer(updater, (epoch, 'epoch'), out='word2vec_result')
trainer.extend(extensions.Evaluator(val_iter, model, converter=convert, device=gpu))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'validation/main/loss', 'elapsed_time']))
trainer.run()
epoch main/loss validation/main/loss elapsed_time
1 6.87314 6.48688 54.154
2 6.44018 6.40645 107.352
3 6.35021 6.3558 159.544
4 6.28615 6.31679 212.612
5 6.23762 6.28779 266.059
6 6.19942 6.22658 319.874
7 6.15986 6.20715 372.798
8 6.13787 6.21461 426.456
9 6.10637 6.24927 479.725
10 6.08759 6.23192 532.966
11 6.06768 6.19332 586.339
12 6.04607 6.17291 639.295
13 6.0321 6.21226 692.67
14 6.02178 6.18489 746.599
15 6.00098 6.17341 799.408
16 5.99099 6.19581 852.966
17 5.97425 6.22275 905.819
18 5.95974 6.20495 958.404
19 5.96579 6.16532 1012.49
20 5.95292 6.21457 1066.24
21 5.93696 6.18441 1119.45
22 5.91804 6.20695 1171.98
23 5.93265 6.15757 1225.99
24 5.92238 6.17064 1279.85
25 5.9154 6.21545 1334.01
26 5.90538 6.1812 1387.68
27 5.8807 6.18523 1439.72
28 5.89009 6.19992 1492.67
29 5.8773 6.24146 1545.48
30 5.89217 6.21846 1599.79
31 5.88493 6.21654 1653.95
32 5.87784 6.18502 1707.45
33 5.88031 6.14161 1761.75
34 5.86278 6.22893 1815.29
35 5.83335 6.18966 1866.56
36 5.85978 6.24276 1920.18
37 5.85921 6.23888 1974.2
38 5.85195 6.19231 2027.92
39 5.8396 6.20542 2080.78
40 5.83745 6.27583 2133.37
41 5.85996 6.23596 2188
42 5.85743 6.17438 2242.4
43 5.84051 6.25449 2295.84
44 5.83023 6.30226 2348.84
45 5.84677 6.23473 2403.11
46 5.82406 6.27398 2456.11
47 5.82827 6.21509 2509.17
48 5.8253 6.23009 2562.15
49 5.83697 6.2564 2616.35
50 5.81998 6.29104 2669.38
51 5.82926 6.26068 2723.47
52 5.81457 6.30152 2776.36
53 5.82587 6.29581 2830.24
54 5.80614 6.30994 2882.85
55 5.8161 6.23224 2935.73
56 5.80867 6.26867 2988.48
57 5.79467 6.24508 3040.2
58 5.81687 6.24676 3093.57
59 5.82064 6.30236 3147.68
60 5.80855 6.30184 3200.75
61 5.81298 6.25173 3254.06
62 5.80753 6.32951 3307.42
63 5.82505 6.2472 3361.68
64 5.78396 6.28168 3413.14
65 5.80209 6.24962 3465.96
66 5.80107 6.326 3518.83
67 5.83765 6.28848 3574.57
68 5.7864 6.3506 3626.88
69 5.80329 6.30671 3679.82
70 5.80032 6.29277 3732.69
71 5.80647 6.30722 3786.21
72 5.8176 6.30046 3840.51
73 5.79912 6.35945 3893.81
74 5.80484 6.32439 3947.35
75 5.82065 6.29674 4002.03
76 5.80872 6.27921 4056.05
77 5.80891 6.28952 4110.1
78 5.79121 6.35363 4163.39
79 5.79161 6.32894 4216.34
80 5.78601 6.3255 4268.95
81 5.79062 6.29608 4321.73
82 5.7959 6.37235 4375.25
83 5.77828 6.31001 4427.44
84 5.7879 6.25628 4480.09
85 5.79297 6.29321 4533.27
86 5.79286 6.2725 4586.44
87 5.79388 6.36764 4639.82
88 5.79062 6.33841 4692.89
89 5.7879 6.31828 4745.68
90 5.81015 6.33247 4800.19
91 5.78858 6.37569 4853.31
92 5.7966 6.35733 4907.27
93 5.79814 6.34506 4961.09
94 5.81956 6.322 5016.65
95 5.81565 6.35974 5071.69
96 5.78953 6.37451 5125.02
97 5.7993 6.42065 5179.34
98 5.79129 6.37995 5232.89
99 5.76834 6.36254 5284.7
100 5.79829 6.3785 5338.93
[ ]:
vocab = chainer.datasets.get_ptb_words_vocabulary()
index2word = {wid: word for word, wid in six.iteritems(vocab)}
# Save the word2vec model
with open('word2vec.model', 'w') as f:
f.write('%d %d\n' % (len(index2word), unit))
w = cuda.to_cpu(model.embed.W.data)
for i, wi in enumerate(w):
v = ' '.join(map(str, wi))
f.write('%s %s\n' % (index2word[i], v))
4.6 似た単語の検索¶
[ ]:
import numpy
import six
n_result = 5 # number of search result to show
with open('word2vec.model', 'r') as f:
ss = f.readline().split()
n_vocab, n_units = int(ss[0]), int(ss[1])
word2index = {}
index2word = {}
w = numpy.empty((n_vocab, n_units), dtype=numpy.float32)
for i, line in enumerate(f):
ss = line.split()
assert len(ss) == n_units + 1
word = ss[0]
word2index[word] = i
index2word[i] = word
w[i] = numpy.array([float(s) for s in ss[1:]], dtype=numpy.float32)
s = numpy.sqrt((w * w).sum(1))
w /= s.reshape((s.shape[0], 1)) # normalize
[ ]:
def search(query):
if query not in word2index:
print('"{0}" is not found'.format(query))
return
v = w[word2index[query]]
similarity = w.dot(v)
print('query: {}'.format(query))
count = 0
for i in (-similarity).argsort():
if numpy.isnan(similarity[i]):
continue
if index2word[i] == query:
continue
print('{0}: {1}'.format(index2word[i], similarity[i]))
count += 1
if count == n_result:
return
appleで検索してみましょう。
[ ]:
query = "apple"
search(query)
query: apple
computer: 0.5457335710525513
compaq: 0.5068206191062927
microsoft: 0.4654524028301239
network: 0.42985647916793823
trotter: 0.42716777324676514
5. Reference¶
- [1] [Mikolov, Tomas; et al. “Efficient Estimation of Word Representations in Vector Space”. arXiv:1301.3781](https://arxiv.org/abs/1301.3781)
- [2] [Distributional Hypothesis](https://aclweb.org/aclwiki/Distributional_Hypothesis)
[ ]: