Coverage for flair/flair/embeddings/document.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from abc import abstractmethod
3from typing import List, Union
5import torch
6from sklearn.feature_extraction.text import TfidfVectorizer
7from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
8from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer
10import flair
11from flair.data import Sentence
12from flair.embeddings.base import Embeddings, ScalarMix
13from flair.embeddings.token import TokenEmbeddings, StackedEmbeddings, FlairEmbeddings
14from flair.nn import LockedDropout, WordDropout
16log = logging.getLogger("flair")
19class DocumentEmbeddings(Embeddings):
20 """Abstract base class for all document-level embeddings. Every new type of document embedding must implement these methods."""
22 @property
23 @abstractmethod
24 def embedding_length(self) -> int:
25 """Returns the length of the embedding vector."""
26 pass
28 @property
29 def embedding_type(self) -> str:
30 return "sentence-level"
33class TransformerDocumentEmbeddings(DocumentEmbeddings):
34 def __init__(
35 self,
36 model: str = "bert-base-uncased",
37 fine_tune: bool = True,
38 layers: str = "-1",
39 layer_mean: bool = False,
40 pooling: str = "cls",
41 **kwargs
42 ):
43 """
44 Bidirectional transformer embeddings of words from various transformer architectures.
45 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
46 options)
47 :param fine_tune: If True, allows transformers to be fine-tuned during training
48 :param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer
49 models tend to be huge.
50 :param layers: string indicating which layers to take for embedding (-1 is topmost layer)
51 :param layer_mean: If True, uses a scalar mix of layers as embedding
52 :param pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'.
53 """
54 super().__init__()
56 if pooling not in ['cls', 'max', 'mean']:
57 raise ValueError(f"Pooling operation `{pooling}` is not defined for TransformerDocumentEmbeddings")
59 # temporary fix to disable tokenizer parallelism warning
60 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning)
61 import os
62 os.environ["TOKENIZERS_PARALLELISM"] = "false"
64 # do not print transformer warnings as these are confusing in this case
65 from transformers import logging
66 logging.set_verbosity_error()
68 # load tokenizer and transformer model
69 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
70 if self.tokenizer.model_max_length > 1000000000:
71 self.tokenizer.model_max_length = 512
72 log.info("No model_max_length in Tokenizer's config.json - setting it to 512. "
73 "Specify desired model_max_length by passing it as attribute to embedding instance.")
74 if not 'config' in kwargs:
75 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
76 self.model = AutoModel.from_pretrained(model, config=config)
77 else:
78 self.model = AutoModel.from_pretrained(None, **kwargs)
80 logging.set_verbosity_warning()
82 # model name
83 self.name = 'transformer-document-' + str(model)
84 self.base_model_name = str(model)
86 # when initializing, embeddings are in eval mode by default
87 self.model.eval()
88 self.model.to(flair.device)
90 # embedding parameters
91 if layers == 'all':
92 # send mini-token through to check how many layers the model has
93 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1]
94 self.layer_indexes = [int(x) for x in range(len(hidden_states))]
95 else:
96 self.layer_indexes = [int(x) for x in layers.split(",")]
98 self.layer_mean = layer_mean
99 self.fine_tune = fine_tune
100 self.static_embeddings = not self.fine_tune
101 self.pooling = pooling
103 # check whether CLS is at beginning or end
104 self.initial_cls_token: bool = self._has_initial_cls_token(tokenizer=self.tokenizer)
106 @staticmethod
107 def _has_initial_cls_token(tokenizer: PreTrainedTokenizer) -> bool:
108 # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial
109 tokens = tokenizer.encode('a')
110 initial_cls_token: bool = False
111 if tokens[0] == tokenizer.cls_token_id: initial_cls_token = True
112 return initial_cls_token
114 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
115 """Add embeddings to all words in a list of sentences."""
117 # gradients are enabled if fine-tuning is enabled
118 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
120 with gradient_context:
122 # first, subtokenize each sentence and find out into how many subtokens each token was divided
123 subtokenized_sentences = []
125 # subtokenize sentences
126 for sentence in sentences:
127 # tokenize and truncate to max subtokens (TODO: check better truncation strategies)
128 subtokenized_sentence = self.tokenizer.encode(sentence.to_tokenized_string(),
129 add_special_tokens=True,
130 max_length=self.tokenizer.model_max_length,
131 truncation=True,
132 )
134 subtokenized_sentences.append(
135 torch.tensor(subtokenized_sentence, dtype=torch.long, device=flair.device))
137 # find longest sentence in batch
138 longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len))
140 # initialize batch tensors and mask
141 input_ids = torch.zeros(
142 [len(sentences), longest_sequence_in_batch],
143 dtype=torch.long,
144 device=flair.device,
145 )
146 mask = torch.zeros(
147 [len(sentences), longest_sequence_in_batch],
148 dtype=torch.long,
149 device=flair.device,
150 )
151 for s_id, sentence in enumerate(subtokenized_sentences):
152 sequence_length = len(sentence)
153 input_ids[s_id][:sequence_length] = sentence
154 mask[s_id][:sequence_length] = torch.ones(sequence_length)
156 # put encoded batch through transformer model to get all hidden states of all encoder layers
157 hidden_states = self.model(input_ids, attention_mask=mask)[-1] if len(sentences) > 1 \
158 else self.model(input_ids)[-1]
160 # iterate over all subtokenized sentences
161 for sentence_idx, (sentence, subtokens) in enumerate(zip(sentences, subtokenized_sentences)):
163 if self.pooling == "cls":
164 index_of_CLS_token = 0 if self.initial_cls_token else len(subtokens) - 1
166 cls_embeddings_all_layers: List[torch.FloatTensor] = \
167 [hidden_states[layer][sentence_idx][index_of_CLS_token] for layer in self.layer_indexes]
169 embeddings_all_layers = cls_embeddings_all_layers
171 elif self.pooling == "mean":
172 mean_embeddings_all_layers: List[torch.FloatTensor] = \
173 [torch.mean(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0) for layer in
174 self.layer_indexes]
176 embeddings_all_layers = mean_embeddings_all_layers
178 elif self.pooling == "max":
179 max_embeddings_all_layers: List[torch.FloatTensor] = \
180 [torch.max(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0)[0] for layer in
181 self.layer_indexes]
183 embeddings_all_layers = max_embeddings_all_layers
185 # use scalar mix of embeddings if so selected
186 if self.layer_mean:
187 sm = ScalarMix(mixture_size=len(embeddings_all_layers))
188 sm_embeddings = sm(embeddings_all_layers)
190 embeddings_all_layers = [sm_embeddings]
192 # set the extracted embedding for the token
193 sentence.set_embedding(self.name, torch.cat(embeddings_all_layers))
195 return sentences
197 @property
198 @abstractmethod
199 def embedding_length(self) -> int:
200 """Returns the length of the embedding vector."""
201 return (
202 len(self.layer_indexes) * self.model.config.hidden_size
203 if not self.layer_mean
204 else self.model.config.hidden_size
205 )
207 def __getstate__(self):
208 # special handling for serializing transformer models
209 config_state_dict = self.model.config.__dict__
210 model_state_dict = self.model.state_dict()
212 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-document-')[-1]
214 # serialize the transformer models and the constructor arguments (but nothing else)
215 model_state = {
216 "config_state_dict": config_state_dict,
217 "model_state_dict": model_state_dict,
218 "embedding_length_internal": self.embedding_length,
220 "base_model_name": self.base_model_name,
221 "fine_tune": self.fine_tune,
222 "layer_indexes": self.layer_indexes,
223 "layer_mean": self.layer_mean,
224 "pooling": self.pooling,
225 }
227 return model_state
229 def __setstate__(self, d):
230 self.__dict__ = d
232 # necessary for reverse compatibility with Flair <= 0.7
233 if 'use_scalar_mix' in self.__dict__.keys():
234 self.__dict__['layer_mean'] = d['use_scalar_mix']
236 # special handling for deserializing transformer models
237 if "config_state_dict" in d:
239 # load transformer model
240 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert"
241 config_class = CONFIG_MAPPING[model_type]
242 loaded_config = config_class.from_dict(d["config_state_dict"])
244 # constructor arguments
245 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']])
247 # re-initialize transformer word embeddings with constructor arguments
248 embedding = TransformerDocumentEmbeddings(
249 model=self.__dict__['base_model_name'],
250 fine_tune=self.__dict__['fine_tune'],
251 layers=layers,
252 layer_mean=self.__dict__['layer_mean'],
254 config=loaded_config,
255 state_dict=d["model_state_dict"],
256 pooling=self.__dict__['pooling'] if 'pooling' in self.__dict__ else 'cls',
257 # for backward compatibility with previous models
258 )
260 # I have no idea why this is necessary, but otherwise it doesn't work
261 for key in embedding.__dict__.keys():
262 self.__dict__[key] = embedding.__dict__[key]
264 else:
265 model_name = self.__dict__['name'].split('transformer-document-')[-1]
266 # reload tokenizer to get around serialization issues
267 try:
268 tokenizer = AutoTokenizer.from_pretrained(model_name)
269 except:
270 pass
271 self.tokenizer = tokenizer
274class DocumentPoolEmbeddings(DocumentEmbeddings):
275 def __init__(
276 self,
277 embeddings: List[TokenEmbeddings],
278 fine_tune_mode: str = "none",
279 pooling: str = "mean",
280 ):
281 """The constructor takes a list of embeddings to be combined.
282 :param embeddings: a list of token embeddings
283 :param fine_tune_mode: if set to "linear" a trainable layer is added, if set to
284 "nonlinear", a nonlinearity is added as well. Set this to make the pooling trainable.
285 :param pooling: a string which can any value from ['mean', 'max', 'min']
286 """
287 super().__init__()
289 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
290 self.__embedding_length = self.embeddings.embedding_length
292 # optional fine-tuning on top of embedding layer
293 self.fine_tune_mode = fine_tune_mode
294 if self.fine_tune_mode in ["nonlinear", "linear"]:
295 self.embedding_flex = torch.nn.Linear(
296 self.embedding_length, self.embedding_length, bias=False
297 )
298 self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length))
300 if self.fine_tune_mode in ["nonlinear"]:
301 self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length)
302 self.embedding_flex_nonlinear_map = torch.nn.Linear(
303 self.embedding_length, self.embedding_length
304 )
306 self.__embedding_length: int = self.embeddings.embedding_length
308 self.to(flair.device)
310 if pooling not in ['min', 'max', 'mean']:
311 raise ValueError(f"Pooling operation for {self.mode!r} is not defined")
313 self.pooling = pooling
314 self.name: str = f"document_{self.pooling}"
316 @property
317 def embedding_length(self) -> int:
318 return self.__embedding_length
320 def embed(self, sentences: Union[List[Sentence], Sentence]):
321 """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates
322 only if embeddings are non-static."""
324 # if only one sentence is passed, convert to list of sentence
325 if isinstance(sentences, Sentence):
326 sentences = [sentences]
328 self.embeddings.embed(sentences)
330 for sentence in sentences:
331 word_embeddings = []
332 for token in sentence.tokens:
333 word_embeddings.append(token.get_embedding().unsqueeze(0))
335 word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device)
337 if self.fine_tune_mode in ["nonlinear", "linear"]:
338 word_embeddings = self.embedding_flex(word_embeddings)
340 if self.fine_tune_mode in ["nonlinear"]:
341 word_embeddings = self.embedding_flex_nonlinear(word_embeddings)
342 word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings)
344 if self.pooling == "mean":
345 pooled_embedding = torch.mean(word_embeddings, 0)
346 elif self.pooling == "max":
347 pooled_embedding, _ = torch.max(word_embeddings, 0)
348 elif self.pooling == "min":
349 pooled_embedding, _ = torch.min(word_embeddings, 0)
351 sentence.set_embedding(self.name, pooled_embedding)
353 def _add_embeddings_internal(self, sentences: List[Sentence]):
354 pass
356 def extra_repr(self):
357 return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}"
360class DocumentTFIDFEmbeddings(DocumentEmbeddings):
361 def __init__(
362 self,
363 train_dataset,
364 **vectorizer_params,
365 ):
366 """The constructor for DocumentTFIDFEmbeddings.
367 :param train_dataset: the train dataset which will be used to construct vectorizer
368 :param vectorizer_params: parameters given to Scikit-learn's TfidfVectorizer constructor
369 """
370 super().__init__()
372 import numpy as np
373 self.vectorizer = TfidfVectorizer(dtype=np.float32, **vectorizer_params)
374 self.vectorizer.fit([s.to_original_text() for s in train_dataset])
376 self.__embedding_length: int = len(self.vectorizer.vocabulary_)
378 self.to(flair.device)
380 self.name: str = f"document_tfidf"
382 @property
383 def embedding_length(self) -> int:
384 return self.__embedding_length
386 def embed(self, sentences: Union[List[Sentence], Sentence]):
387 """Add embeddings to every sentence in the given list of sentences."""
389 # if only one sentence is passed, convert to list of sentence
390 if isinstance(sentences, Sentence):
391 sentences = [sentences]
393 raw_sentences = [s.to_original_text() for s in sentences]
394 tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A)
396 for sentence_id, sentence in enumerate(sentences):
397 sentence.set_embedding(self.name, tfidf_vectors[sentence_id])
399 def _add_embeddings_internal(self, sentences: List[Sentence]):
400 pass
403class DocumentRNNEmbeddings(DocumentEmbeddings):
404 def __init__(
405 self,
406 embeddings: List[TokenEmbeddings],
407 hidden_size=128,
408 rnn_layers=1,
409 reproject_words: bool = True,
410 reproject_words_dimension: int = None,
411 bidirectional: bool = False,
412 dropout: float = 0.5,
413 word_dropout: float = 0.0,
414 locked_dropout: float = 0.0,
415 rnn_type="GRU",
416 fine_tune: bool = True,
417 ):
418 """The constructor takes a list of embeddings to be combined.
419 :param embeddings: a list of token embeddings
420 :param hidden_size: the number of hidden states in the rnn
421 :param rnn_layers: the number of layers for the rnn
422 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
423 layer before putting them into the rnn or not
424 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
425 dimension as before will be taken.
426 :param bidirectional: boolean value, indicating whether to use a bidirectional rnn or not
427 :param dropout: the dropout value to be used
428 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
429 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
430 :param rnn_type: 'GRU' or 'LSTM'
431 """
432 super().__init__()
434 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
436 self.rnn_type = rnn_type
438 self.reproject_words = reproject_words
439 self.bidirectional = bidirectional
441 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length
443 self.static_embeddings = False if fine_tune else True
445 self.__embedding_length: int = hidden_size
446 if self.bidirectional:
447 self.__embedding_length *= 4
449 self.embeddings_dimension: int = self.length_of_all_token_embeddings
450 if self.reproject_words and reproject_words_dimension is not None:
451 self.embeddings_dimension = reproject_words_dimension
453 self.word_reprojection_map = torch.nn.Linear(
454 self.length_of_all_token_embeddings, self.embeddings_dimension
455 )
457 # bidirectional RNN on top of embedding layer
458 if rnn_type == "LSTM":
459 self.rnn = torch.nn.LSTM(
460 self.embeddings_dimension,
461 hidden_size,
462 num_layers=rnn_layers,
463 bidirectional=self.bidirectional,
464 batch_first=True,
465 )
466 else:
467 self.rnn = torch.nn.GRU(
468 self.embeddings_dimension,
469 hidden_size,
470 num_layers=rnn_layers,
471 bidirectional=self.bidirectional,
472 batch_first=True,
473 )
475 self.name = "document_" + self.rnn._get_name()
477 # dropouts
478 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
479 self.locked_dropout = (
480 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
481 )
482 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None
484 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)
486 self.to(flair.device)
488 self.eval()
490 @property
491 def embedding_length(self) -> int:
492 return self.__embedding_length
494 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
495 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
496 only if embeddings are non-static."""
498 # TODO: remove in future versions
499 if not hasattr(self, "locked_dropout"):
500 self.locked_dropout = None
501 if not hasattr(self, "word_dropout"):
502 self.word_dropout = None
504 if type(sentences) is Sentence:
505 sentences = [sentences]
507 self.rnn.zero_grad()
509 # embed words in the sentence
510 self.embeddings.embed(sentences)
512 lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
513 longest_token_sequence_in_batch: int = max(lengths)
515 pre_allocated_zero_tensor = torch.zeros(
516 self.embeddings.embedding_length * longest_token_sequence_in_batch,
517 dtype=torch.float,
518 device=flair.device,
519 )
521 all_embs: List[torch.Tensor] = list()
522 for sentence in sentences:
523 all_embs += [
524 emb for token in sentence for emb in token.get_each_embedding()
525 ]
526 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
528 if nb_padding_tokens > 0:
529 t = pre_allocated_zero_tensor[
530 : self.embeddings.embedding_length * nb_padding_tokens
531 ]
532 all_embs.append(t)
534 sentence_tensor = torch.cat(all_embs).view(
535 [
536 len(sentences),
537 longest_token_sequence_in_batch,
538 self.embeddings.embedding_length,
539 ]
540 )
542 # before-RNN dropout
543 if self.dropout:
544 sentence_tensor = self.dropout(sentence_tensor)
545 if self.locked_dropout:
546 sentence_tensor = self.locked_dropout(sentence_tensor)
547 if self.word_dropout:
548 sentence_tensor = self.word_dropout(sentence_tensor)
550 # reproject if set
551 if self.reproject_words:
552 sentence_tensor = self.word_reprojection_map(sentence_tensor)
554 # push through RNN
555 packed = pack_padded_sequence(
556 sentence_tensor, lengths, enforce_sorted=False, batch_first=True
557 )
558 rnn_out, hidden = self.rnn(packed)
559 outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True)
561 # after-RNN dropout
562 if self.dropout:
563 outputs = self.dropout(outputs)
564 if self.locked_dropout:
565 outputs = self.locked_dropout(outputs)
567 # extract embeddings from RNN
568 for sentence_no, length in enumerate(lengths):
569 last_rep = outputs[sentence_no, length - 1]
571 embedding = last_rep
572 if self.bidirectional:
573 first_rep = outputs[sentence_no, 0]
574 embedding = torch.cat([first_rep, last_rep], 0)
576 if self.static_embeddings:
577 embedding = embedding.detach()
579 sentence = sentences[sentence_no]
580 sentence.set_embedding(self.name, embedding)
582 def _apply(self, fn):
584 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute
585 # check if this is the case and if so, set it
586 for child_module in self.children():
587 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"):
588 _flat_weights_names = []
590 if child_module.__dict__["bidirectional"]:
591 num_direction = 2
592 else:
593 num_direction = 1
594 for layer in range(child_module.__dict__["num_layers"]):
595 for direction in range(num_direction):
596 suffix = "_reverse" if direction == 1 else ""
597 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"]
598 if child_module.__dict__["bias"]:
599 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"]
600 param_names = [
601 x.format(layer, suffix) for x in param_names
602 ]
603 _flat_weights_names.extend(param_names)
605 setattr(child_module, "_flat_weights_names",
606 _flat_weights_names)
608 child_module._apply(fn)
610 def __getstate__(self):
612 # serialize the language models and the constructor arguments (but nothing else)
613 model_state = {
614 "state_dict": self.state_dict(),
616 "embeddings": self.embeddings.embeddings,
617 "hidden_size": self.rnn.hidden_size,
618 "rnn_layers": self.rnn.num_layers,
619 "reproject_words": self.reproject_words,
620 "reproject_words_dimension": self.embeddings_dimension,
621 "bidirectional": self.bidirectional,
622 "dropout": self.dropout.p if self.dropout is not None else 0.,
623 "word_dropout": self.word_dropout.p if self.word_dropout is not None else 0.,
624 "locked_dropout": self.locked_dropout.p if self.locked_dropout is not None else 0.,
625 "rnn_type": self.rnn_type,
626 "fine_tune": not self.static_embeddings,
627 }
629 return model_state
631 def __setstate__(self, d):
633 # special handling for deserializing language models
634 if "state_dict" in d:
636 # re-initialize language model with constructor arguments
637 language_model = DocumentRNNEmbeddings(
638 embeddings=d['embeddings'],
639 hidden_size=d['hidden_size'],
640 rnn_layers=d['rnn_layers'],
641 reproject_words=d['reproject_words'],
642 reproject_words_dimension=d['reproject_words_dimension'],
643 bidirectional=d['bidirectional'],
644 dropout=d['dropout'],
645 word_dropout=d['word_dropout'],
646 locked_dropout=d['locked_dropout'],
647 rnn_type=d['rnn_type'],
648 fine_tune=d['fine_tune'],
649 )
651 language_model.load_state_dict(d['state_dict'])
653 # copy over state dictionary to self
654 for key in language_model.__dict__.keys():
655 self.__dict__[key] = language_model.__dict__[key]
657 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM
658 # in their "self.train()" method)
659 self.eval()
661 else:
662 self.__dict__ = d
665class DocumentLMEmbeddings(DocumentEmbeddings):
666 def __init__(self, flair_embeddings: List[FlairEmbeddings]):
667 super().__init__()
669 self.embeddings = flair_embeddings
670 self.name = "document_lm"
672 # IMPORTANT: add embeddings as torch modules
673 for i, embedding in enumerate(flair_embeddings):
674 self.add_module("lm_embedding_{}".format(i), embedding)
675 if not embedding.static_embeddings:
676 self.static_embeddings = False
678 self._embedding_length: int = sum(
679 embedding.embedding_length for embedding in flair_embeddings
680 )
682 @property
683 def embedding_length(self) -> int:
684 return self._embedding_length
686 def _add_embeddings_internal(self, sentences: List[Sentence]):
687 if type(sentences) is Sentence:
688 sentences = [sentences]
690 for embedding in self.embeddings:
691 embedding.embed(sentences)
693 # iterate over sentences
694 for sentence in sentences:
695 sentence: Sentence = sentence
697 # if its a forward LM, take last state
698 if embedding.is_forward_lm:
699 sentence.set_embedding(
700 embedding.name,
701 sentence[len(sentence) - 1]._embeddings[embedding.name],
702 )
703 else:
704 sentence.set_embedding(
705 embedding.name, sentence[0]._embeddings[embedding.name]
706 )
708 return sentences
711class SentenceTransformerDocumentEmbeddings(DocumentEmbeddings):
712 def __init__(
713 self,
714 model: str = "bert-base-nli-mean-tokens",
715 batch_size: int = 1,
716 convert_to_numpy: bool = False,
717 ):
718 """
719 :param model: string name of models from SentencesTransformer Class
720 :param name: string name of embedding type which will be set to Sentence object
721 :param batch_size: int number of sentences to processed in one batch
722 :param convert_to_numpy: bool whether the encode() returns a numpy array or PyTorch tensor
723 """
724 super().__init__()
726 try:
727 from sentence_transformers import SentenceTransformer
728 except ModuleNotFoundError:
729 log.warning("-" * 100)
730 log.warning('ATTENTION! The library "sentence-transformers" is not installed!')
731 log.warning(
732 'To use Sentence Transformers, please first install with "pip install sentence-transformers"'
733 )
734 log.warning("-" * 100)
735 pass
737 self.model = SentenceTransformer(model)
738 self.name = 'sentence-transformers-' + str(model)
739 self.batch_size = batch_size
740 self.convert_to_numpy = convert_to_numpy
741 self.static_embeddings = True
743 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
745 sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size]
746 for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)]
748 for batch in sentence_batches:
749 self._add_embeddings_to_sentences(batch)
751 return sentences
753 def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
755 # convert to plain strings, embedded in a list for the encode function
756 sentences_plain_text = [sentence.to_plain_string() for sentence in sentences]
758 embeddings = self.model.encode(sentences_plain_text, convert_to_numpy=self.convert_to_numpy)
759 for sentence, embedding in zip(sentences, embeddings):
760 sentence.set_embedding(self.name, embedding)
762 @property
763 @abstractmethod
764 def embedding_length(self) -> int:
765 """Returns the length of the embedding vector."""
766 return self.model.get_sentence_embedding_dimension()
769class DocumentCNNEmbeddings(DocumentEmbeddings):
770 def __init__(
771 self,
772 embeddings: List[TokenEmbeddings],
773 kernels=((100, 3), (100, 4), (100, 5)),
774 reproject_words: bool = True,
775 reproject_words_dimension: int = None,
776 dropout: float = 0.5,
777 word_dropout: float = 0.0,
778 locked_dropout: float = 0.0,
779 fine_tune: bool = True,
780 ):
781 """The constructor takes a list of embeddings to be combined.
782 :param embeddings: a list of token embeddings
783 :param kernels: list of (number of kernels, kernel size)
784 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear
785 layer before putting them into the rnn or not
786 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output
787 dimension as before will be taken.
788 :param dropout: the dropout value to be used
789 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used
790 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used
791 """
792 super().__init__()
794 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings)
795 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length
797 self.kernels = kernels
798 self.reproject_words = reproject_words
800 self.static_embeddings = False if fine_tune else True
802 self.embeddings_dimension: int = self.length_of_all_token_embeddings
803 if self.reproject_words and reproject_words_dimension is not None:
804 self.embeddings_dimension = reproject_words_dimension
806 self.word_reprojection_map = torch.nn.Linear(
807 self.length_of_all_token_embeddings, self.embeddings_dimension
808 )
810 # CNN
811 self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels])
812 self.convs = torch.nn.ModuleList(
813 [
814 torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in
815 self.kernels
816 ]
817 )
818 self.pool = torch.nn.AdaptiveMaxPool1d(1)
820 self.name = "document_cnn"
822 # dropouts
823 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None
824 self.locked_dropout = (
825 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None
826 )
827 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None
829 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight)
831 self.to(flair.device)
833 self.eval()
835 @property
836 def embedding_length(self) -> int:
837 return self.__embedding_length
839 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]):
840 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update
841 only if embeddings are non-static."""
843 # TODO: remove in future versions
844 if not hasattr(self, "locked_dropout"):
845 self.locked_dropout = None
846 if not hasattr(self, "word_dropout"):
847 self.word_dropout = None
849 if type(sentences) is Sentence:
850 sentences = [sentences]
852 self.zero_grad() # is it necessary?
854 # embed words in the sentence
855 self.embeddings.embed(sentences)
857 lengths: List[int] = [len(sentence.tokens) for sentence in sentences]
858 longest_token_sequence_in_batch: int = max(lengths)
860 pre_allocated_zero_tensor = torch.zeros(
861 self.embeddings.embedding_length * longest_token_sequence_in_batch,
862 dtype=torch.float,
863 device=flair.device,
864 )
866 all_embs: List[torch.Tensor] = list()
867 for sentence in sentences:
868 all_embs += [
869 emb for token in sentence for emb in token.get_each_embedding()
870 ]
871 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence)
873 if nb_padding_tokens > 0:
874 t = pre_allocated_zero_tensor[
875 : self.embeddings.embedding_length * nb_padding_tokens
876 ]
877 all_embs.append(t)
879 sentence_tensor = torch.cat(all_embs).view(
880 [
881 len(sentences),
882 longest_token_sequence_in_batch,
883 self.embeddings.embedding_length,
884 ]
885 )
887 # before-RNN dropout
888 if self.dropout:
889 sentence_tensor = self.dropout(sentence_tensor)
890 if self.locked_dropout:
891 sentence_tensor = self.locked_dropout(sentence_tensor)
892 if self.word_dropout:
893 sentence_tensor = self.word_dropout(sentence_tensor)
895 # reproject if set
896 if self.reproject_words:
897 sentence_tensor = self.word_reprojection_map(sentence_tensor)
899 # push CNN
900 x = sentence_tensor
901 x = x.permute(0, 2, 1)
903 rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs]
904 outputs = torch.cat(rep, 1)
906 outputs = outputs.reshape(outputs.size(0), -1)
908 # after-CNN dropout
909 if self.dropout:
910 outputs = self.dropout(outputs)
911 if self.locked_dropout:
912 outputs = self.locked_dropout(outputs)
914 # extract embeddings from CNN
915 for sentence_no, length in enumerate(lengths):
916 embedding = outputs[sentence_no]
918 if self.static_embeddings:
919 embedding = embedding.detach()
921 sentence = sentences[sentence_no]
922 sentence.set_embedding(self.name, embedding)
924 def _apply(self, fn):
925 for child_module in self.children():
926 child_module._apply(fn)