Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/models/entity_linker_model.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

75 statements  

1import logging 

2from typing import List, Union 

3 

4import torch 

5import torch.nn as nn 

6 

7import flair.embeddings 

8import flair.nn 

9from flair.data import DataPoint, Dictionary, SpanLabel 

10 

11log = logging.getLogger("flair") 

12 

13 

14class EntityLinker(flair.nn.DefaultClassifier): 

15 """ 

16 Entity Linking Model 

17 The model expects text/sentences with annotated entity mentions and predicts entities to these mentions. 

18 To this end a word embedding is used to embed the sentences and the embedding of the entity mention goes through a linear layer to get the actual class label. 

19 The model is able to predict '<unk>' for entity mentions that the model can not confidently match to any of the known labels. 

20 """ 

21 

22 def __init__( 

23 self, 

24 word_embeddings: flair.embeddings.TokenEmbeddings, 

25 label_dictionary: Dictionary, 

26 pooling_operation: str = 'average', 

27 label_type: str = 'nel', 

28 **classifierargs, 

29 ): 

30 """ 

31 Initializes an EntityLinker 

32 :param word_embeddings: embeddings used to embed the words/sentences 

33 :param label_dictionary: dictionary that gives ids to all classes. Should contain <unk> 

34 :param pooling_operation: either 'average', 'first', 'last' or 'first&last'. Specifies the way of how text representations of entity mentions (with more than one word) are handled.  

35 E.g. 'average' means that as text representation we take the average of the embeddings of the words in the mention. 'first&last' concatenates 

36 the embedding of the first and the embedding of the last word.  

37 :param label_type: name of the label you use. 

38 """ 

39 

40 super(EntityLinker, self).__init__(label_dictionary, **classifierargs) 

41 

42 self.word_embeddings = word_embeddings 

43 self.pooling_operation = pooling_operation 

44 self._label_type = label_type 

45 

46 # if we concatenate the embeddings we need double input size in our linear layer 

47 if self.pooling_operation == 'first&last': 

48 self.decoder = nn.Linear( 

49 2 * self.word_embeddings.embedding_length, len(self.label_dictionary) 

50 ).to(flair.device) 

51 else: 

52 self.decoder = nn.Linear( 

53 self.word_embeddings.embedding_length, len(self.label_dictionary) 

54 ).to(flair.device) 

55 

56 nn.init.xavier_uniform_(self.decoder.weight) 

57 

58 cases = { 

59 'average': self.emb_mean, 

60 'first': self.emb_first, 

61 'last': self.emb_last, 

62 'first&last': self.emb_firstAndLast 

63 } 

64 

65 if pooling_operation not in cases: 

66 raise KeyError('pooling_operation has to be one of "average", "first", "last" or "first&last"') 

67 

68 self.aggregated_embedding = cases.get(pooling_operation) 

69 

70 self.to(flair.device) 

71 

72 def emb_first(self, arg): 

73 return arg[0] 

74 

75 def emb_last(self, arg): 

76 return arg[-1] 

77 

78 def emb_firstAndLast(self, arg): 

79 return torch.cat((arg[0], arg[-1]), 0) 

80 

81 def emb_mean(self, arg): 

82 return torch.mean(arg, 0) 

83 

84 def forward_pass(self, 

85 sentences: Union[List[DataPoint], DataPoint], 

86 return_label_candidates: bool = False, 

87 ): 

88 

89 if isinstance(sentences, DataPoint): 

90 sentences = [sentences] 

91 

92 # filter sentences with no candidates (no candidates means nothing can be linked anyway) 

93 filtered_sentences = [] 

94 for sentence in sentences: 

95 if sentence.get_labels(self.label_type): 

96 filtered_sentences.append(sentence) 

97 

98 # fields to return 

99 span_labels = [] 

100 sentences_to_spans = [] 

101 empty_label_candidates = [] 

102 

103 # if the entire batch has no sentence with candidates, return empty 

104 if len(filtered_sentences) == 0: 

105 scores = None 

106 

107 # otherwise, embed sentence and send through prediction head 

108 else: 

109 # embed all tokens 

110 self.word_embeddings.embed(filtered_sentences) 

111 

112 embedding_names = self.word_embeddings.get_names() 

113 

114 embedding_list = [] 

115 # get the embeddings of the entity mentions 

116 for sentence in filtered_sentences: 

117 spans = sentence.get_spans(self.label_type) 

118 

119 for span in spans: 

120 mention_emb = torch.Tensor(0, self.word_embeddings.embedding_length).to(flair.device) 

121 

122 for token in span.tokens: 

123 mention_emb = torch.cat((mention_emb, token.get_embedding(embedding_names).unsqueeze(0)), 0) 

124 

125 embedding_list.append(self.aggregated_embedding(mention_emb).unsqueeze(0)) 

126 

127 span_labels.append([label.value for label in span.get_labels(typename=self.label_type)]) 

128 

129 if return_label_candidates: 

130 sentences_to_spans.append(sentence) 

131 candidate = SpanLabel(span=span, value=None, score=None) 

132 empty_label_candidates.append(candidate) 

133 

134 embedding_tensor = torch.cat(embedding_list, 0).to(flair.device) 

135 scores = self.decoder(embedding_tensor) 

136 

137 # minimal return is scores and labels 

138 return_tuple = (scores, span_labels) 

139 

140 if return_label_candidates: 

141 return_tuple += (sentences_to_spans, empty_label_candidates) 

142 

143 return return_tuple 

144 

145 def _get_state_dict(self): 

146 model_state = { 

147 "state_dict": self.state_dict(), 

148 "word_embeddings": self.word_embeddings, 

149 "label_type": self.label_type, 

150 "label_dictionary": self.label_dictionary, 

151 "pooling_operation": self.pooling_operation, 

152 } 

153 return model_state 

154 

155 @staticmethod 

156 def _init_model_with_state_dict(state): 

157 model = EntityLinker( 

158 word_embeddings=state["word_embeddings"], 

159 label_dictionary=state["label_dictionary"], 

160 label_type=state["label_type"], 

161 pooling_operation=state["pooling_operation"], 

162 ) 

163 

164 model.load_state_dict(state["state_dict"]) 

165 return model 

166 

167 @property 

168 def label_type(self): 

169 return self._label_type