-
Notifications
You must be signed in to change notification settings - Fork 1
/
word2vec_align.py
108 lines (84 loc) · 3.87 KB
/
word2vec_align.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import numpy as np
import gensim
# taken from https://gist.github.com/quadrismegistus/09a93e219a6ffc4f216fb85235535faf
def smart_align_gensim(base_embed, other_embed, words=None):
"""
Procrustes align two gensim word2vec models (to allow for comparison between
same word across models).
Code ported from HistWords <https://github.com/williamleif/histwords> by
William Hamilton <[email protected]>.
(With help from William. Thank you!)
First, intersect the vocabularies (see `intersection_align_gensim` documentation).
Then do the alignment on the other_embed model.
Replace the other_embed model's syn0 and syn0norm numpy matrices with the aligned version.
Return other_embed.
If `words` is set, intersect the two models' vocabulary with the vocabulary
in words (see `intersection_align_gensim` documentation).
"""
base_embed.init_sims()
other_embed.init_sims()
# make sure vocabulary and indices are aligned
in_base_embed, in_other_embed = intersection_align_gensim(base_embed,
other_embed,
words=words)
ortho = procrustes_align(in_base_embed.syn0norm, in_other_embed.syn0norm)
# multiply the embedding matrix (syn0norm) by "ortho"
# Replace original array with modified one
other_embed.syn0norm = other_embed.syn0 = (other_embed.syn0norm).dot(ortho)
return other_embed
def procrustes_align(base_vecs, other_vecs):
# just a matrix dot product with numpy
m = other_vecs.T.dot(base_vecs)
# SVD method from numpy
u, _, v = np.linalg.svd(m)
# another matrix operation
return u.dot(v)
def intersection_align_gensim(m1, m2, words=None):
"""
Intersect two gensim word2vec models, m1 and m2.
Only the shared vocabulary between them is kept.
If 'words' is set (as list or set), then the vocabulary is intersected with
this list as well. Indices are re-organized from 0..N in order of
descending frequency (= sum of counts from both m1 and m2).
These indices correspond to the new syn0 and syn0norm objects in both gensim models:
-- so that Row 0 of m1.syn0 will be for the same word as Row 0 of m2.syn0
-- you can find the index of any word on the .index2word list:
model.index2word.index(word) => 2
The .vocab dictionary is also updated for each model, preserving the count
but updating the index.
"""
# Get the vocab for each model
vocab_m1 = set(m1.vocab.keys())
vocab_m2 = set(m2.vocab.keys())
# Find the common vocabulary
common_vocab = vocab_m1 & vocab_m2
if words:
common_vocab &= set(words)
# If no alignment necessary because vocab is identical...
if not vocab_m1 - common_vocab and not vocab_m2 - common_vocab:
return (m1, m2)
# Otherwise sort by frequency (summed for both)
common_vocab = list(common_vocab)
common_vocab.sort(key=lambda w: m1.vocab[w].count +
m2.vocab[w].count,reverse=True)
# Then for each model...
for m in [m1, m2]:
align_vocab(common_vocab, m)
return (m1, m2)
def align_vocab(i2w, m):
""" force m's vocab to be the same order as i2w. """
# Replace old syn0norm array with new one (with common vocab)
indices = [m.vocab[w].index for w in i2w]
old_arr = m.syn0norm
new_arr = np.array([old_arr[index] for index in indices])
m.syn0norm = m.syn0 = new_arr
# Replace old vocab dictionary with new one (with common vocab)
# and old index2word with new one
m.index2word = i2w
old_vocab = m.vocab
new_vocab = {}
for new_index, word in enumerate(i2w):
old_vocab_obj = old_vocab[word]
new_vocab[word] = gensim.models.word2vec.Vocab(index=new_index,
count=old_vocab_obj.count)
m.vocab = new_vocab