Coverage for flair/flair/models/relation_extractor_model.py: 71%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from typing import List, Union, Tuple, Optional
4import torch
5import torch.nn as nn
7import flair.embeddings
8import flair.nn
9from flair.data import DataPoint, RelationLabel, Span, Sentence
11log = logging.getLogger("flair")
14class RelationExtractor(flair.nn.DefaultClassifier):
16 def __init__(
17 self,
18 embeddings: Union[flair.embeddings.TokenEmbeddings],
19 label_type: str,
20 entity_label_type: str,
21 train_on_gold_pairs_only: bool = False,
22 entity_pair_filters: List[Tuple[str, str]] = None,
23 pooling_operation: str = "first_last",
24 dropout_value: float = 0.0,
25 locked_dropout_value: float = 0.0,
26 word_dropout_value: float = 0.0,
27 non_linear_decoder: Optional[int] = None,
28 **classifierargs,
29 ):
30 """
31 Initializes a RelationClassifier
32 :param document_embeddings: embeddings used to embed each data point
33 :param label_dictionary: dictionary of labels you want to predict
34 :param beta: Parameter for F-beta score for evaluation and training annealing
35 :param loss_weights: Dictionary of weights for labels for the loss function
36 (if any label's weight is unspecified it will default to 1.0)
37 """
38 super(RelationExtractor, self).__init__(**classifierargs)
40 # set embeddings
41 self.embeddings: flair.embeddings.TokenEmbeddings = embeddings
43 # set relation and entity label types
44 self._label_type = label_type
45 self.entity_label_type = entity_label_type
47 # whether to use gold entity pairs, and whether to filter entity pairs by type
48 self.train_on_gold_pairs_only = train_on_gold_pairs_only
49 if entity_pair_filters is not None:
50 self.entity_pair_filters = set(entity_pair_filters)
51 else:
52 self.entity_pair_filters = None
54 # init dropouts
55 self.dropout_value = dropout_value
56 self.dropout = torch.nn.Dropout(dropout_value)
57 self.locked_dropout_value = locked_dropout_value
58 self.locked_dropout = flair.nn.LockedDropout(locked_dropout_value)
59 self.word_dropout_value = word_dropout_value
60 self.word_dropout = flair.nn.WordDropout(word_dropout_value)
62 # pooling operation to get embeddings for entites
63 self.pooling_operation = pooling_operation
64 relation_representation_length = 2 * embeddings.embedding_length
65 if self.pooling_operation == 'first_last':
66 relation_representation_length *= 2
67 if type(self.embeddings) == flair.embeddings.TransformerDocumentEmbeddings:
68 relation_representation_length = embeddings.embedding_length
70 # entity pairs could also be no relation at all, add default value for this case to dictionary
71 self.label_dictionary.add_item('O')
73 # decoder can be linear or nonlinear
74 self.non_linear_decoder = non_linear_decoder
75 if self.non_linear_decoder:
76 self.decoder_1 = nn.Linear(relation_representation_length, non_linear_decoder)
77 self.nonlinearity = torch.nn.ReLU()
78 self.decoder_2 = nn.Linear(non_linear_decoder, len(self.label_dictionary))
79 nn.init.xavier_uniform_(self.decoder_1.weight)
80 nn.init.xavier_uniform_(self.decoder_2.weight)
81 else:
82 self.decoder = nn.Linear(relation_representation_length, len(self.label_dictionary))
83 nn.init.xavier_uniform_(self.decoder.weight)
85 self.to(flair.device)
87 def add_entity_markers(self, sentence, span_1, span_2):
89 text = ""
91 entity_one_is_first = None
92 offset = 0
93 for token in sentence:
94 if token == span_2[0]:
95 if entity_one_is_first is None: entity_one_is_first = False
96 offset += 1
97 text += " <e2>"
98 span_2_startid = offset
99 if token == span_1[0]:
100 offset += 1
101 text += " <e1>"
102 if entity_one_is_first is None: entity_one_is_first = True
103 span_1_startid = offset
105 text += " " + token.text
107 if token == span_1[-1]:
108 offset += 1
109 text += " </e1>"
110 span_1_stopid = offset
111 if token == span_2[-1]:
112 offset += 1
113 text += " </e2>"
114 span_2_stopid = offset
116 offset += 1
118 expanded_sentence = Sentence(text, use_tokenizer=False)
120 expanded_span_1 = Span([expanded_sentence[span_1_startid - 1]])
121 expanded_span_2 = Span([expanded_sentence[span_2_startid - 1]])
123 return expanded_sentence, (expanded_span_1, expanded_span_2) \
124 if entity_one_is_first else (expanded_span_2, expanded_span_1)
126 def forward_pass(self,
127 sentences: Union[List[DataPoint], DataPoint],
128 return_label_candidates: bool = False,
129 ):
131 empty_label_candidates = []
132 entity_pairs = []
133 labels = []
134 sentences_to_label = []
136 for sentence in sentences:
138 # super lame: make dictionary to find relation annotations for a given entity pair
139 relation_dict = {}
140 for relation_label in sentence.get_labels(self.label_type):
141 relation_label: RelationLabel = relation_label
142 relation_dict[create_position_string(relation_label.head, relation_label.tail)] = relation_label
144 # get all entity spans
145 span_labels = sentence.get_labels(self.entity_label_type)
147 # go through cross product of entities, for each pair concat embeddings
148 for span_label in span_labels:
149 span_1 = span_label.span
151 for span_label_2 in span_labels:
152 span_2 = span_label_2.span
154 if span_1 == span_2:
155 continue
157 # filter entity pairs according to their tags if set
158 if (self.entity_pair_filters is not None
159 and (span_label.value, span_label_2.value) not in self.entity_pair_filters):
160 continue
162 position_string = create_position_string(span_1, span_2)
164 # get gold label for this relation (if one exists)
165 if position_string in relation_dict:
166 relation_label: RelationLabel = relation_dict[position_string]
167 label = relation_label.value
169 # if there is no gold label for this entity pair, set to 'O' (no relation)
170 else:
171 if self.train_on_gold_pairs_only: continue # skip 'O' labels if training on gold pairs only
172 label = 'O'
174 entity_pairs.append((span_1, span_2))
176 labels.append([label])
178 # if predicting, also remember sentences and label candidates
179 if return_label_candidates:
180 candidate_label = RelationLabel(head=span_1, tail=span_2, value=None, score=None)
181 empty_label_candidates.append(candidate_label)
182 sentences_to_label.append(span_1[0].sentence)
184 # if there's at least one entity pair in the sentence
185 if len(entity_pairs) > 0:
187 # embed sentences and get embeddings for each entity pair
188 self.embeddings.embed(sentences)
189 relation_embeddings = []
191 # get embeddings
192 for entity_pair in entity_pairs:
193 span_1 = entity_pair[0]
194 span_2 = entity_pair[1]
196 if self.pooling_operation == "first_last":
197 embedding = torch.cat([span_1.tokens[0].get_embedding(),
198 span_1.tokens[-1].get_embedding(),
199 span_2.tokens[0].get_embedding(),
200 span_2.tokens[-1].get_embedding()])
201 else:
202 embedding = torch.cat([span_1.tokens[0].get_embedding(), span_2.tokens[0].get_embedding()])
204 relation_embeddings.append(embedding)
206 # stack and drop out
207 all_relations = torch.stack(relation_embeddings)
209 all_relations = self.dropout(all_relations)
210 all_relations = self.locked_dropout(all_relations)
211 all_relations = self.word_dropout(all_relations)
213 # send through decoder
214 if self.non_linear_decoder:
215 sentence_relation_scores = self.decoder_2(self.nonlinearity(self.decoder_1(all_relations)))
216 else:
217 sentence_relation_scores = self.decoder(all_relations)
219 else:
220 sentence_relation_scores = None
222 # return either scores and gold labels (for loss calculation), or include label candidates for prediction
223 result_tuple = (sentence_relation_scores, labels)
225 if return_label_candidates:
226 result_tuple += (sentences_to_label, empty_label_candidates)
228 return result_tuple
230 def _get_state_dict(self):
231 model_state = {
232 "state_dict": self.state_dict(),
233 "embeddings": self.embeddings,
234 "label_dictionary": self.label_dictionary,
235 "label_type": self.label_type,
236 "entity_label_type": self.entity_label_type,
237 "loss_weights": self.loss_weights,
238 "pooling_operation": self.pooling_operation,
239 "dropout_value": self.dropout_value,
240 "locked_dropout_value": self.locked_dropout_value,
241 "word_dropout_value": self.word_dropout_value,
242 "entity_pair_filters": self.entity_pair_filters,
243 "non_linear_decoder": self.non_linear_decoder,
244 }
245 return model_state
247 @staticmethod
248 def _init_model_with_state_dict(state):
249 model = RelationExtractor(
250 embeddings=state["embeddings"],
251 label_dictionary=state["label_dictionary"],
252 label_type=state["label_type"],
253 entity_label_type=state["entity_label_type"],
254 loss_weights=state["loss_weights"],
255 pooling_operation=state["pooling_operation"],
256 dropout_value=state["dropout_value"],
257 locked_dropout_value=state["locked_dropout_value"],
258 word_dropout_value=state["word_dropout_value"],
259 entity_pair_filters=state["entity_pair_filters"],
260 non_linear_decoder=state["non_linear_decoder"],
261 )
262 model.load_state_dict(state["state_dict"])
263 return model
265 @property
266 def label_type(self):
267 return self._label_type
270def create_position_string(head: Span, tail: Span) -> str:
271 return f"{head.id_text} -> {tail.id_text}"