Coverage for flair/flair/models/entity_linker_model.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from typing import List, Union
4import torch
5import torch.nn as nn
7import flair.embeddings
8import flair.nn
9from flair.data import DataPoint, Dictionary, SpanLabel
11log = logging.getLogger("flair")
14class EntityLinker(flair.nn.DefaultClassifier):
15 """
16 Entity Linking Model
17 The model expects text/sentences with annotated entity mentions and predicts entities to these mentions.
18 To this end a word embedding is used to embed the sentences and the embedding of the entity mention goes through a linear layer to get the actual class label.
19 The model is able to predict '<unk>' for entity mentions that the model can not confidently match to any of the known labels.
20 """
22 def __init__(
23 self,
24 word_embeddings: flair.embeddings.TokenEmbeddings,
25 label_dictionary: Dictionary,
26 pooling_operation: str = 'average',
27 label_type: str = 'nel',
28 **classifierargs,
29 ):
30 """
31 Initializes an EntityLinker
32 :param word_embeddings: embeddings used to embed the words/sentences
33 :param label_dictionary: dictionary that gives ids to all classes. Should contain <unk>
34 :param pooling_operation: either 'average', 'first', 'last' or 'first&last'. Specifies the way of how text representations of entity mentions (with more than one word) are handled.
35 E.g. 'average' means that as text representation we take the average of the embeddings of the words in the mention. 'first&last' concatenates
36 the embedding of the first and the embedding of the last word.
37 :param label_type: name of the label you use.
38 """
40 super(EntityLinker, self).__init__(label_dictionary, **classifierargs)
42 self.word_embeddings = word_embeddings
43 self.pooling_operation = pooling_operation
44 self._label_type = label_type
46 # if we concatenate the embeddings we need double input size in our linear layer
47 if self.pooling_operation == 'first&last':
48 self.decoder = nn.Linear(
49 2 * self.word_embeddings.embedding_length, len(self.label_dictionary)
50 ).to(flair.device)
51 else:
52 self.decoder = nn.Linear(
53 self.word_embeddings.embedding_length, len(self.label_dictionary)
54 ).to(flair.device)
56 nn.init.xavier_uniform_(self.decoder.weight)
58 cases = {
59 'average': self.emb_mean,
60 'first': self.emb_first,
61 'last': self.emb_last,
62 'first&last': self.emb_firstAndLast
63 }
65 if pooling_operation not in cases:
66 raise KeyError('pooling_operation has to be one of "average", "first", "last" or "first&last"')
68 self.aggregated_embedding = cases.get(pooling_operation)
70 self.to(flair.device)
72 def emb_first(self, arg):
73 return arg[0]
75 def emb_last(self, arg):
76 return arg[-1]
78 def emb_firstAndLast(self, arg):
79 return torch.cat((arg[0], arg[-1]), 0)
81 def emb_mean(self, arg):
82 return torch.mean(arg, 0)
84 def forward_pass(self,
85 sentences: Union[List[DataPoint], DataPoint],
86 return_label_candidates: bool = False,
87 ):
89 if isinstance(sentences, DataPoint):
90 sentences = [sentences]
92 # filter sentences with no candidates (no candidates means nothing can be linked anyway)
93 filtered_sentences = []
94 for sentence in sentences:
95 if sentence.get_labels(self.label_type):
96 filtered_sentences.append(sentence)
98 # fields to return
99 span_labels = []
100 sentences_to_spans = []
101 empty_label_candidates = []
103 # if the entire batch has no sentence with candidates, return empty
104 if len(filtered_sentences) == 0:
105 scores = None
107 # otherwise, embed sentence and send through prediction head
108 else:
109 # embed all tokens
110 self.word_embeddings.embed(filtered_sentences)
112 embedding_names = self.word_embeddings.get_names()
114 embedding_list = []
115 # get the embeddings of the entity mentions
116 for sentence in filtered_sentences:
117 spans = sentence.get_spans(self.label_type)
119 for span in spans:
120 mention_emb = torch.Tensor(0, self.word_embeddings.embedding_length).to(flair.device)
122 for token in span.tokens:
123 mention_emb = torch.cat((mention_emb, token.get_embedding(embedding_names).unsqueeze(0)), 0)
125 embedding_list.append(self.aggregated_embedding(mention_emb).unsqueeze(0))
127 span_labels.append([label.value for label in span.get_labels(typename=self.label_type)])
129 if return_label_candidates:
130 sentences_to_spans.append(sentence)
131 candidate = SpanLabel(span=span, value=None, score=None)
132 empty_label_candidates.append(candidate)
134 embedding_tensor = torch.cat(embedding_list, 0).to(flair.device)
135 scores = self.decoder(embedding_tensor)
137 # minimal return is scores and labels
138 return_tuple = (scores, span_labels)
140 if return_label_candidates:
141 return_tuple += (sentences_to_spans, empty_label_candidates)
143 return return_tuple
145 def _get_state_dict(self):
146 model_state = {
147 "state_dict": self.state_dict(),
148 "word_embeddings": self.word_embeddings,
149 "label_type": self.label_type,
150 "label_dictionary": self.label_dictionary,
151 "pooling_operation": self.pooling_operation,
152 }
153 return model_state
155 @staticmethod
156 def _init_model_with_state_dict(state):
157 model = EntityLinker(
158 word_embeddings=state["word_embeddings"],
159 label_dictionary=state["label_dictionary"],
160 label_type=state["label_type"],
161 pooling_operation=state["pooling_operation"],
162 )
164 model.load_state_dict(state["state_dict"])
165 return model
167 @property
168 def label_type(self):
169 return self._label_type