Coverage for flair/flair/embeddings/document.py: 72%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from abc import abstractmethod
3from typing import List, Union
5import torch
6from sklearn.feature_extraction.text import TfidfVectorizer
7from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer
10import flair
11from flair.data import Sentence
12from flair.embeddings.base import Embeddings, ScalarMix
13from flair.embeddings.token import TokenEmbeddings, StackedEmbeddings, FlairEmbeddings
14from flair.nn import LockedDropout, WordDropout
16log = logging.getLogger("flair")
19class DocumentEmbeddings(Embeddings):
20 """Abstract base class for all document-level embeddings. Every new type of document embedding must implement these methods."""
22 @property
23 @abstractmethod
24 def embedding_length(self) -> int:
25 """Returns the length of the embedding vector."""
26 pass
28 @property
29 def embedding_type(self) -> str:
30 return "sentence-level"
33class TransformerDocumentEmbeddings(DocumentEmbeddings):
34 def __init__(
35 self,
36 model: str = "bert-base-uncased",
37 fine_tune: bool = True,
38 layers: str = "-1",
39 layer_mean: bool = False,
40 pooling: str = "cls",
41 **kwargs
42 ):
43 """
44 Bidirectional transformer embeddings of words from various transformer architectures.
45 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
46 options)
47 :param fine_tune: If True, allows transformers to be fine-tuned during training
48 :param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer
49 models tend to be huge.
50 :param layers: string indicating which layers to take for embedding (-1 is topmost layer)
51 :param layer_mean: If True, uses a scalar mix of layers as embedding
52 :param pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'.
53 """
54 super().__init__()
56 if pooling not in ['cls', 'max', 'mean']:
57 raise ValueError(f"Pooling operation `{pooling}` is not defined for TransformerDocumentEmbeddings")
59 # temporary fix to disable tokenizer parallelism warning
60 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning)
61 import os
62 os.environ["TOKENIZERS_PARALLELISM"] = "false"
64 # do not print transformer warnings as these are confusing in this case
65 from transformers import logging
66 logging.set_verbosity_error()
68 # load tokenizer and transformer model
69 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
70 if not 'config' in kwargs:
71 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
72 self.model = AutoModel.from_pretrained(model, config=config)
73 else:
74 self.model = AutoModel.from_pretrained(None, **kwargs)
76 logging.set_verbosity_warning()
78 # model name
79 self.name = 'transformer-document-' + str(model)
80 self.base_model_name = str(model)
82 # when initializing, embeddings are in eval mode by default
83 self.model.eval()
84 self.model.to(flair.device)
86 # embedding parameters
87 if layers == 'all':
88 # send mini-token through to check how many layers the model has
89 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1]
90 self.layer_indexes = [int(x) for x in range(len(hidden_states))]
91 else:
92 self.layer_indexes = [int(x) for x in layers.split(",")]
94 self.layer_mean = layer_mean
95 self.fine_tune = fine_tune
96 self.static_embeddings = not self.fine_tune
97 self.pooling = pooling
99 # check whether CLS is at beginning or end
100 self.initial_cls_token: bool = self._has_initial_cls_token(tokenizer=self.tokenizer)
102 @staticmethod
103 def _has_initial_cls_token(tokenizer: PreTrainedTokenizer) -> bool:
104 # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
105 tokens = tokenizer.encode('a')
106 initial_cls_token: bool = False
107 if tokens[0] == tokenizer.cls_token_id: initial_cls_token = True
108 return initial_cls_token
110 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
111 """Add embeddings to all words in a list of sentences."""
113 # gradients are enabled if fine-tuning is enabled
114 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
116 with gradient_context:
118 # first, subtokenize each sentence and find out into how many subtokens each token was divided
119 subtokenized_sentences = []
121 # subtokenize sentences
122 for sentence in sentences:
123 # tokenize and truncate to max subtokens (TODO: check better truncation strategies)
124 subtokenized_sentence = self.tokenizer.encode(sentence.to_tokenized_string(),
125 add_special_tokens=True,
126 max_length=self.tokenizer.model_max_length,
127 truncation=True,
128 )
130 subtokenized_sentences.append(
131 torch.tensor(subtokenized_sentence, dtype=torch.long, device=flair.device))
133 # find longest sentence in batch
134 longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len))
136 # initialize batch tensors and mask
137 input_ids = torch.zeros(
138 [len(sentences), longest_sequence_in_batch],
139 dtype=torch.long,
140 device=flair.device,
141 )
142 mask = torch.zeros(
143 [len(sentences), longest_sequence_in_batch],
144 dtype=torch.long,
145 device=flair.device,
146 )
147 for s_id, sentence in enumerate(subtokenized_sentences):
148 sequence_length = len(sentence)
149 input_ids[s_id][:sequence_length] = sentence
150 mask[s_id][:sequence_length] = torch.ones(sequence_length)
152 # put encoded batch through transformer model to get all hidden states of all encoder layers
153 hidden_states = self.model(input_ids, attention_mask=mask)[-1] if len(sentences) > 1 \
154 else self.model(input_ids)[-1]
156 # iterate over all subtokenized sentences
157 for sentence_idx, (sentence, subtokens) in enumerate(zip(sentences, subtokenized_sentences)):
159 if self.pooling == "cls":
160 index_of_CLS_token = 0 if self.initial_cls_token else len(subtokens) - 1
162 cls_embeddings_all_layers: List[torch.FloatTensor] = \
163 [hidden_states[layer][sentence_idx][index_of_CLS_token] for layer in self.layer_indexes]
165 embeddings_all_layers = cls_embeddings_all_layers
167 elif self.pooling == "mean":
168 mean_embeddings_all_layers: List[torch.FloatTensor] = \
169 [torch.mean(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0) for layer in
170 self.layer_indexes]
172 embeddings_all_layers = mean_embeddings_all_layers
174 elif self.pooling == "max":
175 max_embeddings_all_layers: List[torch.FloatTensor] = \
176 [torch.max(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0)[0] for layer in
177 self.layer_indexes]
179 embeddings_all_layers = max_embeddings_all_layers
181 # use scalar mix of embeddings if so selected
182 if self.layer_mean:
183 sm = ScalarMix(mixture_size=len(embeddings_all_layers))
184 sm_embeddings = sm(embeddings_all_layers)
186 embeddings_all_layers = [sm_embeddings]
188 # set the extracted embedding for the token
189 sentence.set_embedding(self.name, torch.cat(embeddings_all_layers))
191 return sentences
193 @property
194 @abstractmethod
195 def embedding_length(self) -> int:
196 """Returns the length of the embedding vector."""
197 return (
198 len(self.layer_indexes) * self.model.config.hidden_size
199 if not self.layer_mean
200 else self.model.config.hidden_size
201 )
203 def __getstate__(self):
204 # special handling for serializing transformer models
205 config_state_dict = self.model.config.__dict__
206 model_state_dict = self.model.state_dict()
208 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-document-')[-1]
210 # serialize the transformer models and the constructor arguments (but nothing else)
211 model_state = {
212 "config_state_dict": config_state_dict,
213 "model_state_dict": model_state_dict,
214 "embedding_length_internal": self.embedding_length,
216 "base_model_name": self.base_model_name,
217 "fine_tune": self.fine_tune,
218 "layer_indexes": self.layer_indexes,
219 "layer_mean": self.layer_mean,
220 "pooling": self.pooling,
221 }
223 return model_state
225 def __setstate__(self, d):
226 self.__dict__ = d
228 # necessary for reverse compatibility with Flair <= 0.7
229 if 'use_scalar_mix' in self.__dict__.keys():
230 self.__dict__['layer_mean'] = d['use_scalar_mix']
232 # special handling for deserializing transformer models
233 if "config_state_dict" in d:
235 # load transformer model
236 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert"
237 config_class = CONFIG_MAPPING[model_type]
238 loaded_config = config_class.from_dict(d["config_state_dict"])
240 # constructor arguments
241 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']])
243 # re-initialize transformer word embeddings with constructor arguments
244 embedding = TransformerDocumentEmbeddings(
245 model=self.__dict__['base_model_name'],
246 fine_tune=self.__dict__['fine_tune'],
247 layers=layers,
248 layer_mean=self.__dict__['layer_mean'],
250 config=loaded_config,
251 state_dict=d["model_state_dict"],
252 pooling=self.__dict__['pooling'] if 'pooling' in self.__dict__ else 'cls',
253 # for backward compatibility with previous models
254 )
256 # I have no idea why this is necessary, but otherwise it doesn't work
257 for key in embedding.__dict__.keys():
258 self.__dict__[key] = embedding.__dict__[key]
260 else:
261 model_name = self.__dict__['name'].split('transformer-document-')[-1]
262 # reload tokenizer to get around serialization issues
263 try:
264 tokenizer = AutoTokenizer.from_pretrained(model_name)
265 except:
266 pass
267 self.tokenizer = tokenizer
270class DocumentPoolEmbeddings(DocumentEmbeddings):
271 def __init__(
272 self,
273 embeddings: List[TokenEmbeddings],
274 fine_tune_mode: str = "none",
275 pooling: str = "mean",
276 ):
277 """The constructor takes a list of embeddings to be combined.
278 :param embeddings: a list of token embeddings
279 :param fine_tune_mode: if set to "linear" a trainable layer is added, if set to
280 "nonlinear", a nonlinearity is added as well. Set this to make the pooling trainable.
281 :param pooling: a string which can any value from ['mean', 'max', 'min']
282 """
283 super().__init__()
285 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
286 self.__embedding_length = self.embeddings.embedding_length
288 # optional fine-tuning on top of embedding layer
289 self.fine_tune_mode = fine_tune_mode
290 if self.fine_tune_mode in ["nonlinear", "linear"]:
291 self.embedding_flex = torch.nn.Linear(
292 self.embedding_length, self.embedding_length, bias=False
293 )
294 self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length))
296 if self.fine_tune_mode in ["nonlinear"]:
297 self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length)
298 self.embedding_flex_nonlinear_map = torch.nn.Linear(
299 self.embedding_length, self.embedding_length
300 )
302 self.__embedding_length: int = self.embeddings.embedding_length
304 self.to(flair.device)
306 if pooling not in ['min', 'max', 'mean']:
307 raise ValueError(f"Pooling operation for {self.mode!r} is not defined")
309 self.pooling = pooling
310 self.name: str = f"document_{self.pooling}"
312 @property
313 def embedding_length(self) -> int:
314 return self.__embedding_length
316 def embed(self, sentences: Union[List[Sentence], Sentence]):
317 """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
318 only if embeddings are non-static."""
320 # if only one sentence is passed, convert to list of sentence
321 if isinstance(sentences, Sentence):
322 sentences = [sentences]
324 self.embeddings.embed(sentences)
326 for sentence in sentences:
327 word_embeddings = []
328 for token in sentence.tokens:
329 word_embeddings.append(token.get_embedding().unsqueeze(0))
331 word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device)
333 if self.fine_tune_mode in ["nonlinear", "linear"]:
334 word_embeddings = self.embedding_flex(word_embeddings)
336 if self.fine_tune_mode in ["nonlinear"]:
337 word_embeddings = self.embedding_flex_nonlinear(word_embeddings)
338 word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings)
340 if self.pooling == "mean":
341 pooled_embedding = torch.mean(word_embeddings, 0)
342 elif self.pooling == "max":
343 pooled_embedding, _ = torch.max(word_embeddings, 0)
344 elif self.pooling == "min":
345 pooled_embedding, _ = torch.min(word_embeddings, 0)
347 sentence.set_embedding(self.name, pooled_embedding)
349 def _add_embeddings_internal(self, sentences: List[Sentence]):
350 pass
352 def extra_repr(self):
353 return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}"
356class DocumentTFIDFEmbeddings(DocumentEmbeddings):
357 def __init__(
358 self,
359 train_dataset,
360 **vectorizer_params,
361 ):
362 """The constructor for DocumentTFIDFEmbeddings.
363 :param train_dataset: the train dataset which will be used to construct vectorizer
364 :param vectorizer_params: parameters given to Scikit-learn's TfidfVectorizer constructor
365 """
366 super().__init__()
368 import numpy as np
369 self.vectorizer = TfidfVectorizer(dtype=np.float32, **vectorizer_params)
370 self.vectorizer.fit([s.to_original_text() for s in train_dataset])
372 self.__embedding_length: int = len(self.vectorizer.vocabulary_)
374 self.to(flair.device)
376 self.name: str = f"document_tfidf"
378 @property
379 def embedding_length(self) -> int:
380 return self.__embedding_length
382 def embed(self, sentences: Union[List[Sentence], Sentence]):
383 """Add embeddings to every sentence in the given list of sentences."""
385 # if only one sentence is passed, convert to list of sentence
386 if isinstance(sentences, Sentence):
387 sentences = [sentences]
389 raw_sentences = [s.to_original_text() for s in sentences]
390 tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A)
392 for sentence_id, sentence in enumerate(sentences):
393 sentence.set_embedding(self.name, tfidf_vectors[sentence_id])
395 def _add_embeddings_internal(self, sentences: List[Sentence]):
396 pass
399class DocumentRNNEmbeddings(DocumentEmbeddings):
400 def __init__(
401 self,
402 embeddings: List[TokenEmbeddings],
403 hidden_size=128,
404 rnn_layers=1,
405 reproject_words: bool = True,
406 reproject_words_dimension: int = None,
407 bidirectional: bool = False,
408 dropout: float = 0.5,
409 word_dropout: float = 0.0,
410 locked_dropout: float = 0.0,
411 rnn_type="GRU",
412 fine_tune: bool = True,
413 ):
414 """The constructor takes a list of embeddings to be combined.
415 :param embeddings: a list of token embeddings
416 :param hidden_size: the number of hidden states in the rnn
417 :param rnn_layers: the number of layers for the rnn
418 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
419 layer before putting them into the rnn or not
420 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
421 dimension as before will be taken.
422 :param bidirectional: boolean value, indicating whether to use a bidirectional rnn or not
423 :param dropout: the dropout value to be used
424 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
425 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
426 :param rnn_type: 'GRU' or 'LSTM'
427 """
428 super().__init__()
430 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
432 self.rnn_type = rnn_type
434 self.reproject_words = reproject_words
435 self.bidirectional = bidirectional
437 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length
439 self.static_embeddings = False if fine_tune else True
441 self.__embedding_length: int = hidden_size
442 if self.bidirectional:
443 self.__embedding_length *= 4
445 self.embeddings_dimension: int = self.length_of_all_token_embeddings
446 if self.reproject_words and reproject_words_dimension is not None:
447 self.embeddings_dimension = reproject_words_dimension
449 self.word_reprojection_map = torch.nn.Linear(
450 self.length_of_all_token_embeddings, self.embeddings_dimension
451 )
453 # bidirectional RNN on top of embedding layer
454 if rnn_type == "LSTM":
455 self.rnn = torch.nn.LSTM(
456 self.embeddings_dimension,
457 hidden_size,
458 num_layers=rnn_layers,
459 bidirectional=self.bidirectional,
460 batch_first=True,
461 )
462 else:
463 self.rnn = torch.nn.GRU(
464 self.embeddings_dimension,
465 hidden_size,
466 num_layers=rnn_layers,
467 bidirectional=self.bidirectional,
468 batch_first=True,
469 )
471 self.name = "document_" + self.rnn._get_name()
473 # dropouts
474 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
475 self.locked_dropout = (
476 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
477 )
478 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None
480 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)
482 self.to(flair.device)
484 self.eval()
486 @property
487 def embedding_length(self) -> int:
488 return self.__embedding_length
490 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
491 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
492 only if embeddings are non-static."""
494 # TODO: remove in future versions
495 if not hasattr(self, "locked_dropout"):
496 self.locked_dropout = None
497 if not hasattr(self, "word_dropout"):
498 self.word_dropout = None
500 if type(sentences) is Sentence:
501 sentences = [sentences]
503 self.rnn.zero_grad()
505 # embed words in the sentence
506 self.embeddings.embed(sentences)
508 lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
509 longest_token_sequence_in_batch: int = max(lengths)
511 pre_allocated_zero_tensor = torch.zeros(
512 self.embeddings.embedding_length * longest_token_sequence_in_batch,
513 dtype=torch.float,
514 device=flair.device,
515 )
517 all_embs: List[torch.Tensor] = list()
518 for sentence in sentences:
519 all_embs += [
520 emb for token in sentence for emb in token.get_each_embedding()
521 ]
522 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
524 if nb_padding_tokens > 0:
525 t = pre_allocated_zero_tensor[
526 : self.embeddings.embedding_length * nb_padding_tokens
527 ]
528 all_embs.append(t)
530 sentence_tensor = torch.cat(all_embs).view(
531 [
532 len(sentences),
533 longest_token_sequence_in_batch,
534 self.embeddings.embedding_length,
535 ]
536 )
538 # before-RNN dropout
539 if self.dropout:
540 sentence_tensor = self.dropout(sentence_tensor)
541 if self.locked_dropout:
542 sentence_tensor = self.locked_dropout(sentence_tensor)
543 if self.word_dropout:
544 sentence_tensor = self.word_dropout(sentence_tensor)
546 # reproject if set
547 if self.reproject_words:
548 sentence_tensor = self.word_reprojection_map(sentence_tensor)
550 # push through RNN
551 packed = pack_padded_sequence(
552 sentence_tensor, lengths, enforce_sorted=False, batch_first=True
553 )
554 rnn_out, hidden = self.rnn(packed)
555 outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True)
557 # after-RNN dropout
558 if self.dropout:
559 outputs = self.dropout(outputs)
560 if self.locked_dropout:
561 outputs = self.locked_dropout(outputs)
563 # extract embeddings from RNN
564 for sentence_no, length in enumerate(lengths):
565 last_rep = outputs[sentence_no, length - 1]
567 embedding = last_rep
568 if self.bidirectional:
569 first_rep = outputs[sentence_no, 0]
570 embedding = torch.cat([first_rep, last_rep], 0)
572 if self.static_embeddings:
573 embedding = embedding.detach()
575 sentence = sentences[sentence_no]
576 sentence.set_embedding(self.name, embedding)
578 def _apply(self, fn):
580 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute
581 # check if this is the case and if so, set it
582 for child_module in self.children():
583 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"):
584 _flat_weights_names = []
586 if child_module.__dict__["bidirectional"]:
587 num_direction = 2
588 else:
589 num_direction = 1
590 for layer in range(child_module.__dict__["num_layers"]):
591 for direction in range(num_direction):
592 suffix = "_reverse" if direction == 1 else ""
593 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
594 if child_module.__dict__["bias"]:
595 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
596 param_names = [
597 x.format(layer, suffix) for x in param_names
598 ]
599 _flat_weights_names.extend(param_names)
601 setattr(child_module, "_flat_weights_names",
602 _flat_weights_names)
604 child_module._apply(fn)
606 def __getstate__(self):
608 # serialize the language models and the constructor arguments (but nothing else)
609 model_state = {
610 "state_dict": self.state_dict(),
612 "embeddings": self.embeddings.embeddings,
613 "hidden_size": self.rnn.hidden_size,
614 "rnn_layers": self.rnn.num_layers,
615 "reproject_words": self.reproject_words,
616 "reproject_words_dimension": self.embeddings_dimension,
617 "bidirectional": self.bidirectional,
618 "dropout": self.dropout.p if self.dropout is not None else 0.,
619 "word_dropout": self.word_dropout.p if self.word_dropout is not None else 0.,
620 "locked_dropout": self.locked_dropout.p if self.locked_dropout is not None else 0.,
621 "rnn_type": self.rnn_type,
622 "fine_tune": not self.static_embeddings,
623 }
625 return model_state
627 def __setstate__(self, d):
629 # special handling for deserializing language models
630 if "state_dict" in d:
632 # re-initialize language model with constructor arguments
633 language_model = DocumentRNNEmbeddings(
634 embeddings=d['embeddings'],
635 hidden_size=d['hidden_size'],
636 rnn_layers=d['rnn_layers'],
637 reproject_words=d['reproject_words'],
638 reproject_words_dimension=d['reproject_words_dimension'],
639 bidirectional=d['bidirectional'],
640 dropout=d['dropout'],
641 word_dropout=d['word_dropout'],
642 locked_dropout=d['locked_dropout'],
643 rnn_type=d['rnn_type'],
644 fine_tune=d['fine_tune'],
645 )
647 language_model.load_state_dict(d['state_dict'])
649 # copy over state dictionary to self
650 for key in language_model.__dict__.keys():
651 self.__dict__[key] = language_model.__dict__[key]
653 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM
654 # in their "self.train()" method)
655 self.eval()
657 else:
658 self.__dict__ = d
661class DocumentLMEmbeddings(DocumentEmbeddings):
662 def __init__(self, flair_embeddings: List[FlairEmbeddings]):
663 super().__init__()
665 self.embeddings = flair_embeddings
666 self.name = "document_lm"
668 # IMPORTANT: add embeddings as torch modules
669 for i, embedding in enumerate(flair_embeddings):
670 self.add_module("lm_embedding_{}".format(i), embedding)
671 if not embedding.static_embeddings:
672 self.static_embeddings = False
674 self._embedding_length: int = sum(
675 embedding.embedding_length for embedding in flair_embeddings
676 )
678 @property
679 def embedding_length(self) -> int:
680 return self._embedding_length
682 def _add_embeddings_internal(self, sentences: List[Sentence]):
683 if type(sentences) is Sentence:
684 sentences = [sentences]
686 for embedding in self.embeddings:
687 embedding.embed(sentences)
689 # iterate over sentences
690 for sentence in sentences:
691 sentence: Sentence = sentence
693 # if its a forward LM, take last state
694 if embedding.is_forward_lm:
695 sentence.set_embedding(
696 embedding.name,
697 sentence[len(sentence) - 1]._embeddings[embedding.name],
698 )
699 else:
700 sentence.set_embedding(
701 embedding.name, sentence[0]._embeddings[embedding.name]
702 )
704 return sentences
707class SentenceTransformerDocumentEmbeddings(DocumentEmbeddings):
708 def __init__(
709 self,
710 model: str = "bert-base-nli-mean-tokens",
711 batch_size: int = 1,
712 convert_to_numpy: bool = False,
713 ):
714 """
715 :param model: string name of models from SentencesTransformer Class
716 :param name: string name of embedding type which will be set to Sentence object
717 :param batch_size: int number of sentences to processed in one batch
718 :param convert_to_numpy: bool whether the encode() returns a numpy array or PyTorch tensor
719 """
720 super().__init__()
722 try:
723 from sentence_transformers import SentenceTransformer
724 except ModuleNotFoundError:
725 log.warning("-" * 100)
726 log.warning('ATTENTION! The library "sentence-transformers" is not installed!')
727 log.warning(
728 'To use Sentence Transformers, please first install with "pip install sentence-transformers"'
729 )
730 log.warning("-" * 100)
731 pass
733 self.model = SentenceTransformer(model)
734 self.name = 'sentence-transformers-' + str(model)
735 self.batch_size = batch_size
736 self.convert_to_numpy = convert_to_numpy
737 self.static_embeddings = True
739 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
741 sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size]
742 for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)]
744 for batch in sentence_batches:
745 self._add_embeddings_to_sentences(batch)
747 return sentences
749 def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
751 # convert to plain strings, embedded in a list for the encode function
752 sentences_plain_text = [sentence.to_plain_string() for sentence in sentences]
754 embeddings = self.model.encode(sentences_plain_text, convert_to_numpy=self.convert_to_numpy)
755 for sentence, embedding in zip(sentences, embeddings):
756 sentence.set_embedding(self.name, embedding)
758 @property
759 @abstractmethod
760 def embedding_length(self) -> int:
761 """Returns the length of the embedding vector."""
762 return self.model.get_sentence_embedding_dimension()
765class DocumentCNNEmbeddings(DocumentEmbeddings):
766 def __init__(
767 self,
768 embeddings: List[TokenEmbeddings],
769 kernels=((100, 3), (100, 4), (100, 5)),
770 reproject_words: bool = True,
771 reproject_words_dimension: int = None,
772 dropout: float = 0.5,
773 word_dropout: float = 0.0,
774 locked_dropout: float = 0.0,
775 fine_tune: bool = True,
776 ):
777 """The constructor takes a list of embeddings to be combined.
778 :param embeddings: a list of token embeddings
779 :param kernels: list of (number of kernels, kernel size)
780 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
781 layer before putting them into the rnn or not
782 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
783 dimension as before will be taken.
784 :param dropout: the dropout value to be used
785 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
786 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
787 """
788 super().__init__()
790 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
791 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length
793 self.kernels = kernels
794 self.reproject_words = reproject_words
796 self.static_embeddings = False if fine_tune else True
798 self.embeddings_dimension: int = self.length_of_all_token_embeddings
799 if self.reproject_words and reproject_words_dimension is not None:
800 self.embeddings_dimension = reproject_words_dimension
802 self.word_reprojection_map = torch.nn.Linear(
803 self.length_of_all_token_embeddings, self.embeddings_dimension
804 )
806 # CNN
807 self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels])
808 self.convs = torch.nn.ModuleList(
809 [
810 torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in
811 self.kernels
812 ]
813 )
814 self.pool = torch.nn.AdaptiveMaxPool1d(1)
816 self.name = "document_cnn"
818 # dropouts
819 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
820 self.locked_dropout = (
821 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
822 )
823 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None
825 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)
827 self.to(flair.device)
829 self.eval()
831 @property
832 def embedding_length(self) -> int:
833 return self.__embedding_length
835 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
836 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
837 only if embeddings are non-static."""
839 # TODO: remove in future versions
840 if not hasattr(self, "locked_dropout"):
841 self.locked_dropout = None
842 if not hasattr(self, "word_dropout"):
843 self.word_dropout = None
845 if type(sentences) is Sentence:
846 sentences = [sentences]
848 self.zero_grad() # is it necessary?
850 # embed words in the sentence
851 self.embeddings.embed(sentences)
853 lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
854 longest_token_sequence_in_batch: int = max(lengths)
856 pre_allocated_zero_tensor = torch.zeros(
857 self.embeddings.embedding_length * longest_token_sequence_in_batch,
858 dtype=torch.float,
859 device=flair.device,
860 )
862 all_embs: List[torch.Tensor] = list()
863 for sentence in sentences:
864 all_embs += [
865 emb for token in sentence for emb in token.get_each_embedding()
866 ]
867 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
869 if nb_padding_tokens > 0:
870 t = pre_allocated_zero_tensor[
871 : self.embeddings.embedding_length * nb_padding_tokens
872 ]
873 all_embs.append(t)
875 sentence_tensor = torch.cat(all_embs).view(
876 [
877 len(sentences),
878 longest_token_sequence_in_batch,
879 self.embeddings.embedding_length,
880 ]
881 )
883 # before-RNN dropout
884 if self.dropout:
885 sentence_tensor = self.dropout(sentence_tensor)
886 if self.locked_dropout:
887 sentence_tensor = self.locked_dropout(sentence_tensor)
888 if self.word_dropout:
889 sentence_tensor = self.word_dropout(sentence_tensor)
891 # reproject if set
892 if self.reproject_words:
893 sentence_tensor = self.word_reprojection_map(sentence_tensor)
895 # push CNN
896 x = sentence_tensor
897 x = x.permute(0, 2, 1)
899 rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs]
900 outputs = torch.cat(rep, 1)
902 outputs = outputs.reshape(outputs.size(0), -1)
904 # after-CNN dropout
905 if self.dropout:
906 outputs = self.dropout(outputs)
907 if self.locked_dropout:
908 outputs = self.locked_dropout(outputs)
910 # extract embeddings from CNN
911 for sentence_no, length in enumerate(lengths):
912 embedding = outputs[sentence_no]
914 if self.static_embeddings:
915 embedding = embedding.detach()
917 sentence = sentences[sentence_no]
918 sentence.set_embedding(self.name, embedding)
920 def _apply(self, fn):
921 for child_module in self.children():
922 child_module._apply(fn)