Coverage for flair/flair/models/relation_extractor_model.py: 15%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from pathlib import Path
3from typing import List, Union, Tuple, Optional
5import torch
6import torch.nn as nn
8import flair.embeddings
9import flair.nn
10from flair.data import DataPoint, RelationLabel, Span, Sentence
11from flair.file_utils import cached_path
13log = logging.getLogger("flair")
16class RelationExtractor(flair.nn.DefaultClassifier):
18 def __init__(
19 self,
20 embeddings: Union[flair.embeddings.TokenEmbeddings],
21 label_type: str,
22 entity_label_type: str,
23 train_on_gold_pairs_only: bool = False,
24 entity_pair_filters: List[Tuple[str, str]] = None,
25 pooling_operation: str = "first_last",
26 dropout_value: float = 0.0,
27 locked_dropout_value: float = 0.1,
28 word_dropout_value: float = 0.0,
29 non_linear_decoder: Optional[int] = 2048,
30 **classifierargs,
31 ):
32 """
33 Initializes a RelationClassifier
34 :param document_embeddings: embeddings used to embed each data point
35 :param label_dictionary: dictionary of labels you want to predict
36 :param beta: Parameter for F-beta score for evaluation and training annealing
37 :param loss_weights: Dictionary of weights for labels for the loss function
38 (if any label's weight is unspecified it will default to 1.0)
39 """
40 super(RelationExtractor, self).__init__(**classifierargs)
42 # set embeddings
43 self.embeddings: flair.embeddings.TokenEmbeddings = embeddings
45 # set relation and entity label types
46 self._label_type = label_type
47 self.entity_label_type = entity_label_type
49 # whether to use gold entity pairs, and whether to filter entity pairs by type
50 self.train_on_gold_pairs_only = train_on_gold_pairs_only
51 if entity_pair_filters is not None:
52 self.entity_pair_filters = set(entity_pair_filters)
53 else:
54 self.entity_pair_filters = None
56 # init dropouts
57 self.dropout_value = dropout_value
58 self.dropout = torch.nn.Dropout(dropout_value)
59 self.locked_dropout_value = locked_dropout_value
60 self.locked_dropout = flair.nn.LockedDropout(locked_dropout_value)
61 self.word_dropout_value = word_dropout_value
62 self.word_dropout = flair.nn.WordDropout(word_dropout_value)
64 # pooling operation to get embeddings for entites
65 self.pooling_operation = pooling_operation
66 relation_representation_length = 2 * embeddings.embedding_length
67 if self.pooling_operation == 'first_last':
68 relation_representation_length *= 2
69 if type(self.embeddings) == flair.embeddings.TransformerDocumentEmbeddings:
70 relation_representation_length = embeddings.embedding_length
72 # entity pairs could also be no relation at all, add default value for this case to dictionary
73 self.label_dictionary.add_item('O')
75 # decoder can be linear or nonlinear
76 self.non_linear_decoder = non_linear_decoder
77 if self.non_linear_decoder:
78 self.decoder_1 = nn.Linear(relation_representation_length, non_linear_decoder)
79 self.nonlinearity = torch.nn.ReLU()
80 self.decoder_2 = nn.Linear(non_linear_decoder, len(self.label_dictionary))
81 nn.init.xavier_uniform_(self.decoder_1.weight)
82 nn.init.xavier_uniform_(self.decoder_2.weight)
83 else:
84 self.decoder = nn.Linear(relation_representation_length, len(self.label_dictionary))
85 nn.init.xavier_uniform_(self.decoder.weight)
87 self.to(flair.device)
89 def add_entity_markers(self, sentence, span_1, span_2):
91 text = ""
93 entity_one_is_first = None
94 offset = 0
95 for token in sentence:
96 if token == span_2[0]:
97 if entity_one_is_first is None: entity_one_is_first = False
98 offset += 1
99 text += " <e2>"
100 span_2_startid = offset
101 if token == span_1[0]:
102 offset += 1
103 text += " <e1>"
104 if entity_one_is_first is None: entity_one_is_first = True
105 span_1_startid = offset
107 text += " " + token.text
109 if token == span_1[-1]:
110 offset += 1
111 text += " </e1>"
112 span_1_stopid = offset
113 if token == span_2[-1]:
114 offset += 1
115 text += " </e2>"
116 span_2_stopid = offset
118 offset += 1
120 expanded_sentence = Sentence(text, use_tokenizer=False)
122 expanded_span_1 = Span([expanded_sentence[span_1_startid - 1]])
123 expanded_span_2 = Span([expanded_sentence[span_2_startid - 1]])
125 return expanded_sentence, (expanded_span_1, expanded_span_2) \
126 if entity_one_is_first else (expanded_span_2, expanded_span_1)
128 def forward_pass(self,
129 sentences: Union[List[DataPoint], DataPoint],
130 return_label_candidates: bool = False,
131 ):
133 empty_label_candidates = []
134 entity_pairs = []
135 labels = []
136 sentences_to_label = []
138 for sentence in sentences:
140 # super lame: make dictionary to find relation annotations for a given entity pair
141 relation_dict = {}
142 for relation_label in sentence.get_labels(self.label_type):
143 relation_label: RelationLabel = relation_label
144 relation_dict[create_position_string(relation_label.head, relation_label.tail)] = relation_label
146 # get all entity spans
147 span_labels = sentence.get_labels(self.entity_label_type)
149 # go through cross product of entities, for each pair concat embeddings
150 for span_label in span_labels:
151 span_1 = span_label.span
153 for span_label_2 in span_labels:
154 span_2 = span_label_2.span
156 if span_1 == span_2:
157 continue
159 # filter entity pairs according to their tags if set
160 if (self.entity_pair_filters is not None
161 and (span_label.value, span_label_2.value) not in self.entity_pair_filters):
162 continue
164 position_string = create_position_string(span_1, span_2)
166 # get gold label for this relation (if one exists)
167 if position_string in relation_dict:
168 relation_label: RelationLabel = relation_dict[position_string]
169 label = relation_label.value
171 # if there is no gold label for this entity pair, set to 'O' (no relation)
172 else:
173 if self.train_on_gold_pairs_only: continue # skip 'O' labels if training on gold pairs only
174 label = 'O'
176 entity_pairs.append((span_1, span_2))
178 labels.append([label])
180 # if predicting, also remember sentences and label candidates
181 if return_label_candidates:
182 candidate_label = RelationLabel(head=span_1, tail=span_2, value=None, score=None)
183 empty_label_candidates.append(candidate_label)
184 sentences_to_label.append(span_1[0].sentence)
186 # if there's at least one entity pair in the sentence
187 if len(entity_pairs) > 0:
189 # embed sentences and get embeddings for each entity pair
190 self.embeddings.embed(sentences)
191 relation_embeddings = []
193 # get embeddings
194 for entity_pair in entity_pairs:
195 span_1 = entity_pair[0]
196 span_2 = entity_pair[1]
198 if self.pooling_operation == "first_last":
199 embedding = torch.cat([span_1.tokens[0].get_embedding(),
200 span_1.tokens[-1].get_embedding(),
201 span_2.tokens[0].get_embedding(),
202 span_2.tokens[-1].get_embedding()])
203 else:
204 embedding = torch.cat([span_1.tokens[0].get_embedding(), span_2.tokens[0].get_embedding()])
206 relation_embeddings.append(embedding)
208 # stack and drop out (squeeze and unsqueeze)
209 all_relations = torch.stack(relation_embeddings).unsqueeze(1)
211 all_relations = self.dropout(all_relations)
212 all_relations = self.locked_dropout(all_relations)
213 all_relations = self.word_dropout(all_relations)
215 all_relations = all_relations.squeeze(1)
217 # send through decoder
218 if self.non_linear_decoder:
219 sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations)))
220 else:
221 sentence_relation_scores = self.decoder(all_relations)
223 else:
224 sentence_relation_scores = None
226 # return either scores and gold labels (for loss calculation), or include label candidates for prediction
227 result_tuple = (sentence_relation_scores, labels)
229 if return_label_candidates:
230 result_tuple += (sentences_to_label, empty_label_candidates)
232 return result_tuple
234 def _get_state_dict(self):
235 model_state = {
236 "state_dict": self.state_dict(),
237 "embeddings": self.embeddings,
238 "label_dictionary": self.label_dictionary,
239 "label_type": self.label_type,
240 "entity_label_type": self.entity_label_type,
241 "loss_weights": self.loss_weights,
242 "pooling_operation": self.pooling_operation,
243 "dropout_value": self.dropout_value,
244 "locked_dropout_value": self.locked_dropout_value,
245 "word_dropout_value": self.word_dropout_value,
246 "entity_pair_filters": self.entity_pair_filters,
247 "non_linear_decoder": self.non_linear_decoder,
248 }
249 return model_state
251 @staticmethod
252 def _init_model_with_state_dict(state):
253 model = RelationExtractor(
254 embeddings=state["embeddings"],
255 label_dictionary=state["label_dictionary"],
256 label_type=state["label_type"],
257 entity_label_type=state["entity_label_type"],
258 loss_weights=state["loss_weights"],
259 pooling_operation=state["pooling_operation"],
260 dropout_value=state["dropout_value"],
261 locked_dropout_value=state["locked_dropout_value"],
262 word_dropout_value=state["word_dropout_value"],
263 entity_pair_filters=state["entity_pair_filters"],
264 non_linear_decoder=state["non_linear_decoder"],
265 )
266 model.load_state_dict(state["state_dict"])
267 return model
269 @property
270 def label_type(self):
271 return self._label_type
273 @staticmethod
274 def _fetch_model(model_name) -> str:
276 model_map = {}
278 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
280 model_map["relations-fast"] = "/".join([hu_path, "relations-fast", "relations-fast.pt"])
281 model_map["relations"] = "/".join([hu_path, "relations", "relations.pt"])
283 cache_dir = Path("models")
284 if model_name in model_map:
285 model_name = cached_path(model_map[model_name], cache_dir=cache_dir)
287 return model_name
290def create_position_string(head: Span, tail: Span) -> str:
291 return f"{head.id_text} -> {tail.id_text}"