Coverage for flair/flair/visual/manifold.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

78 statements  

1import numpy 

2import tqdm 

3from sklearn.manifold import TSNE 

4 

5 

6class _Transform: 

7 def __init__(self): 

8 pass 

9 

10 def fit(self, X): 

11 return self.transform.fit_transform(X) 

12 

13 

14class tSNE(_Transform): 

15 def __init__(self): 

16 super().__init__() 

17 

18 self.transform = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300) 

19 

20 

21class Visualizer(object): 

22 def visualize_word_emeddings(self, embeddings, sentences, output_file): 

23 X = self.prepare_word_embeddings(embeddings, sentences) 

24 contexts = self.word_contexts(sentences) 

25 

26 trans_ = tSNE() 

27 reduced = trans_.fit(X) 

28 

29 self.visualize(reduced, contexts, output_file) 

30 

31 def visualize_char_emeddings(self, embeddings, sentences, output_file): 

32 X = self.prepare_char_embeddings(embeddings, sentences) 

33 contexts = self.char_contexts(sentences) 

34 

35 trans_ = tSNE() 

36 reduced = trans_.fit(X) 

37 

38 self.visualize(reduced, contexts, output_file) 

39 

40 @staticmethod 

41 def prepare_word_embeddings(embeddings, sentences): 

42 X = [] 

43 

44 for sentence in tqdm.tqdm(sentences): 

45 embeddings.embed(sentence) 

46 

47 for i, token in enumerate(sentence): 

48 X.append(token.embedding.detach().numpy()[None, :]) 

49 

50 X = numpy.concatenate(X, 0) 

51 

52 return X 

53 

54 @staticmethod 

55 def word_contexts(sentences): 

56 contexts = [] 

57 

58 for sentence in sentences: 

59 

60 strs = [x.text for x in sentence.tokens] 

61 

62 for i, token in enumerate(strs): 

63 prop = '<b><font color="red"> {token} </font></b>'.format(token=token) 

64 

65 prop = " ".join(strs[max(i - 4, 0) : i]) + prop 

66 prop = prop + " ".join(strs[i + 1 : min(len(strs), i + 5)]) 

67 

68 contexts.append("<p>" + prop + "</p>") 

69 

70 return contexts 

71 

72 @staticmethod 

73 def prepare_char_embeddings(embeddings, sentences): 

74 X = [] 

75 

76 for sentence in tqdm.tqdm(sentences): 

77 sentence = " ".join([x.text for x in sentence]) 

78 

79 hidden = embeddings.lm.get_representation([sentence], "", "") 

80 X.append(hidden.squeeze().detach().numpy()) 

81 

82 X = numpy.concatenate(X, 0) 

83 

84 return X 

85 

86 @staticmethod 

87 def char_contexts(sentences): 

88 contexts = [] 

89 

90 for sentence in sentences: 

91 sentence = " ".join([token.text for token in sentence]) 

92 

93 for i, char in enumerate(sentence): 

94 context = '<span style="background-color: yellow"><b>{}</b></span>'.format( 

95 char 

96 ) 

97 context = "".join(sentence[max(i - 30, 0) : i]) + context 

98 context = context + "".join( 

99 sentence[i + 1 : min(len(sentence), i + 30)] 

100 ) 

101 

102 contexts.append(context) 

103 

104 return contexts 

105 

106 @staticmethod 

107 def visualize(X, contexts, file): 

108 import matplotlib.pyplot 

109 import mpld3 

110 

111 fig, ax = matplotlib.pyplot.subplots() 

112 

113 ax.grid(True, alpha=0.3) 

114 

115 points = ax.plot( 

116 X[:, 0], X[:, 1], "o", color="b", mec="k", ms=5, mew=1, alpha=0.6 

117 ) 

118 

119 ax.set_xlabel("x") 

120 ax.set_ylabel("y") 

121 ax.set_title("Hover mouse to reveal context", size=20) 

122 

123 tooltip = mpld3.plugins.PointHTMLTooltip( 

124 points[0], contexts, voffset=10, hoffset=10 

125 ) 

126 

127 mpld3.plugins.connect(fig, tooltip) 

128 

129 mpld3.save_html(fig, file)