Coverage for flair/flair/models/relation_extractor_model.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

149 statements  

1import logging 

2from pathlib import Path 

3from typing import List, Union, Tuple, Optional 

4 

5import torch 

6import torch.nn as nn 

7 

8import flair.embeddings 

9import flair.nn 

10from flair.data import DataPoint, RelationLabel, Span, Sentence 

11from flair.file_utils import cached_path 

12 

13log = logging.getLogger("flair") 

14 

15 

16class RelationExtractor(flair.nn.DefaultClassifier): 

17 

18 def __init__( 

19 self, 

20 embeddings: Union[flair.embeddings.TokenEmbeddings], 

21 label_type: str, 

22 entity_label_type: str, 

23 train_on_gold_pairs_only: bool = False, 

24 entity_pair_filters: List[Tuple[str, str]] = None, 

25 pooling_operation: str = "first_last", 

26 dropout_value: float = 0.0, 

27 locked_dropout_value: float = 0.1, 

28 word_dropout_value: float = 0.0, 

29 non_linear_decoder: Optional[int] = 2048, 

30 **classifierargs, 

31 ): 

32 """ 

33 Initializes a RelationClassifier 

34 :param document_embeddings: embeddings used to embed each data point 

35 :param label_dictionary: dictionary of labels you want to predict 

36 :param beta: Parameter for F-beta score for evaluation and training annealing 

37 :param loss_weights: Dictionary of weights for labels for the loss function 

38 (if any label's weight is unspecified it will default to 1.0) 

39 """ 

40 super(RelationExtractor, self).__init__(**classifierargs) 

41 

42 # set embeddings 

43 self.embeddings: flair.embeddings.TokenEmbeddings = embeddings 

44 

45 # set relation and entity label types 

46 self._label_type = label_type 

47 self.entity_label_type = entity_label_type 

48 

49 # whether to use gold entity pairs, and whether to filter entity pairs by type 

50 self.train_on_gold_pairs_only = train_on_gold_pairs_only 

51 if entity_pair_filters is not None: 

52 self.entity_pair_filters = set(entity_pair_filters) 

53 else: 

54 self.entity_pair_filters = None 

55 

56 # init dropouts 

57 self.dropout_value = dropout_value 

58 self.dropout = torch.nn.Dropout(dropout_value) 

59 self.locked_dropout_value = locked_dropout_value 

60 self.locked_dropout = flair.nn.LockedDropout(locked_dropout_value) 

61 self.word_dropout_value = word_dropout_value 

62 self.word_dropout = flair.nn.WordDropout(word_dropout_value) 

63 

64 # pooling operation to get embeddings for entites 

65 self.pooling_operation = pooling_operation 

66 relation_representation_length = 2 * embeddings.embedding_length 

67 if self.pooling_operation == 'first_last': 

68 relation_representation_length *= 2 

69 if type(self.embeddings) == flair.embeddings.TransformerDocumentEmbeddings: 

70 relation_representation_length = embeddings.embedding_length 

71 

72 # entity pairs could also be no relation at all, add default value for this case to dictionary 

73 self.label_dictionary.add_item('O') 

74 

75 # decoder can be linear or nonlinear 

76 self.non_linear_decoder = non_linear_decoder 

77 if self.non_linear_decoder: 

78 self.decoder_1 = nn.Linear(relation_representation_length, non_linear_decoder) 

79 self.nonlinearity = torch.nn.ReLU() 

80 self.decoder_2 = nn.Linear(non_linear_decoder, len(self.label_dictionary)) 

81 nn.init.xavier_uniform_(self.decoder_1.weight) 

82 nn.init.xavier_uniform_(self.decoder_2.weight) 

83 else: 

84 self.decoder = nn.Linear(relation_representation_length, len(self.label_dictionary)) 

85 nn.init.xavier_uniform_(self.decoder.weight) 

86 

87 self.to(flair.device) 

88 

89 def add_entity_markers(self, sentence, span_1, span_2): 

90 

91 text = "" 

92 

93 entity_one_is_first = None 

94 offset = 0 

95 for token in sentence: 

96 if token == span_2[0]: 

97 if entity_one_is_first is None: entity_one_is_first = False 

98 offset += 1 

99 text += " <e2>" 

100 span_2_startid = offset 

101 if token == span_1[0]: 

102 offset += 1 

103 text += " <e1>" 

104 if entity_one_is_first is None: entity_one_is_first = True 

105 span_1_startid = offset 

106 

107 text += " " + token.text 

108 

109 if token == span_1[-1]: 

110 offset += 1 

111 text += " </e1>" 

112 span_1_stopid = offset 

113 if token == span_2[-1]: 

114 offset += 1 

115 text += " </e2>" 

116 span_2_stopid = offset 

117 

118 offset += 1 

119 

120 expanded_sentence = Sentence(text, use_tokenizer=False) 

121 

122 expanded_span_1 = Span([expanded_sentence[span_1_startid - 1]]) 

123 expanded_span_2 = Span([expanded_sentence[span_2_startid - 1]]) 

124 

125 return expanded_sentence, (expanded_span_1, expanded_span_2) \ 

126 if entity_one_is_first else (expanded_span_2, expanded_span_1) 

127 

128 def forward_pass(self, 

129 sentences: Union[List[DataPoint], DataPoint], 

130 return_label_candidates: bool = False, 

131 ): 

132 

133 empty_label_candidates = [] 

134 entity_pairs = [] 

135 labels = [] 

136 sentences_to_label = [] 

137 

138 for sentence in sentences: 

139 

140 # super lame: make dictionary to find relation annotations for a given entity pair 

141 relation_dict = {} 

142 for relation_label in sentence.get_labels(self.label_type): 

143 relation_label: RelationLabel = relation_label 

144 relation_dict[create_position_string(relation_label.head, relation_label.tail)] = relation_label 

145 

146 # get all entity spans 

147 span_labels = sentence.get_labels(self.entity_label_type) 

148 

149 # go through cross product of entities, for each pair concat embeddings 

150 for span_label in span_labels: 

151 span_1 = span_label.span 

152 

153 for span_label_2 in span_labels: 

154 span_2 = span_label_2.span 

155 

156 if span_1 == span_2: 

157 continue 

158 

159 # filter entity pairs according to their tags if set 

160 if (self.entity_pair_filters is not None 

161 and (span_label.value, span_label_2.value) not in self.entity_pair_filters): 

162 continue 

163 

164 position_string = create_position_string(span_1, span_2) 

165 

166 # get gold label for this relation (if one exists) 

167 if position_string in relation_dict: 

168 relation_label: RelationLabel = relation_dict[position_string] 

169 label = relation_label.value 

170 

171 # if there is no gold label for this entity pair, set to 'O' (no relation) 

172 else: 

173 if self.train_on_gold_pairs_only: continue # skip 'O' labels if training on gold pairs only 

174 label = 'O' 

175 

176 entity_pairs.append((span_1, span_2)) 

177 

178 labels.append([label]) 

179 

180 # if predicting, also remember sentences and label candidates 

181 if return_label_candidates: 

182 candidate_label = RelationLabel(head=span_1, tail=span_2, value=None, score=None) 

183 empty_label_candidates.append(candidate_label) 

184 sentences_to_label.append(span_1[0].sentence) 

185 

186 # if there's at least one entity pair in the sentence 

187 if len(entity_pairs) > 0: 

188 

189 # embed sentences and get embeddings for each entity pair 

190 self.embeddings.embed(sentences) 

191 relation_embeddings = [] 

192 

193 # get embeddings 

194 for entity_pair in entity_pairs: 

195 span_1 = entity_pair[0] 

196 span_2 = entity_pair[1] 

197 

198 if self.pooling_operation == "first_last": 

199 embedding = torch.cat([span_1.tokens[0].get_embedding(), 

200 span_1.tokens[-1].get_embedding(), 

201 span_2.tokens[0].get_embedding(), 

202 span_2.tokens[-1].get_embedding()]) 

203 else: 

204 embedding = torch.cat([span_1.tokens[0].get_embedding(), span_2.tokens[0].get_embedding()]) 

205 

206 relation_embeddings.append(embedding) 

207 

208 # stack and drop out (squeeze and unsqueeze) 

209 all_relations = torch.stack(relation_embeddings).unsqueeze(1) 

210 

211 all_relations = self.dropout(all_relations) 

212 all_relations = self.locked_dropout(all_relations) 

213 all_relations = self.word_dropout(all_relations) 

214 

215 all_relations = all_relations.squeeze(1) 

216 

217 # send through decoder 

218 if self.non_linear_decoder: 

219 sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations))) 

220 else: 

221 sentence_relation_scores = self.decoder(all_relations) 

222 

223 else: 

224 sentence_relation_scores = None 

225 

226 # return either scores and gold labels (for loss calculation), or include label candidates for prediction 

227 result_tuple = (sentence_relation_scores, labels) 

228 

229 if return_label_candidates: 

230 result_tuple += (sentences_to_label, empty_label_candidates) 

231 

232 return result_tuple 

233 

234 def _get_state_dict(self): 

235 model_state = { 

236 "state_dict": self.state_dict(), 

237 "embeddings": self.embeddings, 

238 "label_dictionary": self.label_dictionary, 

239 "label_type": self.label_type, 

240 "entity_label_type": self.entity_label_type, 

241 "loss_weights": self.loss_weights, 

242 "pooling_operation": self.pooling_operation, 

243 "dropout_value": self.dropout_value, 

244 "locked_dropout_value": self.locked_dropout_value, 

245 "word_dropout_value": self.word_dropout_value, 

246 "entity_pair_filters": self.entity_pair_filters, 

247 "non_linear_decoder": self.non_linear_decoder, 

248 } 

249 return model_state 

250 

251 @staticmethod 

252 def _init_model_with_state_dict(state): 

253 model = RelationExtractor( 

254 embeddings=state["embeddings"], 

255 label_dictionary=state["label_dictionary"], 

256 label_type=state["label_type"], 

257 entity_label_type=state["entity_label_type"], 

258 loss_weights=state["loss_weights"], 

259 pooling_operation=state["pooling_operation"], 

260 dropout_value=state["dropout_value"], 

261 locked_dropout_value=state["locked_dropout_value"], 

262 word_dropout_value=state["word_dropout_value"], 

263 entity_pair_filters=state["entity_pair_filters"], 

264 non_linear_decoder=state["non_linear_decoder"], 

265 ) 

266 model.load_state_dict(state["state_dict"]) 

267 return model 

268 

269 @property 

270 def label_type(self): 

271 return self._label_type 

272 

273 @staticmethod 

274 def _fetch_model(model_name) -> str: 

275 

276 model_map = {} 

277 

278 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" 

279 

280 model_map["relations-fast"] = "/".join([hu_path, "relations-fast", "relations-fast.pt"]) 

281 model_map["relations"] = "/".join([hu_path, "relations", "relations.pt"]) 

282 

283 cache_dir = Path("models") 

284 if model_name in model_map: 

285 model_name = cached_path(model_map[model_name], cache_dir=cache_dir) 

286 

287 return model_name 

288 

289 

290def create_position_string(head: Span, tail: Span) -> str: 

291 return f"{head.id_text} -> {tail.id_text}"