Coverage for flair/flair/visual/manifold.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import numpy
2import tqdm
3from sklearn.manifold import TSNE
6class _Transform:
7 def __init__(self):
8 pass
10 def fit(self, X):
11 return self.transform.fit_transform(X)
14class tSNE(_Transform):
15 def __init__(self):
16 super().__init__()
18 self.transform = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
21class Visualizer(object):
22 def visualize_word_emeddings(self, embeddings, sentences, output_file):
23 X = self.prepare_word_embeddings(embeddings, sentences)
24 contexts = self.word_contexts(sentences)
26 trans_ = tSNE()
27 reduced = trans_.fit(X)
29 self.visualize(reduced, contexts, output_file)
31 def visualize_char_emeddings(self, embeddings, sentences, output_file):
32 X = self.prepare_char_embeddings(embeddings, sentences)
33 contexts = self.char_contexts(sentences)
35 trans_ = tSNE()
36 reduced = trans_.fit(X)
38 self.visualize(reduced, contexts, output_file)
40 @staticmethod
41 def prepare_word_embeddings(embeddings, sentences):
42 X = []
44 for sentence in tqdm.tqdm(sentences):
45 embeddings.embed(sentence)
47 for i, token in enumerate(sentence):
48 X.append(token.embedding.detach().numpy()[None, :])
50 X = numpy.concatenate(X, 0)
52 return X
54 @staticmethod
55 def word_contexts(sentences):
56 contexts = []
58 for sentence in sentences:
60 strs = [x.text for x in sentence.tokens]
62 for i, token in enumerate(strs):
63 prop = '<b><font color="red"> {token} </font></b>'.format(token=token)
65 prop = " ".join(strs[max(i - 4, 0) : i]) + prop
66 prop = prop + " ".join(strs[i + 1 : min(len(strs), i + 5)])
68 contexts.append("<p>" + prop + "</p>")
70 return contexts
72 @staticmethod
73 def prepare_char_embeddings(embeddings, sentences):
74 X = []
76 for sentence in tqdm.tqdm(sentences):
77 sentence = " ".join([x.text for x in sentence])
79 hidden = embeddings.lm.get_representation([sentence], "", "")
80 X.append(hidden.squeeze().detach().numpy())
82 X = numpy.concatenate(X, 0)
84 return X
86 @staticmethod
87 def char_contexts(sentences):
88 contexts = []
90 for sentence in sentences:
91 sentence = " ".join([token.text for token in sentence])
93 for i, char in enumerate(sentence):
94 context = '<span style="background-color: yellow"><b>{}</b></span>'.format(
95 char
96 )
97 context = "".join(sentence[max(i - 30, 0) : i]) + context
98 context = context + "".join(
99 sentence[i + 1 : min(len(sentence), i + 30)]
100 )
102 contexts.append(context)
104 return contexts
106 @staticmethod
107 def visualize(X, contexts, file):
108 import matplotlib.pyplot
109 import mpld3
111 fig, ax = matplotlib.pyplot.subplots()
113 ax.grid(True, alpha=0.3)
115 points = ax.plot(
116 X[:, 0], X[:, 1], "o", color="b", mec="k", ms=5, mew=1, alpha=0.6
117 )
119 ax.set_xlabel("x")
120 ax.set_ylabel("y")
121 ax.set_title("Hover mouse to reveal context", size=20)
123 tooltip = mpld3.plugins.PointHTMLTooltip(
124 points[0], contexts, voffset=10, hoffset=10
125 )
127 mpld3.plugins.connect(fig, tooltip)
129 mpld3.save_html(fig, file)