Coverage for flair/flair/models/relation_extractor_model.py: 13%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

136 statements  

1import logging 

2from typing import List, Union, Tuple, Optional 

3 

4import torch 

5import torch.nn as nn 

6 

7import flair.embeddings 

8import flair.nn 

9from flair.data import DataPoint, RelationLabel, Span, Sentence 

10 

11log = logging.getLogger("flair") 

12 

13 

14class RelationExtractor(flair.nn.DefaultClassifier): 

15 

16 def __init__( 

17 self, 

18 embeddings: Union[flair.embeddings.TokenEmbeddings], 

19 label_type: str, 

20 entity_label_type: str, 

21 train_on_gold_pairs_only: bool = False, 

22 entity_pair_filters: List[Tuple[str, str]] = None, 

23 pooling_operation: str = "first_last", 

24 dropout_value: float = 0.0, 

25 locked_dropout_value: float = 0.0, 

26 word_dropout_value: float = 0.0, 

27 non_linear_decoder: Optional[int] = None, 

28 **classifierargs, 

29 ): 

30 """ 

31 Initializes a RelationClassifier 

32 :param document_embeddings: embeddings used to embed each data point 

33 :param label_dictionary: dictionary of labels you want to predict 

34 :param beta: Parameter for F-beta score for evaluation and training annealing 

35 :param loss_weights: Dictionary of weights for labels for the loss function 

36 (if any label's weight is unspecified it will default to 1.0) 

37 """ 

38 super(RelationExtractor, self).__init__(**classifierargs) 

39 

40 # set embeddings 

41 self.embeddings: flair.embeddings.TokenEmbeddings = embeddings 

42 

43 # set relation and entity label types 

44 self._label_type = label_type 

45 self.entity_label_type = entity_label_type 

46 

47 # whether to use gold entity pairs, and whether to filter entity pairs by type 

48 self.train_on_gold_pairs_only = train_on_gold_pairs_only 

49 if entity_pair_filters is not None: 

50 self.entity_pair_filters = set(entity_pair_filters) 

51 else: 

52 self.entity_pair_filters = None 

53 

54 # init dropouts 

55 self.dropout_value = dropout_value 

56 self.dropout = torch.nn.Dropout(dropout_value) 

57 self.locked_dropout_value = locked_dropout_value 

58 self.locked_dropout = flair.nn.LockedDropout(locked_dropout_value) 

59 self.word_dropout_value = word_dropout_value 

60 self.word_dropout = flair.nn.WordDropout(word_dropout_value) 

61 

62 # pooling operation to get embeddings for entites 

63 self.pooling_operation = pooling_operation 

64 relation_representation_length = 2 * embeddings.embedding_length 

65 if self.pooling_operation == 'first_last': 

66 relation_representation_length *= 2 

67 if type(self.embeddings) == flair.embeddings.TransformerDocumentEmbeddings: 

68 relation_representation_length = embeddings.embedding_length 

69 

70 # entity pairs could also be no relation at all, add default value for this case to dictionary 

71 self.label_dictionary.add_item('O') 

72 

73 # decoder can be linear or nonlinear 

74 self.non_linear_decoder = non_linear_decoder 

75 if self.non_linear_decoder: 

76 self.decoder_1 = nn.Linear(relation_representation_length, non_linear_decoder) 

77 self.nonlinearity = torch.nn.ReLU() 

78 self.decoder_2 = nn.Linear(non_linear_decoder, len(self.label_dictionary)) 

79 nn.init.xavier_uniform_(self.decoder_1.weight) 

80 nn.init.xavier_uniform_(self.decoder_2.weight) 

81 else: 

82 self.decoder = nn.Linear(relation_representation_length, len(self.label_dictionary)) 

83 nn.init.xavier_uniform_(self.decoder.weight) 

84 

85 self.to(flair.device) 

86 

87 def add_entity_markers(self, sentence, span_1, span_2): 

88 

89 text = "" 

90 

91 entity_one_is_first = None 

92 offset = 0 

93 for token in sentence: 

94 if token == span_2[0]: 

95 if entity_one_is_first is None: entity_one_is_first = False 

96 offset += 1 

97 text += " <e2>" 

98 span_2_startid = offset 

99 if token == span_1[0]: 

100 offset += 1 

101 text += " <e1>" 

102 if entity_one_is_first is None: entity_one_is_first = True 

103 span_1_startid = offset 

104 

105 text += " " + token.text 

106 

107 if token == span_1[-1]: 

108 offset += 1 

109 text += " </e1>" 

110 span_1_stopid = offset 

111 if token == span_2[-1]: 

112 offset += 1 

113 text += " </e2>" 

114 span_2_stopid = offset 

115 

116 offset += 1 

117 

118 expanded_sentence = Sentence(text, use_tokenizer=False) 

119 

120 expanded_span_1 = Span([expanded_sentence[span_1_startid - 1]]) 

121 expanded_span_2 = Span([expanded_sentence[span_2_startid - 1]]) 

122 

123 return expanded_sentence, (expanded_span_1, expanded_span_2) \ 

124 if entity_one_is_first else (expanded_span_2, expanded_span_1) 

125 

126 def forward_pass(self, 

127 sentences: Union[List[DataPoint], DataPoint], 

128 return_label_candidates: bool = False, 

129 ): 

130 

131 empty_label_candidates = [] 

132 entity_pairs = [] 

133 labels = [] 

134 sentences_to_label = [] 

135 

136 for sentence in sentences: 

137 

138 # super lame: make dictionary to find relation annotations for a given entity pair 

139 relation_dict = {} 

140 for relation_label in sentence.get_labels(self.label_type): 

141 relation_label: RelationLabel = relation_label 

142 relation_dict[create_position_string(relation_label.head, relation_label.tail)] = relation_label 

143 

144 # get all entity spans 

145 span_labels = sentence.get_labels(self.entity_label_type) 

146 

147 # go through cross product of entities, for each pair concat embeddings 

148 for span_label in span_labels: 

149 span_1 = span_label.span 

150 

151 for span_label_2 in span_labels: 

152 span_2 = span_label_2.span 

153 

154 if span_1 == span_2: 

155 continue 

156 

157 # filter entity pairs according to their tags if set 

158 if (self.entity_pair_filters is not None 

159 and (span_label.value, span_label_2.value) not in self.entity_pair_filters): 

160 continue 

161 

162 position_string = create_position_string(span_1, span_2) 

163 

164 # get gold label for this relation (if one exists) 

165 if position_string in relation_dict: 

166 relation_label: RelationLabel = relation_dict[position_string] 

167 label = relation_label.value 

168 

169 # if there is no gold label for this entity pair, set to 'O' (no relation) 

170 else: 

171 if self.train_on_gold_pairs_only: continue # skip 'O' labels if training on gold pairs only 

172 label = 'O' 

173 

174 entity_pairs.append((span_1, span_2)) 

175 

176 labels.append([label]) 

177 

178 # if predicting, also remember sentences and label candidates 

179 if return_label_candidates: 

180 candidate_label = RelationLabel(head=span_1, tail=span_2, value=None, score=None) 

181 empty_label_candidates.append(candidate_label) 

182 sentences_to_label.append(span_1[0].sentence) 

183 

184 # if there's at least one entity pair in the sentence 

185 if len(entity_pairs) > 0: 

186 

187 # embed sentences and get embeddings for each entity pair 

188 self.embeddings.embed(sentences) 

189 relation_embeddings = [] 

190 

191 # get embeddings 

192 for entity_pair in entity_pairs: 

193 span_1 = entity_pair[0] 

194 span_2 = entity_pair[1] 

195 

196 if self.pooling_operation == "first_last": 

197 embedding = torch.cat([span_1.tokens[0].get_embedding(), 

198 span_1.tokens[-1].get_embedding(), 

199 span_2.tokens[0].get_embedding(), 

200 span_2.tokens[-1].get_embedding()]) 

201 else: 

202 embedding = torch.cat([span_1.tokens[0].get_embedding(), span_2.tokens[0].get_embedding()]) 

203 

204 relation_embeddings.append(embedding) 

205 

206 # stack and drop out 

207 all_relations = torch.stack(relation_embeddings) 

208 

209 all_relations = self.dropout(all_relations) 

210 all_relations = self.locked_dropout(all_relations) 

211 all_relations = self.word_dropout(all_relations) 

212 

213 # send through decoder 

214 if self.non_linear_decoder: 

215 sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations))) 

216 else: 

217 sentence_relation_scores = self.decoder(all_relations) 

218 

219 else: 

220 sentence_relation_scores = None 

221 

222 # return either scores and gold labels (for loss calculation), or include label candidates for prediction 

223 result_tuple = (sentence_relation_scores, labels) 

224 

225 if return_label_candidates: 

226 result_tuple += (sentences_to_label, empty_label_candidates) 

227 

228 return result_tuple 

229 

230 def _get_state_dict(self): 

231 model_state = { 

232 "state_dict": self.state_dict(), 

233 "embeddings": self.embeddings, 

234 "label_dictionary": self.label_dictionary, 

235 "label_type": self.label_type, 

236 "entity_label_type": self.entity_label_type, 

237 "loss_weights": self.loss_weights, 

238 "pooling_operation": self.pooling_operation, 

239 "dropout_value": self.dropout_value, 

240 "locked_dropout_value": self.locked_dropout_value, 

241 "word_dropout_value": self.word_dropout_value, 

242 "entity_pair_filters": self.entity_pair_filters, 

243 "non_linear_decoder": self.non_linear_decoder, 

244 } 

245 return model_state 

246 

247 @staticmethod 

248 def _init_model_with_state_dict(state): 

249 model = RelationExtractor( 

250 embeddings=state["embeddings"], 

251 label_dictionary=state["label_dictionary"], 

252 label_type=state["label_type"], 

253 entity_label_type=state["entity_label_type"], 

254 loss_weights=state["loss_weights"], 

255 pooling_operation=state["pooling_operation"], 

256 dropout_value=state["dropout_value"], 

257 locked_dropout_value=state["locked_dropout_value"], 

258 word_dropout_value=state["word_dropout_value"], 

259 entity_pair_filters=state["entity_pair_filters"], 

260 non_linear_decoder=state["non_linear_decoder"], 

261 ) 

262 model.load_state_dict(state["state_dict"]) 

263 return model 

264 

265 @property 

266 def label_type(self): 

267 return self._label_type 

268 

269 

270def create_position_string(head: Span, tail: Span) -> str: 

271 return f"{head.id_text} -> {tail.id_text}"