Coverage for flair/flair/embeddings/token.py: 14%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import hashlib
2import inspect
3import logging
4import os
5import re
6from abc import abstractmethod
7from collections import Counter
8from pathlib import Path
9from typing import List, Union, Dict, Optional
11import gensim
12import numpy as np
13import torch
14from bpemb import BPEmb
15from gensim.models import KeyedVectors
16from torch import nn
17from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer, XLNetModel, \
18 TransfoXLModel
20import flair
21from flair.data import Sentence, Token, Corpus, Dictionary
22from flair.embeddings.base import Embeddings
23from flair.file_utils import cached_path, open_inside_zip, instance_lru_cache
25log = logging.getLogger("flair")
28class TokenEmbeddings(Embeddings):
29 """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""
31 @property
32 @abstractmethod
33 def embedding_length(self) -> int:
34 """Returns the length of the embedding vector."""
35 pass
37 @property
38 def embedding_type(self) -> str:
39 return "word-level"
41 @staticmethod
42 def get_instance_parameters(locals: dict) -> dict:
43 class_definition = locals.get("__class__")
44 instance_parameters = set(inspect.getfullargspec(class_definition.__init__).args)
45 instance_parameters.difference_update(set(["self"]))
46 instance_parameters.update(set(["__class__"]))
47 instance_parameters = {class_attribute: attribute_value for class_attribute, attribute_value in locals.items()
48 if class_attribute in instance_parameters}
49 return instance_parameters
52class StackedEmbeddings(TokenEmbeddings):
53 """A stack of embeddings, used if you need to combine several different embedding types."""
55 def __init__(self, embeddings: List[TokenEmbeddings]):
56 """The constructor takes a list of embeddings to be combined."""
57 super().__init__()
59 self.embeddings = embeddings
61 # IMPORTANT: add embeddings as torch modules
62 for i, embedding in enumerate(embeddings):
63 embedding.name = f"{str(i)}-{embedding.name}"
64 self.add_module(f"list_embedding_{str(i)}", embedding)
66 self.name: str = "Stack"
67 self.static_embeddings: bool = True
69 self.__embedding_type: str = embeddings[0].embedding_type
71 self.__embedding_length: int = 0
72 for embedding in embeddings:
73 self.__embedding_length += embedding.embedding_length
75 def embed(
76 self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True
77 ):
78 # if only one sentence is passed, convert to list of sentence
79 if type(sentences) is Sentence:
80 sentences = [sentences]
82 for embedding in self.embeddings:
83 embedding.embed(sentences)
85 @property
86 def embedding_type(self) -> str:
87 return self.__embedding_type
89 @property
90 def embedding_length(self) -> int:
91 return self.__embedding_length
93 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
95 for embedding in self.embeddings:
96 embedding._add_embeddings_internal(sentences)
98 return sentences
100 def __str__(self):
101 return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]'
103 def get_names(self) -> List[str]:
104 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of
105 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding).
106 Then, the list contains the names of all embeddings in the stack."""
107 names = []
108 for embedding in self.embeddings:
109 names.extend(embedding.get_names())
110 return names
112 def get_named_embeddings_dict(self) -> Dict:
114 named_embeddings_dict = {}
115 for embedding in self.embeddings:
116 named_embeddings_dict.update(embedding.get_named_embeddings_dict())
118 return named_embeddings_dict
121class WordEmbeddings(TokenEmbeddings):
122 """Standard static word embeddings, such as GloVe or FastText."""
124 def __init__(self, embeddings: str, field: str = None, fine_tune: bool = False, force_cpu: bool = True,
125 stable: bool = False):
126 """
127 Initializes classic word embeddings. Constructor downloads required files if not there.
128 :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or custom
129 If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable.
130 set stable=True to use the stable embeddings as described in https://arxiv.org/abs/2110.02861
131 """
132 self.embeddings = embeddings
134 self.instance_parameters = self.get_instance_parameters(locals=locals())
136 if fine_tune and force_cpu and flair.device.type != "cpu":
137 raise ValueError("Cannot train WordEmbeddings on cpu if the model is trained on gpu, set force_cpu=False")
139 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/token"
141 cache_dir = Path("embeddings")
143 # GLOVE embeddings
144 if embeddings.lower() == "glove" or embeddings.lower() == "en-glove":
145 cached_path(f"{hu_path}/glove.gensim.vectors.npy", cache_dir=cache_dir)
146 embeddings = cached_path(f"{hu_path}/glove.gensim", cache_dir=cache_dir)
148 # TURIAN embeddings
149 elif embeddings.lower() == "turian" or embeddings.lower() == "en-turian":
150 cached_path(f"{hu_path}/turian.vectors.npy", cache_dir=cache_dir)
151 embeddings = cached_path(f"{hu_path}/turian", cache_dir=cache_dir)
153 # KOMNINOS embeddings
154 elif embeddings.lower() == "extvec" or embeddings.lower() == "en-extvec":
155 cached_path(f"{hu_path}/extvec.gensim.vectors.npy", cache_dir=cache_dir)
156 embeddings = cached_path(f"{hu_path}/extvec.gensim", cache_dir=cache_dir)
158 # pubmed embeddings
159 elif embeddings.lower() == "pubmed" or embeddings.lower() == "en-pubmed":
160 cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim.vectors.npy", cache_dir=cache_dir)
161 embeddings = cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim", cache_dir=cache_dir)
163 # FT-CRAWL embeddings
164 elif embeddings.lower() == "crawl" or embeddings.lower() == "en-crawl":
165 cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M.vectors.npy", cache_dir=cache_dir)
166 embeddings = cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M", cache_dir=cache_dir)
168 # FT-CRAWL embeddings
169 elif embeddings.lower() in ["news", "en-news", "en"]:
170 cached_path(f"{hu_path}/en-fasttext-news-300d-1M.vectors.npy", cache_dir=cache_dir)
171 embeddings = cached_path(f"{hu_path}/en-fasttext-news-300d-1M", cache_dir=cache_dir)
173 # twitter embeddings
174 elif embeddings.lower() in ["twitter", "en-twitter"]:
175 cached_path(f"{hu_path}/twitter.gensim.vectors.npy", cache_dir=cache_dir)
176 embeddings = cached_path(f"{hu_path}/twitter.gensim", cache_dir=cache_dir)
178 # two-letter language code wiki embeddings
179 elif len(embeddings.lower()) == 2:
180 cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
181 embeddings = cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M", cache_dir=cache_dir)
183 # two-letter language code wiki embeddings
184 elif len(embeddings.lower()) == 7 and embeddings.endswith("-wiki"):
185 cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
186 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M", cache_dir=cache_dir)
188 # two-letter language code crawl embeddings
189 elif len(embeddings.lower()) == 8 and embeddings.endswith("-crawl"):
190 cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
191 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M", cache_dir=cache_dir)
193 elif not Path(embeddings).exists():
194 raise ValueError(
195 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
196 )
198 self.name: str = str(embeddings)
199 self.static_embeddings = not fine_tune
200 self.fine_tune = fine_tune
201 self.force_cpu = force_cpu
202 self.field = field
203 self.stable = stable
204 super().__init__()
206 if str(embeddings).endswith(".bin"):
207 precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
208 str(embeddings), binary=True
209 )
210 else:
211 precomputed_word_embeddings = gensim.models.KeyedVectors.load(
212 str(embeddings)
213 )
215 self.__embedding_length: int = precomputed_word_embeddings.vector_size
217 vectors = np.row_stack(
218 (precomputed_word_embeddings.vectors, np.zeros(self.__embedding_length, dtype="float"))
219 )
220 self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not fine_tune)
222 try:
223 # gensim version 4
224 self.vocab = precomputed_word_embeddings.key_to_index
225 except:
226 # gensim version 3
227 self.vocab = {k: v.index for k, v in precomputed_word_embeddings.vocab.items()}
229 if stable:
230 self.layer_norm = nn.LayerNorm(self.__embedding_length, elementwise_affine=fine_tune)
231 else:
232 self.layer_norm = None
234 self.device = None
235 self.to(flair.device)
237 @property
238 def embedding_length(self) -> int:
239 return self.__embedding_length
241 @instance_lru_cache(maxsize=100000, typed=False)
242 def get_cached_token_index(self, word: str) -> int:
243 if word in self.vocab:
244 return self.vocab[word]
245 elif word.lower() in self.vocab:
246 return self.vocab[word.lower()]
247 elif re.sub(r"\d", "#", word.lower()) in self.vocab:
248 return self.vocab[
249 re.sub(r"\d", "#", word.lower())
250 ]
251 elif re.sub(r"\d", "0", word.lower()) in self.vocab:
252 return self.vocab[
253 re.sub(r"\d", "0", word.lower())
254 ]
255 else:
256 return len(self.vocab) # <unk> token
258 def get_vec(self, word: str) -> torch.Tensor:
259 word_embedding = self.vectors[self.get_cached_token_index(word)]
261 word_embedding = torch.tensor(
262 word_embedding.tolist(), device=flair.device, dtype=torch.float
263 )
264 return word_embedding
266 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
268 tokens = [token for sentence in sentences for token in sentence.tokens]
270 word_indices: List[int] = []
271 for token in tokens:
272 if "field" not in self.__dict__ or self.field is None:
273 word = token.text
274 else:
275 word = token.get_tag(self.field).value
276 word_indices.append(self.get_cached_token_index(word))
278 embeddings = self.embedding(torch.tensor(word_indices, dtype=torch.long, device=self.device))
279 if self.stable:
280 embeddings = self.layer_norm(embeddings)
282 if self.force_cpu:
283 embeddings = embeddings.to(flair.device)
285 for emb, token in zip(embeddings, tokens):
286 token.set_embedding(self.name, emb)
288 return sentences
290 def __str__(self):
291 return self.name
293 def extra_repr(self):
294 # fix serialized models
295 if "embeddings" not in self.__dict__:
296 self.embeddings = self.name
298 return f"'{self.embeddings}'"
300 def train(self, mode=True):
301 if not self.fine_tune:
302 pass
303 else:
304 super(WordEmbeddings, self).train(mode)
306 def to(self, device):
307 if self.force_cpu:
308 device = torch.device("cpu")
309 self.device = device
310 super(WordEmbeddings, self).to(device)
312 def _apply(self, fn):
313 if fn.__name__ == "convert" and self.force_cpu:
314 # this is required to force the module on the cpu,
315 # if a parent module is put to gpu, the _apply is called to each sub_module
316 # self.to(..) actually sets the device properly
317 if not hasattr(self, "device"):
318 self.to(flair.device)
319 return
320 super(WordEmbeddings, self)._apply(fn)
322 def __getattribute__(self, item):
323 # this ignores the get_cached_vec method when loading older versions
324 # it is needed for compatibility reasons
325 if "get_cached_vec" == item:
326 return None
327 return super().__getattribute__(item)
329 def __setstate__(self, state):
330 if "get_cached_vec" in state:
331 del state["get_cached_vec"]
332 if "force_cpu" not in state:
333 state["force_cpu"] = True
334 if "fine_tune" not in state:
335 state["fine_tune"] = False
336 if "precomputed_word_embeddings" in state:
337 precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings")
338 vectors = np.row_stack(
339 (precomputed_word_embeddings.vectors, np.zeros(precomputed_word_embeddings.vector_size, dtype="float"))
340 )
341 embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not state["fine_tune"])
343 try:
344 # gensim version 4
345 vocab = precomputed_word_embeddings.key_to_index
346 except:
347 # gensim version 3
348 vocab = {k: v.index for k, v in precomputed_word_embeddings.__dict__["vocab"].items()}
349 state["embedding"] = embedding
350 state["vocab"] = vocab
351 if "stable" not in state:
352 state["stable"] = False
353 state["layer_norm"] = None
355 super().__setstate__(state)
358class CharacterEmbeddings(TokenEmbeddings):
359 """Character embeddings of words, as proposed in Lample et al., 2016."""
361 def __init__(
362 self,
363 path_to_char_dict: str = None,
364 char_embedding_dim: int = 25,
365 hidden_size_char: int = 25,
366 ):
367 """Uses the default character dictionary if none provided."""
369 super().__init__()
370 self.name = "Char"
371 self.static_embeddings = False
372 self.instance_parameters = self.get_instance_parameters(locals=locals())
374 # use list of common characters if none provided
375 if path_to_char_dict is None:
376 self.char_dictionary: Dictionary = Dictionary.load("common-chars")
377 else:
378 self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict)
380 self.char_embedding_dim: int = char_embedding_dim
381 self.hidden_size_char: int = hidden_size_char
382 self.char_embedding = torch.nn.Embedding(
383 len(self.char_dictionary.item2idx), self.char_embedding_dim
384 )
385 self.char_rnn = torch.nn.LSTM(
386 self.char_embedding_dim,
387 self.hidden_size_char,
388 num_layers=1,
389 bidirectional=True,
390 )
392 self.__embedding_length = self.hidden_size_char * 2
394 self.to(flair.device)
396 @property
397 def embedding_length(self) -> int:
398 return self.__embedding_length
400 def _add_embeddings_internal(self, sentences: List[Sentence]):
402 for sentence in sentences:
404 tokens_char_indices = []
406 # translate words in sentence into ints using dictionary
407 for token in sentence.tokens:
408 char_indices = [
409 self.char_dictionary.get_idx_for_item(char) for char in token.text
410 ]
411 tokens_char_indices.append(char_indices)
413 # sort words by length, for batching and masking
414 tokens_sorted_by_length = sorted(
415 tokens_char_indices, key=lambda p: len(p), reverse=True
416 )
417 d = {}
418 for i, ci in enumerate(tokens_char_indices):
419 for j, cj in enumerate(tokens_sorted_by_length):
420 if ci == cj:
421 d[j] = i
422 continue
423 chars2_length = [len(c) for c in tokens_sorted_by_length]
424 longest_token_in_sentence = max(chars2_length)
425 tokens_mask = torch.zeros(
426 (len(tokens_sorted_by_length), longest_token_in_sentence),
427 dtype=torch.long,
428 device=flair.device,
429 )
431 for i, c in enumerate(tokens_sorted_by_length):
432 tokens_mask[i, : chars2_length[i]] = torch.tensor(
433 c, dtype=torch.long, device=flair.device
434 )
436 # chars for rnn processing
437 chars = tokens_mask
439 character_embeddings = self.char_embedding(chars).transpose(0, 1)
441 packed = torch.nn.utils.rnn.pack_padded_sequence(
442 character_embeddings, chars2_length
443 )
445 lstm_out, self.hidden = self.char_rnn(packed)
447 outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
448 outputs = outputs.transpose(0, 1)
449 chars_embeds_temp = torch.zeros(
450 (outputs.size(0), outputs.size(2)),
451 dtype=torch.float,
452 device=flair.device,
453 )
454 for i, index in enumerate(output_lengths):
455 chars_embeds_temp[i] = outputs[i, index - 1]
456 character_embeddings = chars_embeds_temp.clone()
457 for i in range(character_embeddings.size(0)):
458 character_embeddings[d[i]] = chars_embeds_temp[i]
460 for token_number, token in enumerate(sentence.tokens):
461 token.set_embedding(self.name, character_embeddings[token_number])
463 def __str__(self):
464 return self.name
467class FlairEmbeddings(TokenEmbeddings):
468 """Contextual string embeddings of words, as proposed in Akbik et al., 2018."""
470 def __init__(self,
471 model,
472 fine_tune: bool = False,
473 chars_per_chunk: int = 512,
474 with_whitespace: bool = True,
475 tokenized_lm: bool = True,
476 is_lower: bool = False,
477 ):
478 """
479 initializes contextual string embeddings using a character-level language model.
480 :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
481 'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward',
482 etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md)
483 depending on which character language model is desired.
484 :param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows
485 down training and often leads to overfitting, so use with caution.
486 :param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster
487 but requires more memory. Lower means slower but less memory.
488 :param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden
489 state at last character of word.
490 :param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text
491 False might be better.
492 """
493 super().__init__()
494 self.instance_parameters = self.get_instance_parameters(locals=locals())
496 cache_dir = Path("embeddings")
498 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/flair"
499 clef_hipe_path: str = "https://files.ifi.uzh.ch/cl/siclemat/impresso/clef-hipe-2020/flair"
500 am_path : str = "http://ltdata1.informatik.uni-hamburg.de/amharic/models/flair/"
502 self.is_lower: bool = is_lower
504 self.PRETRAINED_MODEL_ARCHIVE_MAP = {
505 # multilingual models
506 "multi-forward": f"{hu_path}/lm-jw300-forward-v0.1.pt",
507 "multi-backward": f"{hu_path}/lm-jw300-backward-v0.1.pt",
508 "multi-v0-forward": f"{hu_path}/lm-multi-forward-v0.1.pt",
509 "multi-v0-backward": f"{hu_path}/lm-multi-backward-v0.1.pt",
510 "multi-forward-fast": f"{hu_path}/lm-multi-forward-fast-v0.1.pt",
511 "multi-backward-fast": f"{hu_path}/lm-multi-backward-fast-v0.1.pt",
512 # English models
513 "en-forward": f"{hu_path}/news-forward-0.4.1.pt",
514 "en-backward": f"{hu_path}/news-backward-0.4.1.pt",
515 "en-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt",
516 "en-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt",
517 "news-forward": f"{hu_path}/news-forward-0.4.1.pt",
518 "news-backward": f"{hu_path}/news-backward-0.4.1.pt",
519 "news-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt",
520 "news-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt",
521 "mix-forward": f"{hu_path}/lm-mix-english-forward-v0.2rc.pt",
522 "mix-backward": f"{hu_path}/lm-mix-english-backward-v0.2rc.pt",
523 # Arabic
524 "ar-forward": f"{hu_path}/lm-ar-opus-large-forward-v0.1.pt",
525 "ar-backward": f"{hu_path}/lm-ar-opus-large-backward-v0.1.pt",
526 # Bulgarian
527 "bg-forward-fast": f"{hu_path}/lm-bg-small-forward-v0.1.pt",
528 "bg-backward-fast": f"{hu_path}/lm-bg-small-backward-v0.1.pt",
529 "bg-forward": f"{hu_path}/lm-bg-opus-large-forward-v0.1.pt",
530 "bg-backward": f"{hu_path}/lm-bg-opus-large-backward-v0.1.pt",
531 # Czech
532 "cs-forward": f"{hu_path}/lm-cs-opus-large-forward-v0.1.pt",
533 "cs-backward": f"{hu_path}/lm-cs-opus-large-backward-v0.1.pt",
534 "cs-v0-forward": f"{hu_path}/lm-cs-large-forward-v0.1.pt",
535 "cs-v0-backward": f"{hu_path}/lm-cs-large-backward-v0.1.pt",
536 # Danish
537 "da-forward": f"{hu_path}/lm-da-opus-large-forward-v0.1.pt",
538 "da-backward": f"{hu_path}/lm-da-opus-large-backward-v0.1.pt",
539 # German
540 "de-forward": f"{hu_path}/lm-mix-german-forward-v0.2rc.pt",
541 "de-backward": f"{hu_path}/lm-mix-german-backward-v0.2rc.pt",
542 "de-historic-ha-forward": f"{hu_path}/lm-historic-hamburger-anzeiger-forward-v0.1.pt",
543 "de-historic-ha-backward": f"{hu_path}/lm-historic-hamburger-anzeiger-backward-v0.1.pt",
544 "de-historic-wz-forward": f"{hu_path}/lm-historic-wiener-zeitung-forward-v0.1.pt",
545 "de-historic-wz-backward": f"{hu_path}/lm-historic-wiener-zeitung-backward-v0.1.pt",
546 "de-historic-rw-forward": f"{hu_path}/redewiedergabe_lm_forward.pt",
547 "de-historic-rw-backward": f"{hu_path}/redewiedergabe_lm_backward.pt",
548 # Spanish
549 "es-forward": f"{hu_path}/lm-es-forward.pt",
550 "es-backward": f"{hu_path}/lm-es-backward.pt",
551 "es-forward-fast": f"{hu_path}/lm-es-forward-fast.pt",
552 "es-backward-fast": f"{hu_path}/lm-es-backward-fast.pt",
553 # Basque
554 "eu-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.2.pt",
555 "eu-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.2.pt",
556 "eu-v1-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.1.pt",
557 "eu-v1-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.1.pt",
558 "eu-v0-forward": f"{hu_path}/lm-eu-large-forward-v0.1.pt",
559 "eu-v0-backward": f"{hu_path}/lm-eu-large-backward-v0.1.pt",
560 # Persian
561 "fa-forward": f"{hu_path}/lm-fa-opus-large-forward-v0.1.pt",
562 "fa-backward": f"{hu_path}/lm-fa-opus-large-backward-v0.1.pt",
563 # Finnish
564 "fi-forward": f"{hu_path}/lm-fi-opus-large-forward-v0.1.pt",
565 "fi-backward": f"{hu_path}/lm-fi-opus-large-backward-v0.1.pt",
566 # French
567 "fr-forward": f"{hu_path}/lm-fr-charlm-forward.pt",
568 "fr-backward": f"{hu_path}/lm-fr-charlm-backward.pt",
569 # Hebrew
570 "he-forward": f"{hu_path}/lm-he-opus-large-forward-v0.1.pt",
571 "he-backward": f"{hu_path}/lm-he-opus-large-backward-v0.1.pt",
572 # Hindi
573 "hi-forward": f"{hu_path}/lm-hi-opus-large-forward-v0.1.pt",
574 "hi-backward": f"{hu_path}/lm-hi-opus-large-backward-v0.1.pt",
575 # Croatian
576 "hr-forward": f"{hu_path}/lm-hr-opus-large-forward-v0.1.pt",
577 "hr-backward": f"{hu_path}/lm-hr-opus-large-backward-v0.1.pt",
578 # Indonesian
579 "id-forward": f"{hu_path}/lm-id-opus-large-forward-v0.1.pt",
580 "id-backward": f"{hu_path}/lm-id-opus-large-backward-v0.1.pt",
581 # Italian
582 "it-forward": f"{hu_path}/lm-it-opus-large-forward-v0.1.pt",
583 "it-backward": f"{hu_path}/lm-it-opus-large-backward-v0.1.pt",
584 # Japanese
585 "ja-forward": f"{hu_path}/japanese-forward.pt",
586 "ja-backward": f"{hu_path}/japanese-backward.pt",
587 # Malayalam
588 "ml-forward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-forward.pt",
589 "ml-backward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-backward.pt",
590 # Dutch
591 "nl-forward": f"{hu_path}/lm-nl-opus-large-forward-v0.1.pt",
592 "nl-backward": f"{hu_path}/lm-nl-opus-large-backward-v0.1.pt",
593 "nl-v0-forward": f"{hu_path}/lm-nl-large-forward-v0.1.pt",
594 "nl-v0-backward": f"{hu_path}/lm-nl-large-backward-v0.1.pt",
595 # Norwegian
596 "no-forward": f"{hu_path}/lm-no-opus-large-forward-v0.1.pt",
597 "no-backward": f"{hu_path}/lm-no-opus-large-backward-v0.1.pt",
598 # Polish
599 "pl-forward": f"{hu_path}/lm-polish-forward-v0.2.pt",
600 "pl-backward": f"{hu_path}/lm-polish-backward-v0.2.pt",
601 "pl-opus-forward": f"{hu_path}/lm-pl-opus-large-forward-v0.1.pt",
602 "pl-opus-backward": f"{hu_path}/lm-pl-opus-large-backward-v0.1.pt",
603 # Portuguese
604 "pt-forward": f"{hu_path}/lm-pt-forward.pt",
605 "pt-backward": f"{hu_path}/lm-pt-backward.pt",
606 # Pubmed
607 "pubmed-forward": f"{hu_path}/pubmed-forward.pt",
608 "pubmed-backward": f"{hu_path}/pubmed-backward.pt",
609 "pubmed-2015-forward": f"{hu_path}/pubmed-2015-fw-lm.pt",
610 "pubmed-2015-backward": f"{hu_path}/pubmed-2015-bw-lm.pt",
611 # Slovenian
612 "sl-forward": f"{hu_path}/lm-sl-opus-large-forward-v0.1.pt",
613 "sl-backward": f"{hu_path}/lm-sl-opus-large-backward-v0.1.pt",
614 "sl-v0-forward": f"{hu_path}/lm-sl-large-forward-v0.1.pt",
615 "sl-v0-backward": f"{hu_path}/lm-sl-large-backward-v0.1.pt",
616 # Swedish
617 "sv-forward": f"{hu_path}/lm-sv-opus-large-forward-v0.1.pt",
618 "sv-backward": f"{hu_path}/lm-sv-opus-large-backward-v0.1.pt",
619 "sv-v0-forward": f"{hu_path}/lm-sv-large-forward-v0.1.pt",
620 "sv-v0-backward": f"{hu_path}/lm-sv-large-backward-v0.1.pt",
621 # Tamil
622 "ta-forward": f"{hu_path}/lm-ta-opus-large-forward-v0.1.pt",
623 "ta-backward": f"{hu_path}/lm-ta-opus-large-backward-v0.1.pt",
624 # Spanish clinical
625 "es-clinical-forward": f"{hu_path}/es-clinical-forward.pt",
626 "es-clinical-backward": f"{hu_path}/es-clinical-backward.pt",
627 # CLEF HIPE Shared task
628 "de-impresso-hipe-v1-forward": f"{clef_hipe_path}/de-hipe-flair-v1-forward/best-lm.pt",
629 "de-impresso-hipe-v1-backward": f"{clef_hipe_path}/de-hipe-flair-v1-backward/best-lm.pt",
630 "en-impresso-hipe-v1-forward": f"{clef_hipe_path}/en-flair-v1-forward/best-lm.pt",
631 "en-impresso-hipe-v1-backward": f"{clef_hipe_path}/en-flair-v1-backward/best-lm.pt",
632 "fr-impresso-hipe-v1-forward": f"{clef_hipe_path}/fr-hipe-flair-v1-forward/best-lm.pt",
633 "fr-impresso-hipe-v1-backward": f"{clef_hipe_path}/fr-hipe-flair-v1-backward/best-lm.pt",
634 # Amharic
635 "am-forward": f"{am_path}/best-lm.pt",
636 }
638 if type(model) == str:
640 # load model if in pretrained model map
641 if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP:
642 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()]
644 # Fix for CLEF HIPE models (avoid overwriting best-lm.pt in cache_dir)
645 if "impresso-hipe" in model.lower():
646 cache_dir = cache_dir / model.lower()
647 # CLEF HIPE models are lowercased
648 self.is_lower = True
649 model = cached_path(base_path, cache_dir=cache_dir)
651 elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP:
652 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[
653 replace_with_language_code(model)
654 ]
655 model = cached_path(base_path, cache_dir=cache_dir)
657 elif not Path(model).exists():
658 raise ValueError(
659 f'The given model "{model}" is not available or is not a valid path.'
660 )
662 from flair.models import LanguageModel
664 if type(model) == LanguageModel:
665 self.lm: LanguageModel = model
666 self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}"
667 else:
668 self.lm: LanguageModel = LanguageModel.load_language_model(model)
669 self.name = str(model)
671 # embeddings are static if we don't do finetuning
672 self.fine_tune = fine_tune
673 self.static_embeddings = not fine_tune
675 self.is_forward_lm: bool = self.lm.is_forward_lm
676 self.with_whitespace: bool = with_whitespace
677 self.tokenized_lm: bool = tokenized_lm
678 self.chars_per_chunk: int = chars_per_chunk
680 # embed a dummy sentence to determine embedding_length
681 dummy_sentence: Sentence = Sentence()
682 dummy_sentence.add_token(Token("hello"))
683 embedded_dummy = self.embed(dummy_sentence)
684 self.__embedding_length: int = len(
685 embedded_dummy[0].get_token(1).get_embedding()
686 )
688 # set to eval mode
689 self.eval()
691 def train(self, mode=True):
693 # make compatible with serialized models (TODO: remove)
694 if "fine_tune" not in self.__dict__:
695 self.fine_tune = False
696 if "chars_per_chunk" not in self.__dict__:
697 self.chars_per_chunk = 512
699 # unless fine-tuning is set, do not set language model to train() in order to disallow language model dropout
700 if not self.fine_tune:
701 pass
702 else:
703 super(FlairEmbeddings, self).train(mode)
705 @property
706 def embedding_length(self) -> int:
707 return self.__embedding_length
709 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
711 # make compatible with serialized models (TODO: remove)
712 if "with_whitespace" not in self.__dict__:
713 self.with_whitespace = True
714 if "tokenized_lm" not in self.__dict__:
715 self.tokenized_lm = True
716 if "is_lower" not in self.__dict__:
717 self.is_lower = False
719 # gradients are enable if fine-tuning is enabled
720 gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()
722 with gradient_context:
724 # if this is not possible, use LM to generate embedding. First, get text sentences
725 text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \
726 else [sentence.to_plain_string() for sentence in sentences]
728 if self.is_lower:
729 text_sentences = [sentence.lower() for sentence in text_sentences]
731 start_marker = self.lm.document_delimiter if "document_delimiter" in self.lm.__dict__ else '\n'
732 end_marker = " "
734 # get hidden states from language model
735 all_hidden_states_in_lm = self.lm.get_representation(
736 text_sentences, start_marker, end_marker, self.chars_per_chunk
737 )
739 if not self.fine_tune:
740 all_hidden_states_in_lm = all_hidden_states_in_lm.detach()
742 # take first or last hidden states from language model as word representation
743 for i, sentence in enumerate(sentences):
744 sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string()
746 offset_forward: int = len(start_marker)
747 offset_backward: int = len(sentence_text) + len(start_marker)
749 for token in sentence.tokens:
751 offset_forward += len(token.text)
752 if self.is_forward_lm:
753 offset_with_whitespace = offset_forward
754 offset_without_whitespace = offset_forward - 1
755 else:
756 offset_with_whitespace = offset_backward
757 offset_without_whitespace = offset_backward - 1
759 # offset mode that extracts at whitespace after last character
760 if self.with_whitespace:
761 embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :]
762 # offset mode that extracts at last character
763 else:
764 embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :]
766 if self.tokenized_lm or token.whitespace_after:
767 offset_forward += 1
768 offset_backward -= 1
770 offset_backward -= len(token.text)
772 # only clone if optimization mode is 'gpu'
773 if flair.embedding_storage_mode == "gpu":
774 embedding = embedding.clone()
776 token.set_embedding(self.name, embedding)
778 del all_hidden_states_in_lm
780 return sentences
782 def __str__(self):
783 return self.name
786class PooledFlairEmbeddings(TokenEmbeddings):
787 def __init__(
788 self,
789 contextual_embeddings: Union[str, FlairEmbeddings],
790 pooling: str = "min",
791 only_capitalized: bool = False,
792 **kwargs,
793 ):
795 super().__init__()
796 self.instance_parameters = self.get_instance_parameters(locals=locals())
798 # use the character language model embeddings as basis
799 if type(contextual_embeddings) is str:
800 self.context_embeddings: FlairEmbeddings = FlairEmbeddings(
801 contextual_embeddings, **kwargs
802 )
803 else:
804 self.context_embeddings: FlairEmbeddings = contextual_embeddings
806 # length is twice the original character LM embedding length
807 self.embedding_length = self.context_embeddings.embedding_length * 2
808 self.name = self.context_embeddings.name + "-context"
810 # these fields are for the embedding memory
811 self.word_embeddings = {}
812 self.word_count = {}
814 # whether to add only capitalized words to memory (faster runtime and lower memory consumption)
815 self.only_capitalized = only_capitalized
817 # we re-compute embeddings dynamically at each epoch
818 self.static_embeddings = False
820 # set the memory method
821 self.pooling = pooling
823 def train(self, mode=True):
824 super().train(mode=mode)
825 if mode:
826 # memory is wiped each time we do a training run
827 print("train mode resetting embeddings")
828 self.word_embeddings = {}
829 self.word_count = {}
831 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
833 self.context_embeddings.embed(sentences)
835 # if we keep a pooling, it needs to be updated continuously
836 for sentence in sentences:
837 for token in sentence.tokens:
839 # update embedding
840 local_embedding = token._embeddings[self.context_embeddings.name].cpu()
842 # check token.text is empty or not
843 if token.text:
844 if token.text[0].isupper() or not self.only_capitalized:
846 if token.text not in self.word_embeddings:
847 self.word_embeddings[token.text] = local_embedding
848 self.word_count[token.text] = 1
849 else:
851 # set aggregation operation
852 if self.pooling == "mean":
853 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
854 elif self.pooling == "fade":
855 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
856 aggregated_embedding /= 2
857 elif self.pooling == "max":
858 aggregated_embedding = torch.max(self.word_embeddings[token.text], local_embedding)
859 elif self.pooling == "min":
860 aggregated_embedding = torch.min(self.word_embeddings[token.text], local_embedding)
862 self.word_embeddings[token.text] = aggregated_embedding
863 self.word_count[token.text] += 1
865 # add embeddings after updating
866 for sentence in sentences:
867 for token in sentence.tokens:
868 if token.text in self.word_embeddings:
869 base = (
870 self.word_embeddings[token.text] / self.word_count[token.text]
871 if self.pooling == "mean"
872 else self.word_embeddings[token.text]
873 )
874 else:
875 base = token._embeddings[self.context_embeddings.name]
877 token.set_embedding(self.name, base)
879 return sentences
881 def embedding_length(self) -> int:
882 return self.embedding_length
884 def get_names(self) -> List[str]:
885 return [self.name, self.context_embeddings.name]
887 def __setstate__(self, d):
888 self.__dict__ = d
890 if flair.device != 'cpu':
891 for key in self.word_embeddings:
892 self.word_embeddings[key] = self.word_embeddings[key].cpu()
895class TransformerWordEmbeddings(TokenEmbeddings):
896 NO_MAX_SEQ_LENGTH_MODELS = [XLNetModel, TransfoXLModel]
898 def __init__(
899 self,
900 model: str = "bert-base-uncased",
901 layers: str = "all",
902 subtoken_pooling: str = "first",
903 layer_mean: bool = True,
904 fine_tune: bool = False,
905 allow_long_sentences: bool = True,
906 use_context: Union[bool, int] = False,
907 memory_effective_training: bool = True,
908 respect_document_boundaries: bool = True,
909 context_dropout: float = 0.5,
910 **kwargs
911 ):
912 """
913 Bidirectional transformer embeddings of words from various transformer architectures.
914 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
915 options)
916 :param layers: string indicating which layers to take for embedding (-1 is topmost layer)
917 :param subtoken_pooling: how to get from token piece embeddings to token embedding. Either take the first
918 subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean')
919 :param layer_mean: If True, uses a scalar mix of layers as embedding
920 :param fine_tune: If True, allows transformers to be fine-tuned during training
921 """
922 super().__init__()
923 self.instance_parameters = self.get_instance_parameters(locals=locals())
925 # temporary fix to disable tokenizer parallelism warning
926 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning)
927 import os
928 os.environ["TOKENIZERS_PARALLELISM"] = "false"
930 # do not print transformer warnings as these are confusing in this case
931 from transformers import logging
932 logging.set_verbosity_error()
934 # load tokenizer and transformer model
935 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
936 if self.tokenizer.model_max_length > 1000000000:
937 self.tokenizer.model_max_length = 512
938 log.info("No model_max_length in Tokenizer's config.json - setting it to 512. "
939 "Specify desired model_max_length by passing it as attribute to embedding instance.")
940 if not 'config' in kwargs:
941 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
942 self.model = AutoModel.from_pretrained(model, config=config)
943 else:
944 self.model = AutoModel.from_pretrained(None, **kwargs)
946 logging.set_verbosity_warning()
948 if type(self.model) not in self.NO_MAX_SEQ_LENGTH_MODELS:
949 self.allow_long_sentences = allow_long_sentences
950 self.truncate = True
951 self.max_subtokens_sequence_length = self.tokenizer.model_max_length
952 self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0
953 else:
954 # in the end, these models don't need this configuration
955 self.allow_long_sentences = False
956 self.truncate = False
957 self.max_subtokens_sequence_length = None
958 self.stride = 0
960 self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb
962 # model name
963 self.name = 'transformer-word-' + str(model)
964 self.base_model = str(model)
966 # whether to detach gradients on overlong sentences
967 self.memory_effective_training = memory_effective_training
969 # store whether to use context (and how much)
970 if type(use_context) == bool:
971 self.context_length: int = 64 if use_context else 0
972 if type(use_context) == int:
973 self.context_length: int = use_context
975 # dropout contexts
976 self.context_dropout = context_dropout
978 # if using context, can we cross document boundaries?
979 self.respect_document_boundaries = respect_document_boundaries
981 # send self to flair-device
982 self.to(flair.device)
984 # embedding parameters
985 if layers == 'all':
986 # send mini-token through to check how many layers the model has
987 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1]
988 self.layer_indexes = [int(x) for x in range(len(hidden_states))]
989 else:
990 self.layer_indexes = [int(x) for x in layers.split(",")]
992 self.pooling_operation = subtoken_pooling
993 self.layer_mean = layer_mean
994 self.fine_tune = fine_tune
995 self.static_embeddings = not self.fine_tune
997 # calculate embedding length
998 if not self.layer_mean:
999 length = len(self.layer_indexes) * self.model.config.hidden_size
1000 else:
1001 length = self.model.config.hidden_size
1002 if self.pooling_operation == 'first_last': length *= 2
1004 # return length
1005 self.embedding_length_internal = length
1007 self.special_tokens = []
1008 # check if special tokens exist to circumvent error message
1009 if self.tokenizer._bos_token:
1010 self.special_tokens.append(self.tokenizer.bos_token)
1011 if self.tokenizer._cls_token:
1012 self.special_tokens.append(self.tokenizer.cls_token)
1014 # most models have an intial BOS token, except for XLNet, T5 and GPT2
1015 self.begin_offset = self._get_begin_offset_of_tokenizer(tokenizer=self.tokenizer)
1017 # when initializing, embeddings are in eval mode by default
1018 self.eval()
1020 @staticmethod
1021 def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int:
1022 test_string = 'a'
1023 tokens = tokenizer.encode(test_string)
1025 for begin_offset, token in enumerate(tokens):
1026 if tokenizer.decode([token]) == test_string or tokenizer.decode([token]) == tokenizer.unk_token:
1027 break
1028 return begin_offset
1030 @staticmethod
1031 def _remove_special_markup(text: str):
1032 # remove special markup
1033 text = re.sub('^Ġ', '', text) # RoBERTa models
1034 text = re.sub('^##', '', text) # BERT models
1035 text = re.sub('^▁', '', text) # XLNet models
1036 text = re.sub('</w>$', '', text) # XLM models
1037 return text
1039 def _get_processed_token_text(self, token: Token) -> str:
1040 pieces = self.tokenizer.tokenize(token.text)
1041 token_text = ''
1042 for piece in pieces:
1043 token_text += self._remove_special_markup(piece)
1044 token_text = token_text.lower()
1045 return token_text
1047 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1049 # we require encoded subtokenized sentences, the mapping to original tokens and the number of
1050 # parts that each sentence produces
1051 subtokenized_sentences = []
1052 all_token_subtoken_lengths = []
1054 # if we also use context, first expand sentence to include context
1055 if self.context_length > 0:
1057 # set context if not set already
1058 previous_sentence = None
1059 for sentence in sentences:
1060 if sentence.is_context_set(): continue
1061 sentence._previous_sentence = previous_sentence
1062 sentence._next_sentence = None
1063 if previous_sentence: previous_sentence._next_sentence = sentence
1064 previous_sentence = sentence
1066 original_sentences = []
1067 expanded_sentences = []
1068 context_offsets = []
1070 for sentence in sentences:
1071 # in case of contextualization, we must remember non-expanded sentence
1072 original_sentence = sentence
1073 original_sentences.append(original_sentence)
1075 # create expanded sentence and remember context offsets
1076 expanded_sentence, context_offset = self._expand_sentence_with_context(sentence)
1077 expanded_sentences.append(expanded_sentence)
1078 context_offsets.append(context_offset)
1080 # overwrite sentence with expanded sentence
1081 sentence = expanded_sentence
1083 sentences = expanded_sentences
1085 tokenized_sentences = []
1086 for sentence in sentences:
1088 # subtokenize the sentence
1089 tokenized_string = sentence.to_tokenized_string()
1091 # transformer specific tokenization
1092 subtokenized_sentence = self.tokenizer.tokenize(tokenized_string)
1094 # set zero embeddings for empty sentences and exclude
1095 if len(subtokenized_sentence) == 0:
1096 for token in sentence:
1097 token.set_embedding(self.name, torch.zeros(self.embedding_length))
1098 continue
1100 # determine into how many subtokens each token is split
1101 token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence)
1103 # remember tokenized sentences and their subtokenization
1104 tokenized_sentences.append(tokenized_string)
1105 all_token_subtoken_lengths.append(token_subtoken_lengths)
1107 # encode inputs
1108 batch_encoding = self.tokenizer(tokenized_sentences,
1109 max_length=self.max_subtokens_sequence_length,
1110 stride=self.stride,
1111 return_overflowing_tokens=self.allow_long_sentences,
1112 truncation=self.truncate,
1113 padding=True,
1114 return_tensors='pt',
1115 )
1117 model_kwargs = {}
1118 input_ids = batch_encoding['input_ids'].to(flair.device)
1120 # Models such as FNet do not have an attention_mask
1121 if 'attention_mask' in batch_encoding:
1122 model_kwargs['attention_mask'] = batch_encoding['attention_mask'].to(flair.device)
1124 # determine which sentence was split into how many parts
1125 sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \
1126 else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist()
1128 # set language IDs for XLM-style transformers
1129 if self.use_lang_emb:
1130 model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype)
1132 for s_id, sentence in enumerate(tokenized_sentences):
1133 sequence_length = len(sentence)
1134 lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0)
1135 model_kwargs["langs"][s_id][:sequence_length] = lang_id
1137 # put encoded batch through transformer model to get all hidden states of all encoder layers
1138 hidden_states = self.model(input_ids, **model_kwargs)[-1]
1139 # make the tuple a tensor; makes working with it easier.
1140 hidden_states = torch.stack(hidden_states)
1142 sentence_idx_offset = 0
1144 # gradients are enabled if fine-tuning is enabled
1145 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
1147 with gradient_context:
1149 # iterate over all subtokenized sentences
1150 for sentence_idx, (sentence, subtoken_lengths, nr_sentence_parts) in enumerate(
1151 zip(sentences, all_token_subtoken_lengths, sentence_parts_lengths)):
1153 sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...]
1155 for i in range(1, nr_sentence_parts):
1156 sentence_idx_offset += 1
1157 remainder_sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...]
1158 # remove stride_size//2 at end of sentence_hidden_state, and half at beginning of remainder,
1159 # in order to get some context into the embeddings of these words.
1160 # also don't include the embedding of the extra [CLS] and [SEP] tokens.
1161 sentence_hidden_state = torch.cat((sentence_hidden_state[:, :-1 - self.stride // 2, :],
1162 remainder_sentence_hidden_state[:, 1 + self.stride // 2:,
1163 :]), 1)
1165 subword_start_idx = self.begin_offset
1167 # for each token, get embedding
1168 for token_idx, (token, number_of_subtokens) in enumerate(zip(sentence, subtoken_lengths)):
1170 # some tokens have no subtokens at all (if omitted by BERT tokenizer) so return zero vector
1171 if number_of_subtokens == 0:
1172 token.set_embedding(self.name, torch.zeros(self.embedding_length))
1173 continue
1175 subword_end_idx = subword_start_idx + number_of_subtokens
1177 subtoken_embeddings: List[torch.FloatTensor] = []
1179 # get states from all selected layers, aggregate with pooling operation
1180 for layer in self.layer_indexes:
1181 current_embeddings = sentence_hidden_state[layer][subword_start_idx:subword_end_idx]
1183 if self.pooling_operation == "first":
1184 final_embedding: torch.FloatTensor = current_embeddings[0]
1186 if self.pooling_operation == "last":
1187 final_embedding: torch.FloatTensor = current_embeddings[-1]
1189 if self.pooling_operation == "first_last":
1190 final_embedding: torch.Tensor = torch.cat(
1191 [current_embeddings[0], current_embeddings[-1]])
1193 if self.pooling_operation == "mean":
1194 all_embeddings: List[torch.FloatTensor] = [
1195 embedding.unsqueeze(0) for embedding in current_embeddings
1196 ]
1197 final_embedding: torch.Tensor = torch.mean(torch.cat(all_embeddings, dim=0), dim=0)
1199 subtoken_embeddings.append(final_embedding)
1201 # use layer mean of embeddings if so selected
1202 if self.layer_mean and len(self.layer_indexes) > 1:
1203 sm_embeddings = torch.mean(torch.stack(subtoken_embeddings, dim=1), dim=1)
1204 subtoken_embeddings = [sm_embeddings]
1206 # set the extracted embedding for the token
1207 token.set_embedding(self.name, torch.cat(subtoken_embeddings))
1209 subword_start_idx += number_of_subtokens
1211 # move embeddings from context back to original sentence (if using context)
1212 if self.context_length > 0:
1213 for original_sentence, expanded_sentence, context_offset in zip(original_sentences,
1214 sentences,
1215 context_offsets):
1216 for token_idx, token in enumerate(original_sentence):
1217 token.set_embedding(self.name,
1218 expanded_sentence[token_idx + context_offset].get_embedding(self.name))
1219 sentence = original_sentence
1221 def _expand_sentence_with_context(self, sentence):
1223 # remember original sentence
1224 original_sentence = sentence
1226 import random
1227 expand_context = False if self.training and random.randint(1, 100) <= (self.context_dropout * 100) else True
1229 left_context = ''
1230 right_context = ''
1232 if expand_context:
1234 # get left context
1235 while True:
1236 sentence = sentence.previous_sentence()
1237 if sentence is None: break
1239 if self.respect_document_boundaries and sentence.is_document_boundary: break
1241 left_context = sentence.to_tokenized_string() + ' ' + left_context
1242 left_context = left_context.strip()
1243 if len(left_context.split(" ")) > self.context_length:
1244 left_context = " ".join(left_context.split(" ")[-self.context_length:])
1245 break
1246 original_sentence.left_context = left_context
1248 sentence = original_sentence
1250 # get right context
1251 while True:
1252 sentence = sentence.next_sentence()
1253 if sentence is None: break
1254 if self.respect_document_boundaries and sentence.is_document_boundary: break
1256 right_context += ' ' + sentence.to_tokenized_string()
1257 right_context = right_context.strip()
1258 if len(right_context.split(" ")) > self.context_length:
1259 right_context = " ".join(right_context.split(" ")[:self.context_length])
1260 break
1262 original_sentence.right_context = right_context
1264 left_context_split = left_context.split(" ")
1265 right_context_split = right_context.split(" ")
1267 # empty contexts should not introduce whitespace tokens
1268 if left_context_split == [""]: left_context_split = []
1269 if right_context_split == [""]: right_context_split = []
1271 # make expanded sentence
1272 expanded_sentence = Sentence()
1273 expanded_sentence.tokens = [Token(token) for token in left_context_split +
1274 original_sentence.to_tokenized_string().split(" ") +
1275 right_context_split]
1277 context_length = len(left_context_split)
1278 return expanded_sentence, context_length
1280 def reconstruct_tokens_from_subtokens(self, sentence, subtokens):
1281 word_iterator = iter(sentence)
1282 token = next(word_iterator)
1283 token_text = self._get_processed_token_text(token)
1284 token_subtoken_lengths = []
1285 reconstructed_token = ''
1286 subtoken_count = 0
1287 # iterate over subtokens and reconstruct tokens
1288 for subtoken_id, subtoken in enumerate(subtokens):
1290 # remove special markup
1291 subtoken = self._remove_special_markup(subtoken)
1293 # TODO check if this is necessary is this method is called before prepare_for_model
1294 # check if reconstructed token is special begin token ([CLS] or similar)
1295 if subtoken in self.special_tokens and subtoken_id == 0:
1296 continue
1298 # some BERT tokenizers somehow omit words - in such cases skip to next token
1299 if subtoken_count == 0 and not token_text.startswith(subtoken.lower()):
1301 while True:
1302 token_subtoken_lengths.append(0)
1303 token = next(word_iterator)
1304 token_text = self._get_processed_token_text(token)
1305 if token_text.startswith(subtoken.lower()): break
1307 subtoken_count += 1
1309 # append subtoken to reconstruct token
1310 reconstructed_token = reconstructed_token + subtoken
1312 # check if reconstructed token is the same as current token
1313 if reconstructed_token.lower() == token_text:
1315 # if so, add subtoken count
1316 token_subtoken_lengths.append(subtoken_count)
1318 # reset subtoken count and reconstructed token
1319 reconstructed_token = ''
1320 subtoken_count = 0
1322 # break from loop if all tokens are accounted for
1323 if len(token_subtoken_lengths) < len(sentence):
1324 token = next(word_iterator)
1325 token_text = self._get_processed_token_text(token)
1326 else:
1327 break
1329 # if tokens are unaccounted for
1330 while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1:
1331 token_subtoken_lengths.append(0)
1332 if len(token_subtoken_lengths) == len(sentence): break
1333 token = next(word_iterator)
1335 # check if all tokens were matched to subtokens
1336 if token != sentence[-1]:
1337 log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'")
1338 log.error(f"Last matched: '{token}'")
1339 log.error(f"Last sentence: '{sentence[-1]}'")
1340 log.error(f"subtokenized: '{subtokens}'")
1341 return token_subtoken_lengths
1343 @property
1344 def embedding_length(self) -> int:
1346 if "embedding_length_internal" in self.__dict__.keys():
1347 return self.embedding_length_internal
1349 # """Returns the length of the embedding vector."""
1350 if not self.layer_mean:
1351 length = len(self.layer_indexes) * self.model.config.hidden_size
1352 else:
1353 length = self.model.config.hidden_size
1355 if self.pooling_operation == 'first_last': length *= 2
1357 self.__embedding_length = length
1359 return length
1361 def __getstate__(self):
1362 # special handling for serializing transformer models
1363 config_state_dict = self.model.config.__dict__
1364 model_state_dict = self.model.state_dict()
1366 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-word-')[-1]
1368 # serialize the transformer models and the constructor arguments (but nothing else)
1369 model_state = {
1370 "config_state_dict": config_state_dict,
1371 "model_state_dict": model_state_dict,
1372 "embedding_length_internal": self.embedding_length,
1374 "base_model_name": self.base_model_name,
1375 "name": self.name,
1376 "layer_indexes": self.layer_indexes,
1377 "subtoken_pooling": self.pooling_operation,
1378 "context_length": self.context_length,
1379 "layer_mean": self.layer_mean,
1380 "fine_tune": self.fine_tune,
1381 "allow_long_sentences": self.allow_long_sentences,
1382 "memory_effective_training": self.memory_effective_training,
1383 "respect_document_boundaries": self.respect_document_boundaries,
1384 "context_dropout": self.context_dropout,
1385 }
1387 return model_state
1389 def __setstate__(self, d):
1390 self.__dict__ = d
1392 # necessary for reverse compatibility with Flair <= 0.7
1393 if 'use_scalar_mix' in self.__dict__.keys():
1394 self.__dict__['layer_mean'] = d['use_scalar_mix']
1395 if not 'memory_effective_training' in self.__dict__.keys():
1396 self.__dict__['memory_effective_training'] = True
1397 if 'pooling_operation' in self.__dict__.keys():
1398 self.__dict__['subtoken_pooling'] = d['pooling_operation']
1399 if not 'context_length' in self.__dict__.keys():
1400 self.__dict__['context_length'] = 0
1401 if 'use_context' in self.__dict__.keys():
1402 self.__dict__['context_length'] = 64 if self.__dict__['use_context'] == True else 0
1404 if not 'context_dropout' in self.__dict__.keys():
1405 self.__dict__['context_dropout'] = 0.5
1406 if not 'respect_document_boundaries' in self.__dict__.keys():
1407 self.__dict__['respect_document_boundaries'] = True
1408 if not 'memory_effective_training' in self.__dict__.keys():
1409 self.__dict__['memory_effective_training'] = True
1410 if not 'base_model_name' in self.__dict__.keys():
1411 self.__dict__['base_model_name'] = self.__dict__['name'].split('transformer-word-')[-1]
1413 # special handling for deserializing transformer models
1414 if "config_state_dict" in d:
1416 # load transformer model
1417 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert"
1418 config_class = CONFIG_MAPPING[model_type]
1419 loaded_config = config_class.from_dict(d["config_state_dict"])
1421 # constructor arguments
1422 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']])
1424 # re-initialize transformer word embeddings with constructor arguments
1425 embedding = TransformerWordEmbeddings(
1426 model=self.__dict__['base_model_name'],
1427 layers=layers,
1428 subtoken_pooling=self.__dict__['subtoken_pooling'],
1429 use_context=self.__dict__['context_length'],
1430 layer_mean=self.__dict__['layer_mean'],
1431 fine_tune=self.__dict__['fine_tune'],
1432 allow_long_sentences=self.__dict__['allow_long_sentences'],
1433 respect_document_boundaries=self.__dict__['respect_document_boundaries'],
1434 memory_effective_training=self.__dict__['memory_effective_training'],
1435 context_dropout=self.__dict__['context_dropout'],
1437 config=loaded_config,
1438 state_dict=d["model_state_dict"],
1439 )
1441 # I have no idea why this is necessary, but otherwise it doesn't work
1442 for key in embedding.__dict__.keys():
1443 self.__dict__[key] = embedding.__dict__[key]
1445 else:
1447 # reload tokenizer to get around serialization issues
1448 model_name = self.__dict__['name'].split('transformer-word-')[-1]
1449 try:
1450 tokenizer = AutoTokenizer.from_pretrained(model_name)
1451 except:
1452 pass
1454 self.tokenizer = tokenizer
1457class FastTextEmbeddings(TokenEmbeddings):
1458 """FastText Embeddings with oov functionality"""
1460 def __init__(self, embeddings: str, use_local: bool = True, field: str = None):
1461 """
1462 Initializes fasttext word embeddings. Constructor downloads required embedding file and stores in cache
1463 if use_local is False.
1465 :param embeddings: path to your embeddings '.bin' file
1466 :param use_local: set this to False if you are using embeddings from a remote source
1467 """
1468 self.instance_parameters = self.get_instance_parameters(locals=locals())
1470 cache_dir = Path("embeddings")
1472 if use_local:
1473 if not Path(embeddings).exists():
1474 raise ValueError(
1475 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
1476 )
1477 else:
1478 embeddings = cached_path(f"{embeddings}", cache_dir=cache_dir)
1480 self.embeddings = embeddings
1482 self.name: str = str(embeddings)
1484 self.static_embeddings = True
1486 self.precomputed_word_embeddings: gensim.models.FastText = gensim.models.FastText.load_fasttext_format(
1487 str(embeddings)
1488 )
1489 print(self.precomputed_word_embeddings)
1491 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
1493 self.field = field
1494 super().__init__()
1496 @property
1497 def embedding_length(self) -> int:
1498 return self.__embedding_length
1500 @instance_lru_cache(maxsize=10000, typed=False)
1501 def get_cached_vec(self, word: str) -> torch.Tensor:
1502 try:
1503 word_embedding = self.precomputed_word_embeddings.wv[word]
1504 except:
1505 word_embedding = np.zeros(self.embedding_length, dtype="float")
1507 word_embedding = torch.tensor(
1508 word_embedding.tolist(), device=flair.device, dtype=torch.float
1509 )
1510 return word_embedding
1512 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1514 for i, sentence in enumerate(sentences):
1516 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1518 if "field" not in self.__dict__ or self.field is None:
1519 word = token.text
1520 else:
1521 word = token.get_tag(self.field).value
1523 word_embedding = self.get_cached_vec(word)
1525 token.set_embedding(self.name, word_embedding)
1527 return sentences
1529 def __str__(self):
1530 return self.name
1532 def extra_repr(self):
1533 return f"'{self.embeddings}'"
1536class OneHotEmbeddings(TokenEmbeddings):
1537 """One-hot encoded embeddings. """
1539 def __init__(
1540 self,
1541 vocab_dictionary: Dictionary,
1542 field: str = "text",
1543 embedding_length: int = 300,
1544 stable: bool = False,
1545 ):
1546 """
1547 Initializes one-hot encoded word embeddings and a trainable embedding layer
1548 :param vocab_dictionary: the vocabulary that will be encoded
1549 :param field: by default, the 'text' of tokens is embedded, but you can also embed tags such as 'pos'
1550 :param embedding_length: dimensionality of the trainable embedding layer
1551 :param stable: set stable=True to use the stable embeddings as described in https://arxiv.org/abs/2110.02861
1552 """
1553 super().__init__()
1554 self.name = f"one-hot-{field}"
1555 self.static_embeddings = False
1556 self.field = field
1557 self.instance_parameters = self.get_instance_parameters(locals=locals())
1558 self.__embedding_length = embedding_length
1559 self.vocab_dictionary = vocab_dictionary
1561 print(self.vocab_dictionary.idx2item)
1562 print(f"vocabulary size of {len(self.vocab_dictionary)}")
1564 # model architecture
1565 self.embedding_layer = torch.nn.Embedding(
1566 len(self.vocab_dictionary), self.__embedding_length
1567 )
1568 torch.nn.init.xavier_uniform_(self.embedding_layer.weight)
1569 if stable:
1570 self.layer_norm = torch.nn.LayerNorm(embedding_length)
1571 else:
1572 self.layer_norm = None
1574 self.to(flair.device)
1576 @property
1577 def embedding_length(self) -> int:
1578 return self.__embedding_length
1580 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1582 tokens = [
1583 t
1584 for sentence in sentences
1585 for t in sentence.tokens
1586 ]
1588 if self.field == "text":
1589 one_hot_sentences = [
1590 self.vocab_dictionary.get_idx_for_item(t.text)
1591 for t in tokens
1592 ]
1593 else:
1594 one_hot_sentences = [
1595 self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value)
1596 for t in tokens
1597 ]
1599 one_hot_sentences = torch.tensor(one_hot_sentences, dtype=torch.long).to(
1600 flair.device
1601 )
1603 embedded = self.embedding_layer.forward(one_hot_sentences)
1604 if self.layer_norm:
1605 embedded = self.layer_norm(embedded)
1607 for emb, token in zip(embedded, tokens):
1608 token.set_embedding(self.name, emb)
1610 return sentences
1612 def __str__(self):
1613 return self.name
1615 @classmethod
1616 def from_corpus(
1617 cls,
1618 corpus: Corpus,
1619 field: str = "text",
1620 min_freq: int = 3,
1621 **kwargs
1622 ):
1623 vocab_dictionary = Dictionary()
1625 tokens = list(map((lambda s: s.tokens), corpus.train))
1626 tokens = [token for sublist in tokens for token in sublist]
1628 if field == "text":
1629 most_common = Counter(list(map((lambda t: t.text), tokens))).most_common()
1630 else:
1631 most_common = Counter(
1632 list(map((lambda t: t.get_tag(field).value), tokens))
1633 ).most_common()
1635 tokens = []
1636 for token, freq in most_common:
1637 if freq < min_freq:
1638 break
1639 tokens.append(token)
1641 for token in tokens:
1642 vocab_dictionary.add_item(token)
1644 return cls(vocab_dictionary, field=field, **kwargs)
1647class HashEmbeddings(TokenEmbeddings):
1648 """Standard embeddings with Hashing Trick."""
1650 def __init__(
1651 self, num_embeddings: int = 1000, embedding_length: int = 300, hash_method="md5"
1652 ):
1654 super().__init__()
1655 self.name = "hash"
1656 self.static_embeddings = False
1657 self.instance_parameters = self.get_instance_parameters(locals=locals())
1659 self.__num_embeddings = num_embeddings
1660 self.__embedding_length = embedding_length
1662 self.__hash_method = hash_method
1664 # model architecture
1665 self.embedding_layer = torch.nn.Embedding(
1666 self.__num_embeddings, self.__embedding_length
1667 )
1668 torch.nn.init.xavier_uniform_(self.embedding_layer.weight)
1670 self.to(flair.device)
1672 @property
1673 def num_embeddings(self) -> int:
1674 return self.__num_embeddings
1676 @property
1677 def embedding_length(self) -> int:
1678 return self.__embedding_length
1680 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1681 def get_idx_for_item(text):
1682 hash_function = hashlib.new(self.__hash_method)
1683 hash_function.update(bytes(str(text), "utf-8"))
1684 return int(hash_function.hexdigest(), 16) % self.__num_embeddings
1686 hash_sentences = []
1687 for i, sentence in enumerate(sentences):
1688 context_idxs = [get_idx_for_item(t.text) for t in sentence.tokens]
1690 hash_sentences.extend(context_idxs)
1692 hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(flair.device)
1694 embedded = self.embedding_layer.forward(hash_sentences)
1696 index = 0
1697 for sentence in sentences:
1698 for token in sentence:
1699 embedding = embedded[index]
1700 token.set_embedding(self.name, embedding)
1701 index += 1
1703 return sentences
1705 def __str__(self):
1706 return self.name
1709class MuseCrosslingualEmbeddings(TokenEmbeddings):
1710 def __init__(self, ):
1711 self.name: str = f"muse-crosslingual"
1712 self.static_embeddings = True
1713 self.__embedding_length: int = 300
1714 self.language_embeddings = {}
1715 super().__init__()
1717 @instance_lru_cache(maxsize=10000, typed=False)
1718 def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor:
1719 current_embedding_model = self.language_embeddings[language_code]
1720 if word in current_embedding_model:
1721 word_embedding = current_embedding_model[word]
1722 elif word.lower() in current_embedding_model:
1723 word_embedding = current_embedding_model[word.lower()]
1724 elif re.sub(r"\d", "#", word.lower()) in current_embedding_model:
1725 word_embedding = current_embedding_model[re.sub(r"\d", "#", word.lower())]
1726 elif re.sub(r"\d", "0", word.lower()) in current_embedding_model:
1727 word_embedding = current_embedding_model[re.sub(r"\d", "0", word.lower())]
1728 else:
1729 word_embedding = np.zeros(self.embedding_length, dtype="float")
1730 word_embedding = torch.tensor(
1731 word_embedding, device=flair.device, dtype=torch.float
1732 )
1733 return word_embedding
1735 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1737 for i, sentence in enumerate(sentences):
1739 language_code = sentence.get_language_code()
1740 supported = [
1741 "en",
1742 "de",
1743 "bg",
1744 "ca",
1745 "hr",
1746 "cs",
1747 "da",
1748 "nl",
1749 "et",
1750 "fi",
1751 "fr",
1752 "el",
1753 "he",
1754 "hu",
1755 "id",
1756 "it",
1757 "mk",
1758 "no",
1759 # "pl",
1760 "pt",
1761 "ro",
1762 "ru",
1763 "sk",
1764 ]
1765 if language_code not in supported:
1766 language_code = "en"
1768 if language_code not in self.language_embeddings:
1769 log.info(f"Loading up MUSE embeddings for '{language_code}'!")
1770 # download if necessary
1771 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/muse"
1772 cache_dir = Path("embeddings") / "MUSE"
1773 cached_path(
1774 f"{hu_path}/muse.{language_code}.vec.gensim.vectors.npy",
1775 cache_dir=cache_dir,
1776 )
1777 embeddings_file = cached_path(
1778 f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir
1779 )
1781 # load the model
1782 self.language_embeddings[
1783 language_code
1784 ] = gensim.models.KeyedVectors.load(str(embeddings_file))
1786 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1788 if "field" not in self.__dict__ or self.field is None:
1789 word = token.text
1790 else:
1791 word = token.get_tag(self.field).value
1793 word_embedding = self.get_cached_vec(
1794 language_code=language_code, word=word
1795 )
1797 token.set_embedding(self.name, word_embedding)
1799 return sentences
1801 @property
1802 def embedding_length(self) -> int:
1803 return self.__embedding_length
1805 def __str__(self):
1806 return self.name
1809# TODO: keep for backwards compatibility, but remove in future
1810class BPEmbSerializable(BPEmb):
1811 def __getstate__(self):
1812 state = self.__dict__.copy()
1813 # save the sentence piece model as binary file (not as path which may change)
1814 state["spm_model_binary"] = open(self.model_file, mode="rb").read()
1815 state["spm"] = None
1816 return state
1818 def __setstate__(self, state):
1819 from bpemb.util import sentencepiece_load
1821 model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"])
1822 self.__dict__ = state
1824 # write out the binary sentence piece model into the expected directory
1825 self.cache_dir: Path = flair.cache_root / "embeddings"
1826 if "spm_model_binary" in self.__dict__:
1827 # if the model was saved as binary and it is not found on disk, write to appropriate path
1828 if not os.path.exists(self.cache_dir / state["lang"]):
1829 os.makedirs(self.cache_dir / state["lang"])
1830 self.model_file = self.cache_dir / model_file
1831 with open(self.model_file, "wb") as out:
1832 out.write(self.__dict__["spm_model_binary"])
1833 else:
1834 # otherwise, use normal process and potentially trigger another download
1835 self.model_file = self._load_file(model_file)
1837 # once the modes if there, load it with sentence piece
1838 state["spm"] = sentencepiece_load(self.model_file)
1841class BytePairEmbeddings(TokenEmbeddings):
1842 def __init__(
1843 self,
1844 language: str = None,
1845 dim: int = 50,
1846 syllables: int = 100000,
1847 cache_dir=None,
1848 model_file_path: Path = None,
1849 embedding_file_path: Path = None,
1850 **kwargs,
1851 ):
1852 """
1853 Initializes BP embeddings. Constructor downloads required files if not there.
1854 """
1855 self.instance_parameters = self.get_instance_parameters(locals=locals())
1857 if not cache_dir:
1858 cache_dir = flair.cache_root / "embeddings"
1859 if language:
1860 self.name: str = f"bpe-{language}-{syllables}-{dim}"
1861 else:
1862 assert (
1863 model_file_path is not None and embedding_file_path is not None
1864 ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)"
1865 dim = None
1867 self.embedder = BPEmbSerializable(
1868 lang=language,
1869 vs=syllables,
1870 dim=dim,
1871 cache_dir=cache_dir,
1872 model_file=model_file_path,
1873 emb_file=embedding_file_path,
1874 **kwargs,
1875 )
1877 if not language:
1878 self.name: str = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}"
1879 self.static_embeddings = True
1881 self.__embedding_length: int = self.embedder.emb.vector_size * 2
1882 super().__init__()
1884 @property
1885 def embedding_length(self) -> int:
1886 return self.__embedding_length
1888 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1890 for i, sentence in enumerate(sentences):
1892 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1894 if "field" not in self.__dict__ or self.field is None:
1895 word = token.text
1896 else:
1897 word = token.get_tag(self.field).value
1899 if word.strip() == "":
1900 # empty words get no embedding
1901 token.set_embedding(
1902 self.name, torch.zeros(self.embedding_length, dtype=torch.float)
1903 )
1904 else:
1905 # all other words get embedded
1906 embeddings = self.embedder.embed(word.lower())
1907 embedding = np.concatenate(
1908 (embeddings[0], embeddings[len(embeddings) - 1])
1909 )
1910 token.set_embedding(
1911 self.name, torch.tensor(embedding, dtype=torch.float)
1912 )
1914 return sentences
1916 def __str__(self):
1917 return self.name
1919 def extra_repr(self):
1920 return "model={}".format(self.name)
1923class ELMoEmbeddings(TokenEmbeddings):
1924 """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018.
1925 ELMo word vectors can be constructed by combining layers in different ways.
1926 Default is to concatene the top 3 layers in the LM."""
1928 def __init__(
1929 self, model: str = "original", options_file: str = None, weight_file: str = None,
1930 embedding_mode: str = "all"
1931 ):
1932 super().__init__()
1934 self.instance_parameters = self.get_instance_parameters(locals=locals())
1936 try:
1937 import allennlp.commands.elmo
1938 except ModuleNotFoundError:
1939 log.warning("-" * 100)
1940 log.warning('ATTENTION! The library "allennlp" is not installed!')
1941 log.warning(
1942 'To use ELMoEmbeddings, please first install with "pip install allennlp==0.9.0"'
1943 )
1944 log.warning("-" * 100)
1945 pass
1947 assert embedding_mode in ["all", "top", "average"]
1949 self.name = f"elmo-{model}-{embedding_mode}"
1950 self.static_embeddings = True
1952 if not options_file or not weight_file:
1953 # the default model for ELMo is the 'original' model, which is very large
1954 options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
1955 weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
1956 # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
1957 if model == "small":
1958 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
1959 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
1960 if model == "medium":
1961 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
1962 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
1963 if model in ["large", "5.5B"]:
1964 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
1965 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
1966 if model == "pt" or model == "portuguese":
1967 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
1968 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
1969 if model == "pubmed":
1970 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
1971 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
1973 if embedding_mode == "all":
1974 self.embedding_mode_fn = self.use_layers_all
1975 elif embedding_mode == "top":
1976 self.embedding_mode_fn = self.use_layers_top
1977 elif embedding_mode == "average":
1978 self.embedding_mode_fn = self.use_layers_average
1980 # put on Cuda if available
1981 from flair import device
1983 if re.fullmatch(r"cuda:[0-9]+", str(device)):
1984 cuda_device = int(str(device).split(":")[-1])
1985 elif str(device) == "cpu":
1986 cuda_device = -1
1987 else:
1988 cuda_device = 0
1990 self.ee = allennlp.commands.elmo.ElmoEmbedder(
1991 options_file=options_file, weight_file=weight_file, cuda_device=cuda_device
1992 )
1994 # embed a dummy sentence to determine embedding_length
1995 dummy_sentence: Sentence = Sentence()
1996 dummy_sentence.add_token(Token("hello"))
1997 embedded_dummy = self.embed(dummy_sentence)
1998 self.__embedding_length: int = len(
1999 embedded_dummy[0].get_token(1).get_embedding()
2000 )
2002 @property
2003 def embedding_length(self) -> int:
2004 return self.__embedding_length
2006 def use_layers_all(self, x):
2007 return torch.cat(x, 0)
2009 def use_layers_top(self, x):
2010 return x[-1]
2012 def use_layers_average(self, x):
2013 return torch.mean(torch.stack(x), 0)
2015 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
2016 # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn
2017 if not getattr(self, "embedding_mode_fn", None):
2018 self.embedding_mode_fn = self.use_layers_all
2020 sentence_words: List[List[str]] = []
2021 for sentence in sentences:
2022 sentence_words.append([token.text for token in sentence])
2024 embeddings = self.ee.embed_batch(sentence_words)
2026 for i, sentence in enumerate(sentences):
2028 sentence_embeddings = embeddings[i]
2030 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
2031 elmo_embedding_layers = [
2032 torch.FloatTensor(sentence_embeddings[0, token_idx, :]),
2033 torch.FloatTensor(sentence_embeddings[1, token_idx, :]),
2034 torch.FloatTensor(sentence_embeddings[2, token_idx, :])
2035 ]
2036 word_embedding = self.embedding_mode_fn(elmo_embedding_layers)
2037 token.set_embedding(self.name, word_embedding)
2039 return sentences
2041 def extra_repr(self):
2042 return "model={}".format(self.name)
2044 def __str__(self):
2045 return self.name
2047 def __setstate__(self, state):
2048 self.__dict__ = state
2050 if re.fullmatch(r"cuda:[0-9]+", str(flair.device)):
2051 cuda_device = int(str(flair.device).split(":")[-1])
2052 elif str(flair.device) == "cpu":
2053 cuda_device = -1
2054 else:
2055 cuda_device = 0
2057 self.ee.cuda_device = cuda_device
2059 self.ee.elmo_bilm.to(device=flair.device)
2060 self.ee.elmo_bilm._elmo_lstm._states = tuple(
2061 [state.to(flair.device) for state in self.ee.elmo_bilm._elmo_lstm._states])
2064class NILCEmbeddings(WordEmbeddings):
2065 def __init__(self, embeddings: str, model: str = "skip", size: int = 100):
2066 """
2067 Initializes portuguese classic word embeddings trained by NILC Lab (http://www.nilc.icmc.usp.br/embeddings).
2068 Constructor downloads required files if not there.
2069 :param embeddings: one of: 'fasttext', 'glove', 'wang2vec' or 'word2vec'
2070 :param model: one of: 'skip' or 'cbow'. This is not applicable to glove.
2071 :param size: one of: 50, 100, 300, 600 or 1000.
2072 """
2074 self.instance_parameters = self.get_instance_parameters(locals=locals())
2076 base_path = "http://143.107.183.175:22980/download.php?file=embeddings/"
2078 cache_dir = Path("embeddings") / embeddings.lower()
2080 # GLOVE embeddings
2081 if embeddings.lower() == "glove":
2082 cached_path(
2083 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir
2084 )
2085 embeddings = cached_path(
2086 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir
2087 )
2089 elif embeddings.lower() in ["fasttext", "wang2vec", "word2vec"]:
2090 cached_path(
2091 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir
2092 )
2093 embeddings = cached_path(
2094 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir
2095 )
2097 elif not Path(embeddings).exists():
2098 raise ValueError(
2099 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
2100 )
2102 self.name: str = str(embeddings)
2103 self.static_embeddings = True
2105 log.info("Reading embeddings from %s" % embeddings)
2106 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
2107 open_inside_zip(str(embeddings), cache_dir=cache_dir)
2108 )
2110 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
2111 super(TokenEmbeddings, self).__init__()
2113 @property
2114 def embedding_length(self) -> int:
2115 return self.__embedding_length
2117 def __str__(self):
2118 return self.name
2121def replace_with_language_code(string: str):
2122 string = string.replace("arabic-", "ar-")
2123 string = string.replace("basque-", "eu-")
2124 string = string.replace("bulgarian-", "bg-")
2125 string = string.replace("croatian-", "hr-")
2126 string = string.replace("czech-", "cs-")
2127 string = string.replace("danish-", "da-")
2128 string = string.replace("dutch-", "nl-")
2129 string = string.replace("farsi-", "fa-")
2130 string = string.replace("persian-", "fa-")
2131 string = string.replace("finnish-", "fi-")
2132 string = string.replace("french-", "fr-")
2133 string = string.replace("german-", "de-")
2134 string = string.replace("hebrew-", "he-")
2135 string = string.replace("hindi-", "hi-")
2136 string = string.replace("indonesian-", "id-")
2137 string = string.replace("italian-", "it-")
2138 string = string.replace("japanese-", "ja-")
2139 string = string.replace("norwegian-", "no")
2140 string = string.replace("polish-", "pl-")
2141 string = string.replace("portuguese-", "pt-")
2142 string = string.replace("slovenian-", "sl-")
2143 string = string.replace("spanish-", "es-")
2144 string = string.replace("swedish-", "sv-")
2145 return string