Coverage for flair/flair/embeddings/token.py: 21%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import hashlib
2import inspect
3import logging
4import os
5import re
6from abc import abstractmethod
7from collections import Counter
8from pathlib import Path
9from typing import List, Union, Dict
11import gensim
12import numpy as np
13import torch
14from bpemb import BPEmb
15from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer, XLNetModel, \
16 TransfoXLModel
18import flair
19from flair.data import Sentence, Token, Corpus, Dictionary
20from flair.embeddings.base import Embeddings
21from flair.file_utils import cached_path, open_inside_zip, instance_lru_cache
23log = logging.getLogger("flair")
26class TokenEmbeddings(Embeddings):
27 """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""
29 @property
30 @abstractmethod
31 def embedding_length(self) -> int:
32 """Returns the length of the embedding vector."""
33 pass
35 @property
36 def embedding_type(self) -> str:
37 return "word-level"
39 @staticmethod
40 def get_instance_parameters(locals: dict) -> dict:
41 class_definition = locals.get("__class__")
42 instance_parameters = set(inspect.getfullargspec(class_definition.__init__).args)
43 instance_parameters.difference_update(set(["self"]))
44 instance_parameters.update(set(["__class__"]))
45 instance_parameters = {class_attribute: attribute_value for class_attribute, attribute_value in locals.items()
46 if class_attribute in instance_parameters}
47 return instance_parameters
50class StackedEmbeddings(TokenEmbeddings):
51 """A stack of embeddings, used if you need to combine several different embedding types."""
53 def __init__(self, embeddings: List[TokenEmbeddings]):
54 """The constructor takes a list of embeddings to be combined."""
55 super().__init__()
57 self.embeddings = embeddings
59 # IMPORTANT: add embeddings as torch modules
60 for i, embedding in enumerate(embeddings):
61 embedding.name = f"{str(i)}-{embedding.name}"
62 self.add_module(f"list_embedding_{str(i)}", embedding)
64 self.name: str = "Stack"
65 self.static_embeddings: bool = True
67 self.__embedding_type: str = embeddings[0].embedding_type
69 self.__embedding_length: int = 0
70 for embedding in embeddings:
71 self.__embedding_length += embedding.embedding_length
73 def embed(
74 self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True
75 ):
76 # if only one sentence is passed, convert to list of sentence
77 if type(sentences) is Sentence:
78 sentences = [sentences]
80 for embedding in self.embeddings:
81 embedding.embed(sentences)
83 @property
84 def embedding_type(self) -> str:
85 return self.__embedding_type
87 @property
88 def embedding_length(self) -> int:
89 return self.__embedding_length
91 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
93 for embedding in self.embeddings:
94 embedding._add_embeddings_internal(sentences)
96 return sentences
98 def __str__(self):
99 return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]'
101 def get_names(self) -> List[str]:
102 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of
103 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding).
104 Then, the list contains the names of all embeddings in the stack."""
105 names = []
106 for embedding in self.embeddings:
107 names.extend(embedding.get_names())
108 return names
110 def get_named_embeddings_dict(self) -> Dict:
112 named_embeddings_dict = {}
113 for embedding in self.embeddings:
114 named_embeddings_dict.update(embedding.get_named_embeddings_dict())
116 return named_embeddings_dict
119class WordEmbeddings(TokenEmbeddings):
120 """Standard static word embeddings, such as GloVe or FastText."""
122 def __init__(self, embeddings: str, field: str = None):
123 """
124 Initializes classic word embeddings. Constructor downloads required files if not there.
125 :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or custom
126 If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable.
127 """
128 self.embeddings = embeddings
130 self.instance_parameters = self.get_instance_parameters(locals=locals())
132 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/token"
134 cache_dir = Path("embeddings")
136 # GLOVE embeddings
137 if embeddings.lower() == "glove" or embeddings.lower() == "en-glove":
138 cached_path(f"{hu_path}/glove.gensim.vectors.npy", cache_dir=cache_dir)
139 embeddings = cached_path(f"{hu_path}/glove.gensim", cache_dir=cache_dir)
141 # TURIAN embeddings
142 elif embeddings.lower() == "turian" or embeddings.lower() == "en-turian":
143 cached_path(f"{hu_path}/turian.vectors.npy", cache_dir=cache_dir)
144 embeddings = cached_path(f"{hu_path}/turian", cache_dir=cache_dir)
146 # KOMNINOS embeddings
147 elif embeddings.lower() == "extvec" or embeddings.lower() == "en-extvec":
148 cached_path(f"{hu_path}/extvec.gensim.vectors.npy", cache_dir=cache_dir)
149 embeddings = cached_path(f"{hu_path}/extvec.gensim", cache_dir=cache_dir)
151 # pubmed embeddings
152 elif embeddings.lower() == "pubmed" or embeddings.lower() == "en-pubmed":
153 cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim.vectors.npy", cache_dir=cache_dir)
154 embeddings = cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim", cache_dir=cache_dir)
156 # FT-CRAWL embeddings
157 elif embeddings.lower() == "crawl" or embeddings.lower() == "en-crawl":
158 cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M.vectors.npy", cache_dir=cache_dir)
159 embeddings = cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M", cache_dir=cache_dir)
161 # FT-CRAWL embeddings
162 elif embeddings.lower() in ["news", "en-news", "en"]:
163 cached_path(f"{hu_path}/en-fasttext-news-300d-1M.vectors.npy", cache_dir=cache_dir)
164 embeddings = cached_path(f"{hu_path}/en-fasttext-news-300d-1M", cache_dir=cache_dir)
166 # twitter embeddings
167 elif embeddings.lower() in ["twitter", "en-twitter"]:
168 cached_path(f"{hu_path}/twitter.gensim.vectors.npy", cache_dir=cache_dir)
169 embeddings = cached_path(f"{hu_path}/twitter.gensim", cache_dir=cache_dir)
171 # two-letter language code wiki embeddings
172 elif len(embeddings.lower()) == 2:
173 cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
174 embeddings = cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M", cache_dir=cache_dir)
176 # two-letter language code wiki embeddings
177 elif len(embeddings.lower()) == 7 and embeddings.endswith("-wiki"):
178 cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
179 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M", cache_dir=cache_dir)
181 # two-letter language code crawl embeddings
182 elif len(embeddings.lower()) == 8 and embeddings.endswith("-crawl"):
183 cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir)
184 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M", cache_dir=cache_dir)
186 elif not Path(embeddings).exists():
187 raise ValueError(
188 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
189 )
191 self.name: str = str(embeddings)
192 self.static_embeddings = True
194 if str(embeddings).endswith(".bin"):
195 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
196 str(embeddings), binary=True
197 )
198 else:
199 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load(
200 str(embeddings)
201 )
203 self.field = field
205 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
206 super().__init__()
208 @property
209 def embedding_length(self) -> int:
210 return self.__embedding_length
212 @instance_lru_cache(maxsize=10000, typed=False)
213 def get_cached_vec(self, word: str) -> torch.Tensor:
214 if word in self.precomputed_word_embeddings:
215 word_embedding = self.precomputed_word_embeddings[word]
216 elif word.lower() in self.precomputed_word_embeddings:
217 word_embedding = self.precomputed_word_embeddings[word.lower()]
218 elif re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings:
219 word_embedding = self.precomputed_word_embeddings[
220 re.sub(r"\d", "#", word.lower())
221 ]
222 elif re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings:
223 word_embedding = self.precomputed_word_embeddings[
224 re.sub(r"\d", "0", word.lower())
225 ]
226 else:
227 word_embedding = np.zeros(self.embedding_length, dtype="float")
229 word_embedding = torch.tensor(
230 word_embedding.tolist(), device=flair.device, dtype=torch.float
231 )
232 return word_embedding
234 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
236 for i, sentence in enumerate(sentences):
238 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
240 if "field" not in self.__dict__ or self.field is None:
241 word = token.text
242 else:
243 word = token.get_tag(self.field).value
245 word_embedding = self.get_cached_vec(word=word)
247 token.set_embedding(self.name, word_embedding)
249 return sentences
251 def __str__(self):
252 return self.name
254 def extra_repr(self):
255 # fix serialized models
256 if "embeddings" not in self.__dict__:
257 self.embeddings = self.name
259 return f"'{self.embeddings}'"
262class CharacterEmbeddings(TokenEmbeddings):
263 """Character embeddings of words, as proposed in Lample et al., 2016."""
265 def __init__(
266 self,
267 path_to_char_dict: str = None,
268 char_embedding_dim: int = 25,
269 hidden_size_char: int = 25,
270 ):
271 """Uses the default character dictionary if none provided."""
273 super().__init__()
274 self.name = "Char"
275 self.static_embeddings = False
276 self.instance_parameters = self.get_instance_parameters(locals=locals())
278 # use list of common characters if none provided
279 if path_to_char_dict is None:
280 self.char_dictionary: Dictionary = Dictionary.load("common-chars")
281 else:
282 self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict)
284 self.char_embedding_dim: int = char_embedding_dim
285 self.hidden_size_char: int = hidden_size_char
286 self.char_embedding = torch.nn.Embedding(
287 len(self.char_dictionary.item2idx), self.char_embedding_dim
288 )
289 self.char_rnn = torch.nn.LSTM(
290 self.char_embedding_dim,
291 self.hidden_size_char,
292 num_layers=1,
293 bidirectional=True,
294 )
296 self.__embedding_length = self.hidden_size_char * 2
298 self.to(flair.device)
300 @property
301 def embedding_length(self) -> int:
302 return self.__embedding_length
304 def _add_embeddings_internal(self, sentences: List[Sentence]):
306 for sentence in sentences:
308 tokens_char_indices = []
310 # translate words in sentence into ints using dictionary
311 for token in sentence.tokens:
312 char_indices = [
313 self.char_dictionary.get_idx_for_item(char) for char in token.text
314 ]
315 tokens_char_indices.append(char_indices)
317 # sort words by length, for batching and masking
318 tokens_sorted_by_length = sorted(
319 tokens_char_indices, key=lambda p: len(p), reverse=True
320 )
321 d = {}
322 for i, ci in enumerate(tokens_char_indices):
323 for j, cj in enumerate(tokens_sorted_by_length):
324 if ci == cj:
325 d[j] = i
326 continue
327 chars2_length = [len(c) for c in tokens_sorted_by_length]
328 longest_token_in_sentence = max(chars2_length)
329 tokens_mask = torch.zeros(
330 (len(tokens_sorted_by_length), longest_token_in_sentence),
331 dtype=torch.long,
332 device=flair.device,
333 )
335 for i, c in enumerate(tokens_sorted_by_length):
336 tokens_mask[i, : chars2_length[i]] = torch.tensor(
337 c, dtype=torch.long, device=flair.device
338 )
340 # chars for rnn processing
341 chars = tokens_mask
343 character_embeddings = self.char_embedding(chars).transpose(0, 1)
345 packed = torch.nn.utils.rnn.pack_padded_sequence(
346 character_embeddings, chars2_length
347 )
349 lstm_out, self.hidden = self.char_rnn(packed)
351 outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
352 outputs = outputs.transpose(0, 1)
353 chars_embeds_temp = torch.zeros(
354 (outputs.size(0), outputs.size(2)),
355 dtype=torch.float,
356 device=flair.device,
357 )
358 for i, index in enumerate(output_lengths):
359 chars_embeds_temp[i] = outputs[i, index - 1]
360 character_embeddings = chars_embeds_temp.clone()
361 for i in range(character_embeddings.size(0)):
362 character_embeddings[d[i]] = chars_embeds_temp[i]
364 for token_number, token in enumerate(sentence.tokens):
365 token.set_embedding(self.name, character_embeddings[token_number])
367 def __str__(self):
368 return self.name
371class FlairEmbeddings(TokenEmbeddings):
372 """Contextual string embeddings of words, as proposed in Akbik et al., 2018."""
374 def __init__(self,
375 model,
376 fine_tune: bool = False,
377 chars_per_chunk: int = 512,
378 with_whitespace: bool = True,
379 tokenized_lm: bool = True,
380 is_lower: bool = False,
381 ):
382 """
383 initializes contextual string embeddings using a character-level language model.
384 :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
385 'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward',
386 etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md)
387 depending on which character language model is desired.
388 :param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows
389 down training and often leads to overfitting, so use with caution.
390 :param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster
391 but requires more memory. Lower means slower but less memory.
392 :param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden
393 state at last character of word.
394 :param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text
395 False might be better.
396 """
397 super().__init__()
398 self.instance_parameters = self.get_instance_parameters(locals=locals())
400 cache_dir = Path("embeddings")
402 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/flair"
403 clef_hipe_path: str = "https://files.ifi.uzh.ch/cl/siclemat/impresso/clef-hipe-2020/flair"
405 self.is_lower: bool = is_lower
407 self.PRETRAINED_MODEL_ARCHIVE_MAP = {
408 # multilingual models
409 "multi-forward": f"{hu_path}/lm-jw300-forward-v0.1.pt",
410 "multi-backward": f"{hu_path}/lm-jw300-backward-v0.1.pt",
411 "multi-v0-forward": f"{hu_path}/lm-multi-forward-v0.1.pt",
412 "multi-v0-backward": f"{hu_path}/lm-multi-backward-v0.1.pt",
413 "multi-forward-fast": f"{hu_path}/lm-multi-forward-fast-v0.1.pt",
414 "multi-backward-fast": f"{hu_path}/lm-multi-backward-fast-v0.1.pt",
415 # English models
416 "en-forward": f"{hu_path}/news-forward-0.4.1.pt",
417 "en-backward": f"{hu_path}/news-backward-0.4.1.pt",
418 "en-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt",
419 "en-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt",
420 "news-forward": f"{hu_path}/news-forward-0.4.1.pt",
421 "news-backward": f"{hu_path}/news-backward-0.4.1.pt",
422 "news-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt",
423 "news-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt",
424 "mix-forward": f"{hu_path}/lm-mix-english-forward-v0.2rc.pt",
425 "mix-backward": f"{hu_path}/lm-mix-english-backward-v0.2rc.pt",
426 # Arabic
427 "ar-forward": f"{hu_path}/lm-ar-opus-large-forward-v0.1.pt",
428 "ar-backward": f"{hu_path}/lm-ar-opus-large-backward-v0.1.pt",
429 # Bulgarian
430 "bg-forward-fast": f"{hu_path}/lm-bg-small-forward-v0.1.pt",
431 "bg-backward-fast": f"{hu_path}/lm-bg-small-backward-v0.1.pt",
432 "bg-forward": f"{hu_path}/lm-bg-opus-large-forward-v0.1.pt",
433 "bg-backward": f"{hu_path}/lm-bg-opus-large-backward-v0.1.pt",
434 # Czech
435 "cs-forward": f"{hu_path}/lm-cs-opus-large-forward-v0.1.pt",
436 "cs-backward": f"{hu_path}/lm-cs-opus-large-backward-v0.1.pt",
437 "cs-v0-forward": f"{hu_path}/lm-cs-large-forward-v0.1.pt",
438 "cs-v0-backward": f"{hu_path}/lm-cs-large-backward-v0.1.pt",
439 # Danish
440 "da-forward": f"{hu_path}/lm-da-opus-large-forward-v0.1.pt",
441 "da-backward": f"{hu_path}/lm-da-opus-large-backward-v0.1.pt",
442 # German
443 "de-forward": f"{hu_path}/lm-mix-german-forward-v0.2rc.pt",
444 "de-backward": f"{hu_path}/lm-mix-german-backward-v0.2rc.pt",
445 "de-historic-ha-forward": f"{hu_path}/lm-historic-hamburger-anzeiger-forward-v0.1.pt",
446 "de-historic-ha-backward": f"{hu_path}/lm-historic-hamburger-anzeiger-backward-v0.1.pt",
447 "de-historic-wz-forward": f"{hu_path}/lm-historic-wiener-zeitung-forward-v0.1.pt",
448 "de-historic-wz-backward": f"{hu_path}/lm-historic-wiener-zeitung-backward-v0.1.pt",
449 "de-historic-rw-forward": f"{hu_path}/redewiedergabe_lm_forward.pt",
450 "de-historic-rw-backward": f"{hu_path}/redewiedergabe_lm_backward.pt",
451 # Spanish
452 "es-forward": f"{hu_path}/lm-es-forward.pt",
453 "es-backward": f"{hu_path}/lm-es-backward.pt",
454 "es-forward-fast": f"{hu_path}/lm-es-forward-fast.pt",
455 "es-backward-fast": f"{hu_path}/lm-es-backward-fast.pt",
456 # Basque
457 "eu-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.2.pt",
458 "eu-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.2.pt",
459 "eu-v1-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.1.pt",
460 "eu-v1-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.1.pt",
461 "eu-v0-forward": f"{hu_path}/lm-eu-large-forward-v0.1.pt",
462 "eu-v0-backward": f"{hu_path}/lm-eu-large-backward-v0.1.pt",
463 # Persian
464 "fa-forward": f"{hu_path}/lm-fa-opus-large-forward-v0.1.pt",
465 "fa-backward": f"{hu_path}/lm-fa-opus-large-backward-v0.1.pt",
466 # Finnish
467 "fi-forward": f"{hu_path}/lm-fi-opus-large-forward-v0.1.pt",
468 "fi-backward": f"{hu_path}/lm-fi-opus-large-backward-v0.1.pt",
469 # French
470 "fr-forward": f"{hu_path}/lm-fr-charlm-forward.pt",
471 "fr-backward": f"{hu_path}/lm-fr-charlm-backward.pt",
472 # Hebrew
473 "he-forward": f"{hu_path}/lm-he-opus-large-forward-v0.1.pt",
474 "he-backward": f"{hu_path}/lm-he-opus-large-backward-v0.1.pt",
475 # Hindi
476 "hi-forward": f"{hu_path}/lm-hi-opus-large-forward-v0.1.pt",
477 "hi-backward": f"{hu_path}/lm-hi-opus-large-backward-v0.1.pt",
478 # Croatian
479 "hr-forward": f"{hu_path}/lm-hr-opus-large-forward-v0.1.pt",
480 "hr-backward": f"{hu_path}/lm-hr-opus-large-backward-v0.1.pt",
481 # Indonesian
482 "id-forward": f"{hu_path}/lm-id-opus-large-forward-v0.1.pt",
483 "id-backward": f"{hu_path}/lm-id-opus-large-backward-v0.1.pt",
484 # Italian
485 "it-forward": f"{hu_path}/lm-it-opus-large-forward-v0.1.pt",
486 "it-backward": f"{hu_path}/lm-it-opus-large-backward-v0.1.pt",
487 # Japanese
488 "ja-forward": f"{hu_path}/japanese-forward.pt",
489 "ja-backward": f"{hu_path}/japanese-backward.pt",
490 # Malayalam
491 "ml-forward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-forward.pt",
492 "ml-backward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-backward.pt",
493 # Dutch
494 "nl-forward": f"{hu_path}/lm-nl-opus-large-forward-v0.1.pt",
495 "nl-backward": f"{hu_path}/lm-nl-opus-large-backward-v0.1.pt",
496 "nl-v0-forward": f"{hu_path}/lm-nl-large-forward-v0.1.pt",
497 "nl-v0-backward": f"{hu_path}/lm-nl-large-backward-v0.1.pt",
498 # Norwegian
499 "no-forward": f"{hu_path}/lm-no-opus-large-forward-v0.1.pt",
500 "no-backward": f"{hu_path}/lm-no-opus-large-backward-v0.1.pt",
501 # Polish
502 "pl-forward": f"{hu_path}/lm-polish-forward-v0.2.pt",
503 "pl-backward": f"{hu_path}/lm-polish-backward-v0.2.pt",
504 "pl-opus-forward": f"{hu_path}/lm-pl-opus-large-forward-v0.1.pt",
505 "pl-opus-backward": f"{hu_path}/lm-pl-opus-large-backward-v0.1.pt",
506 # Portuguese
507 "pt-forward": f"{hu_path}/lm-pt-forward.pt",
508 "pt-backward": f"{hu_path}/lm-pt-backward.pt",
509 # Pubmed
510 "pubmed-forward": f"{hu_path}/pubmed-forward.pt",
511 "pubmed-backward": f"{hu_path}/pubmed-backward.pt",
512 "pubmed-2015-forward": f"{hu_path}/pubmed-2015-fw-lm.pt",
513 "pubmed-2015-backward": f"{hu_path}/pubmed-2015-bw-lm.pt",
514 # Slovenian
515 "sl-forward": f"{hu_path}/lm-sl-opus-large-forward-v0.1.pt",
516 "sl-backward": f"{hu_path}/lm-sl-opus-large-backward-v0.1.pt",
517 "sl-v0-forward": f"{hu_path}/lm-sl-large-forward-v0.1.pt",
518 "sl-v0-backward": f"{hu_path}/lm-sl-large-backward-v0.1.pt",
519 # Swedish
520 "sv-forward": f"{hu_path}/lm-sv-opus-large-forward-v0.1.pt",
521 "sv-backward": f"{hu_path}/lm-sv-opus-large-backward-v0.1.pt",
522 "sv-v0-forward": f"{hu_path}/lm-sv-large-forward-v0.1.pt",
523 "sv-v0-backward": f"{hu_path}/lm-sv-large-backward-v0.1.pt",
524 # Tamil
525 "ta-forward": f"{hu_path}/lm-ta-opus-large-forward-v0.1.pt",
526 "ta-backward": f"{hu_path}/lm-ta-opus-large-backward-v0.1.pt",
527 # Spanish clinical
528 "es-clinical-forward": f"{hu_path}/es-clinical-forward.pt",
529 "es-clinical-backward": f"{hu_path}/es-clinical-backward.pt",
530 # CLEF HIPE Shared task
531 "de-impresso-hipe-v1-forward": f"{clef_hipe_path}/de-hipe-flair-v1-forward/best-lm.pt",
532 "de-impresso-hipe-v1-backward": f"{clef_hipe_path}/de-hipe-flair-v1-backward/best-lm.pt",
533 "en-impresso-hipe-v1-forward": f"{clef_hipe_path}/en-flair-v1-forward/best-lm.pt",
534 "en-impresso-hipe-v1-backward": f"{clef_hipe_path}/en-flair-v1-backward/best-lm.pt",
535 "fr-impresso-hipe-v1-forward": f"{clef_hipe_path}/fr-hipe-flair-v1-forward/best-lm.pt",
536 "fr-impresso-hipe-v1-backward": f"{clef_hipe_path}/fr-hipe-flair-v1-backward/best-lm.pt",
537 }
539 if type(model) == str:
541 # load model if in pretrained model map
542 if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP:
543 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()]
545 # Fix for CLEF HIPE models (avoid overwriting best-lm.pt in cache_dir)
546 if "impresso-hipe" in model.lower():
547 cache_dir = cache_dir / model.lower()
548 # CLEF HIPE models are lowercased
549 self.is_lower = True
550 model = cached_path(base_path, cache_dir=cache_dir)
552 elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP:
553 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[
554 replace_with_language_code(model)
555 ]
556 model = cached_path(base_path, cache_dir=cache_dir)
558 elif not Path(model).exists():
559 raise ValueError(
560 f'The given model "{model}" is not available or is not a valid path.'
561 )
563 from flair.models import LanguageModel
565 if type(model) == LanguageModel:
566 self.lm: LanguageModel = model
567 self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}"
568 else:
569 self.lm: LanguageModel = LanguageModel.load_language_model(model)
570 self.name = str(model)
572 # embeddings are static if we don't do finetuning
573 self.fine_tune = fine_tune
574 self.static_embeddings = not fine_tune
576 self.is_forward_lm: bool = self.lm.is_forward_lm
577 self.with_whitespace: bool = with_whitespace
578 self.tokenized_lm: bool = tokenized_lm
579 self.chars_per_chunk: int = chars_per_chunk
581 # embed a dummy sentence to determine embedding_length
582 dummy_sentence: Sentence = Sentence()
583 dummy_sentence.add_token(Token("hello"))
584 embedded_dummy = self.embed(dummy_sentence)
585 self.__embedding_length: int = len(
586 embedded_dummy[0].get_token(1).get_embedding()
587 )
589 # set to eval mode
590 self.eval()
592 def train(self, mode=True):
594 # make compatible with serialized models (TODO: remove)
595 if "fine_tune" not in self.__dict__:
596 self.fine_tune = False
597 if "chars_per_chunk" not in self.__dict__:
598 self.chars_per_chunk = 512
600 # unless fine-tuning is set, do not set language model to train() in order to disallow language model dropout
601 if not self.fine_tune:
602 pass
603 else:
604 super(FlairEmbeddings, self).train(mode)
606 @property
607 def embedding_length(self) -> int:
608 return self.__embedding_length
610 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
612 # make compatible with serialized models (TODO: remove)
613 if "with_whitespace" not in self.__dict__:
614 self.with_whitespace = True
615 if "tokenized_lm" not in self.__dict__:
616 self.tokenized_lm = True
617 if "is_lower" not in self.__dict__:
618 self.is_lower = False
620 # gradients are enable if fine-tuning is enabled
621 gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()
623 with gradient_context:
625 # if this is not possible, use LM to generate embedding. First, get text sentences
626 text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \
627 else [sentence.to_plain_string() for sentence in sentences]
629 if self.is_lower:
630 text_sentences = [sentence.lower() for sentence in text_sentences]
632 start_marker = self.lm.document_delimiter if "document_delimiter" in self.lm.__dict__ else '\n'
633 end_marker = " "
635 # get hidden states from language model
636 all_hidden_states_in_lm = self.lm.get_representation(
637 text_sentences, start_marker, end_marker, self.chars_per_chunk
638 )
640 if not self.fine_tune:
641 all_hidden_states_in_lm = all_hidden_states_in_lm.detach()
643 # take first or last hidden states from language model as word representation
644 for i, sentence in enumerate(sentences):
645 sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string()
647 offset_forward: int = len(start_marker)
648 offset_backward: int = len(sentence_text) + len(start_marker)
650 for token in sentence.tokens:
652 offset_forward += len(token.text)
653 if self.is_forward_lm:
654 offset_with_whitespace = offset_forward
655 offset_without_whitespace = offset_forward - 1
656 else:
657 offset_with_whitespace = offset_backward
658 offset_without_whitespace = offset_backward - 1
660 # offset mode that extracts at whitespace after last character
661 if self.with_whitespace:
662 embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :]
663 # offset mode that extracts at last character
664 else:
665 embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :]
667 if self.tokenized_lm or token.whitespace_after:
668 offset_forward += 1
669 offset_backward -= 1
671 offset_backward -= len(token.text)
673 # only clone if optimization mode is 'gpu'
674 if flair.embedding_storage_mode == "gpu":
675 embedding = embedding.clone()
677 token.set_embedding(self.name, embedding)
679 del all_hidden_states_in_lm
681 return sentences
683 def __str__(self):
684 return self.name
687class PooledFlairEmbeddings(TokenEmbeddings):
688 def __init__(
689 self,
690 contextual_embeddings: Union[str, FlairEmbeddings],
691 pooling: str = "min",
692 only_capitalized: bool = False,
693 **kwargs,
694 ):
696 super().__init__()
697 self.instance_parameters = self.get_instance_parameters(locals=locals())
699 # use the character language model embeddings as basis
700 if type(contextual_embeddings) is str:
701 self.context_embeddings: FlairEmbeddings = FlairEmbeddings(
702 contextual_embeddings, **kwargs
703 )
704 else:
705 self.context_embeddings: FlairEmbeddings = contextual_embeddings
707 # length is twice the original character LM embedding length
708 self.embedding_length = self.context_embeddings.embedding_length * 2
709 self.name = self.context_embeddings.name + "-context"
711 # these fields are for the embedding memory
712 self.word_embeddings = {}
713 self.word_count = {}
715 # whether to add only capitalized words to memory (faster runtime and lower memory consumption)
716 self.only_capitalized = only_capitalized
718 # we re-compute embeddings dynamically at each epoch
719 self.static_embeddings = False
721 # set the memory method
722 self.pooling = pooling
724 def train(self, mode=True):
725 super().train(mode=mode)
726 if mode:
727 # memory is wiped each time we do a training run
728 print("train mode resetting embeddings")
729 self.word_embeddings = {}
730 self.word_count = {}
732 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
734 self.context_embeddings.embed(sentences)
736 # if we keep a pooling, it needs to be updated continuously
737 for sentence in sentences:
738 for token in sentence.tokens:
740 # update embedding
741 local_embedding = token._embeddings[self.context_embeddings.name].cpu()
743 # check token.text is empty or not
744 if token.text:
745 if token.text[0].isupper() or not self.only_capitalized:
747 if token.text not in self.word_embeddings:
748 self.word_embeddings[token.text] = local_embedding
749 self.word_count[token.text] = 1
750 else:
752 # set aggregation operation
753 if self.pooling == "mean":
754 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
755 elif self.pooling == "fade":
756 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
757 aggregated_embedding /= 2
758 elif self.pooling == "max":
759 aggregated_embedding = torch.max(self.word_embeddings[token.text], local_embedding)
760 elif self.pooling == "min":
761 aggregated_embedding = torch.min(self.word_embeddings[token.text], local_embedding)
763 self.word_embeddings[token.text] = aggregated_embedding
764 self.word_count[token.text] += 1
766 # add embeddings after updating
767 for sentence in sentences:
768 for token in sentence.tokens:
769 if token.text in self.word_embeddings:
770 base = (
771 self.word_embeddings[token.text] / self.word_count[token.text]
772 if self.pooling == "mean"
773 else self.word_embeddings[token.text]
774 )
775 else:
776 base = token._embeddings[self.context_embeddings.name]
778 token.set_embedding(self.name, base)
780 return sentences
782 def embedding_length(self) -> int:
783 return self.embedding_length
785 def get_names(self) -> List[str]:
786 return [self.name, self.context_embeddings.name]
788 def __setstate__(self, d):
789 self.__dict__ = d
791 if flair.device != 'cpu':
792 for key in self.word_embeddings:
793 self.word_embeddings[key] = self.word_embeddings[key].cpu()
796class TransformerWordEmbeddings(TokenEmbeddings):
797 NO_MAX_SEQ_LENGTH_MODELS = [XLNetModel, TransfoXLModel]
799 def __init__(
800 self,
801 model: str = "bert-base-uncased",
802 layers: str = "all",
803 subtoken_pooling: str = "first",
804 layer_mean: bool = True,
805 fine_tune: bool = False,
806 allow_long_sentences: bool = True,
807 use_context: Union[bool, int] = False,
808 memory_effective_training: bool = True,
809 respect_document_boundaries: bool = True,
810 context_dropout: float = 0.5,
811 **kwargs
812 ):
813 """
814 Bidirectional transformer embeddings of words from various transformer architectures.
815 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
816 options)
817 :param layers: string indicating which layers to take for embedding (-1 is topmost layer)
818 :param subtoken_pooling: how to get from token piece embeddings to token embedding. Either take the first
819 subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean')
820 :param layer_mean: If True, uses a scalar mix of layers as embedding
821 :param fine_tune: If True, allows transformers to be fine-tuned during training
822 """
823 super().__init__()
824 self.instance_parameters = self.get_instance_parameters(locals=locals())
826 # temporary fix to disable tokenizer parallelism warning
827 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning)
828 import os
829 os.environ["TOKENIZERS_PARALLELISM"] = "false"
831 # do not print transformer warnings as these are confusing in this case
832 from transformers import logging
833 logging.set_verbosity_error()
835 # load tokenizer and transformer model
836 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
837 if not 'config' in kwargs:
838 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
839 self.model = AutoModel.from_pretrained(model, config=config)
840 else:
841 self.model = AutoModel.from_pretrained(None, **kwargs)
843 logging.set_verbosity_warning()
845 if type(self.model) not in self.NO_MAX_SEQ_LENGTH_MODELS:
846 self.allow_long_sentences = allow_long_sentences
847 self.truncate = True
848 self.max_subtokens_sequence_length = self.tokenizer.model_max_length
849 self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0
850 else:
851 # in the end, these models don't need this configuration
852 self.allow_long_sentences = False
853 self.truncate = False
854 self.max_subtokens_sequence_length = None
855 self.stride = 0
857 self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb
859 # model name
860 self.name = 'transformer-word-' + str(model)
861 self.base_model = str(model)
863 # whether to detach gradients on overlong sentences
864 self.memory_effective_training = memory_effective_training
866 # store whether to use context (and how much)
867 if type(use_context) == bool:
868 self.context_length: int = 64 if use_context else 0
869 if type(use_context) == int:
870 self.context_length: int = use_context
872 # dropout contexts
873 self.context_dropout = context_dropout
875 # if using context, can we cross document boundaries?
876 self.respect_document_boundaries = respect_document_boundaries
878 # send self to flair-device
879 self.to(flair.device)
881 # embedding parameters
882 if layers == 'all':
883 # send mini-token through to check how many layers the model has
884 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1]
885 self.layer_indexes = [int(x) for x in range(len(hidden_states))]
886 else:
887 self.layer_indexes = [int(x) for x in layers.split(",")]
889 self.pooling_operation = subtoken_pooling
890 self.layer_mean = layer_mean
891 self.fine_tune = fine_tune
892 self.static_embeddings = not self.fine_tune
894 # calculate embedding length
895 if not self.layer_mean:
896 length = len(self.layer_indexes) * self.model.config.hidden_size
897 else:
898 length = self.model.config.hidden_size
899 if self.pooling_operation == 'first_last': length *= 2
901 # return length
902 self.embedding_length_internal = length
904 self.special_tokens = []
905 # check if special tokens exist to circumvent error message
906 if self.tokenizer._bos_token:
907 self.special_tokens.append(self.tokenizer.bos_token)
908 if self.tokenizer._cls_token:
909 self.special_tokens.append(self.tokenizer.cls_token)
911 # most models have an intial BOS token, except for XLNet, T5 and GPT2
912 self.begin_offset = self._get_begin_offset_of_tokenizer(tokenizer=self.tokenizer)
914 # when initializing, embeddings are in eval mode by default
915 self.eval()
917 @staticmethod
918 def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int:
919 test_string = 'a'
920 tokens = tokenizer.encode(test_string)
922 for begin_offset, token in enumerate(tokens):
923 if tokenizer.decode([token]) == test_string or tokenizer.decode([token]) == tokenizer.unk_token:
924 break
925 return begin_offset
927 @staticmethod
928 def _remove_special_markup(text: str):
929 # remove special markup
930 text = re.sub('^Ġ', '', text) # RoBERTa models
931 text = re.sub('^##', '', text) # BERT models
932 text = re.sub('^▁', '', text) # XLNet models
933 text = re.sub('</w>$', '', text) # XLM models
934 return text
936 def _get_processed_token_text(self, token: Token) -> str:
937 pieces = self.tokenizer.tokenize(token.text)
938 token_text = ''
939 for piece in pieces:
940 token_text += self._remove_special_markup(piece)
941 token_text = token_text.lower()
942 return token_text
944 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
946 # we require encoded subtokenized sentences, the mapping to original tokens and the number of
947 # parts that each sentence produces
948 subtokenized_sentences = []
949 all_token_subtoken_lengths = []
951 # if we also use context, first expand sentence to include context
952 if self.context_length > 0:
954 # set context if not set already
955 previous_sentence = None
956 for sentence in sentences:
957 if sentence.is_context_set(): continue
958 sentence._previous_sentence = previous_sentence
959 sentence._next_sentence = None
960 if previous_sentence: previous_sentence._next_sentence = sentence
961 previous_sentence = sentence
963 original_sentences = []
964 expanded_sentences = []
965 context_offsets = []
967 for sentence in sentences:
968 # in case of contextualization, we must remember non-expanded sentence
969 original_sentence = sentence
970 original_sentences.append(original_sentence)
972 # create expanded sentence and remember context offsets
973 expanded_sentence, context_offset = self._expand_sentence_with_context(sentence)
974 expanded_sentences.append(expanded_sentence)
975 context_offsets.append(context_offset)
977 # overwrite sentence with expanded sentence
978 sentence = expanded_sentence
980 sentences = expanded_sentences
982 tokenized_sentences = []
983 for sentence in sentences:
985 # subtokenize the sentence
986 tokenized_string = sentence.to_tokenized_string()
988 # transformer specific tokenization
989 subtokenized_sentence = self.tokenizer.tokenize(tokenized_string)
991 # set zero embeddings for empty sentences and exclude
992 if len(subtokenized_sentence) == 0:
993 for token in sentence:
994 token.set_embedding(self.name, torch.zeros(self.embedding_length))
995 continue
997 # determine into how many subtokens each token is split
998 token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence)
1000 # remember tokenized sentences and their subtokenization
1001 tokenized_sentences.append(tokenized_string)
1002 all_token_subtoken_lengths.append(token_subtoken_lengths)
1004 # encode inputs
1005 batch_encoding = self.tokenizer(tokenized_sentences,
1006 max_length=self.max_subtokens_sequence_length,
1007 stride=self.stride,
1008 return_overflowing_tokens=self.allow_long_sentences,
1009 truncation=self.truncate,
1010 padding=True,
1011 return_tensors='pt',
1012 )
1014 input_ids = batch_encoding['input_ids'].to(flair.device)
1015 attention_mask = batch_encoding['attention_mask'].to(flair.device)
1017 # determine which sentence was split into how many parts
1018 sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \
1019 else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist()
1021 model_kwargs = {}
1022 # set language IDs for XLM-style transformers
1023 if self.use_lang_emb:
1024 model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype)
1026 for s_id, sentence in enumerate(tokenized_sentences):
1027 sequence_length = len(sentence)
1028 lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0)
1029 model_kwargs["langs"][s_id][:sequence_length] = lang_id
1031 # put encoded batch through transformer model to get all hidden states of all encoder layers
1032 hidden_states = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)[-1]
1033 # make the tuple a tensor; makes working with it easier.
1034 hidden_states = torch.stack(hidden_states)
1036 sentence_idx_offset = 0
1038 # gradients are enabled if fine-tuning is enabled
1039 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad()
1041 with gradient_context:
1043 # iterate over all subtokenized sentences
1044 for sentence_idx, (sentence, subtoken_lengths, nr_sentence_parts) in enumerate(
1045 zip(sentences, all_token_subtoken_lengths, sentence_parts_lengths)):
1047 sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...]
1049 for i in range(1, nr_sentence_parts):
1050 sentence_idx_offset += 1
1051 remainder_sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...]
1052 # remove stride_size//2 at end of sentence_hidden_state, and half at beginning of remainder,
1053 # in order to get some context into the embeddings of these words.
1054 # also don't include the embedding of the extra [CLS] and [SEP] tokens.
1055 sentence_hidden_state = torch.cat((sentence_hidden_state[:, :-1 - self.stride // 2, :],
1056 remainder_sentence_hidden_state[:, 1 + self.stride // 2:,
1057 :]), 1)
1059 subword_start_idx = self.begin_offset
1061 # for each token, get embedding
1062 for token_idx, (token, number_of_subtokens) in enumerate(zip(sentence, subtoken_lengths)):
1064 # some tokens have no subtokens at all (if omitted by BERT tokenizer) so return zero vector
1065 if number_of_subtokens == 0:
1066 token.set_embedding(self.name, torch.zeros(self.embedding_length))
1067 continue
1069 subword_end_idx = subword_start_idx + number_of_subtokens
1071 subtoken_embeddings: List[torch.FloatTensor] = []
1073 # get states from all selected layers, aggregate with pooling operation
1074 for layer in self.layer_indexes:
1075 current_embeddings = sentence_hidden_state[layer][subword_start_idx:subword_end_idx]
1077 if self.pooling_operation == "first":
1078 final_embedding: torch.FloatTensor = current_embeddings[0]
1080 if self.pooling_operation == "last":
1081 final_embedding: torch.FloatTensor = current_embeddings[-1]
1083 if self.pooling_operation == "first_last":
1084 final_embedding: torch.Tensor = torch.cat(
1085 [current_embeddings[0], current_embeddings[-1]])
1087 if self.pooling_operation == "mean":
1088 all_embeddings: List[torch.FloatTensor] = [
1089 embedding.unsqueeze(0) for embedding in current_embeddings
1090 ]
1091 final_embedding: torch.Tensor = torch.mean(torch.cat(all_embeddings, dim=0), dim=0)
1093 subtoken_embeddings.append(final_embedding)
1095 # use layer mean of embeddings if so selected
1096 if self.layer_mean and len(self.layer_indexes) > 1:
1097 sm_embeddings = torch.mean(torch.stack(subtoken_embeddings, dim=1), dim=1)
1098 subtoken_embeddings = [sm_embeddings]
1100 # set the extracted embedding for the token
1101 token.set_embedding(self.name, torch.cat(subtoken_embeddings))
1103 subword_start_idx += number_of_subtokens
1105 # move embeddings from context back to original sentence (if using context)
1106 if self.context_length > 0:
1107 for original_sentence, expanded_sentence, context_offset in zip(original_sentences,
1108 sentences,
1109 context_offsets):
1110 for token_idx, token in enumerate(original_sentence):
1111 token.set_embedding(self.name,
1112 expanded_sentence[token_idx + context_offset].get_embedding(self.name))
1113 sentence = original_sentence
1115 def _expand_sentence_with_context(self, sentence):
1117 # remember original sentence
1118 original_sentence = sentence
1120 import random
1121 expand_context = False if self.training and random.randint(1, 100) <= (self.context_dropout * 100) else True
1123 left_context = ''
1124 right_context = ''
1126 if expand_context:
1128 # get left context
1129 while True:
1130 sentence = sentence.previous_sentence()
1131 if sentence is None: break
1133 if self.respect_document_boundaries and sentence.is_document_boundary: break
1135 left_context = sentence.to_tokenized_string() + ' ' + left_context
1136 left_context = left_context.strip()
1137 if len(left_context.split(" ")) > self.context_length:
1138 left_context = " ".join(left_context.split(" ")[-self.context_length:])
1139 break
1140 original_sentence.left_context = left_context
1142 sentence = original_sentence
1144 # get right context
1145 while True:
1146 sentence = sentence.next_sentence()
1147 if sentence is None: break
1148 if self.respect_document_boundaries and sentence.is_document_boundary: break
1150 right_context += ' ' + sentence.to_tokenized_string()
1151 right_context = right_context.strip()
1152 if len(right_context.split(" ")) > self.context_length:
1153 right_context = " ".join(right_context.split(" ")[:self.context_length])
1154 break
1156 original_sentence.right_context = right_context
1158 left_context_split = left_context.split(" ")
1159 right_context_split = right_context.split(" ")
1161 # empty contexts should not introduce whitespace tokens
1162 if left_context_split == [""]: left_context_split = []
1163 if right_context_split == [""]: right_context_split = []
1165 # make expanded sentence
1166 expanded_sentence = Sentence()
1167 expanded_sentence.tokens = [Token(token) for token in left_context_split +
1168 original_sentence.to_tokenized_string().split(" ") +
1169 right_context_split]
1171 context_length = len(left_context_split)
1172 return expanded_sentence, context_length
1174 def reconstruct_tokens_from_subtokens(self, sentence, subtokens):
1175 word_iterator = iter(sentence)
1176 token = next(word_iterator)
1177 token_text = self._get_processed_token_text(token)
1178 token_subtoken_lengths = []
1179 reconstructed_token = ''
1180 subtoken_count = 0
1181 # iterate over subtokens and reconstruct tokens
1182 for subtoken_id, subtoken in enumerate(subtokens):
1184 # remove special markup
1185 subtoken = self._remove_special_markup(subtoken)
1187 # TODO check if this is necessary is this method is called before prepare_for_model
1188 # check if reconstructed token is special begin token ([CLS] or similar)
1189 if subtoken in self.special_tokens and subtoken_id == 0:
1190 continue
1192 # some BERT tokenizers somehow omit words - in such cases skip to next token
1193 if subtoken_count == 0 and not token_text.startswith(subtoken.lower()):
1195 while True:
1196 token_subtoken_lengths.append(0)
1197 token = next(word_iterator)
1198 token_text = self._get_processed_token_text(token)
1199 if token_text.startswith(subtoken.lower()): break
1201 subtoken_count += 1
1203 # append subtoken to reconstruct token
1204 reconstructed_token = reconstructed_token + subtoken
1206 # check if reconstructed token is the same as current token
1207 if reconstructed_token.lower() == token_text:
1209 # if so, add subtoken count
1210 token_subtoken_lengths.append(subtoken_count)
1212 # reset subtoken count and reconstructed token
1213 reconstructed_token = ''
1214 subtoken_count = 0
1216 # break from loop if all tokens are accounted for
1217 if len(token_subtoken_lengths) < len(sentence):
1218 token = next(word_iterator)
1219 token_text = self._get_processed_token_text(token)
1220 else:
1221 break
1223 # if tokens are unaccounted for
1224 while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1:
1225 token_subtoken_lengths.append(0)
1226 if len(token_subtoken_lengths) == len(sentence): break
1227 token = next(word_iterator)
1229 # check if all tokens were matched to subtokens
1230 if token != sentence[-1]:
1231 log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'")
1232 log.error(f"Last matched: '{token}'")
1233 log.error(f"Last sentence: '{sentence[-1]}'")
1234 log.error(f"subtokenized: '{subtokens}'")
1235 return token_subtoken_lengths
1237 @property
1238 def embedding_length(self) -> int:
1240 if "embedding_length_internal" in self.__dict__.keys():
1241 return self.embedding_length_internal
1243 # """Returns the length of the embedding vector."""
1244 if not self.layer_mean:
1245 length = len(self.layer_indexes) * self.model.config.hidden_size
1246 else:
1247 length = self.model.config.hidden_size
1249 if self.pooling_operation == 'first_last': length *= 2
1251 self.__embedding_length = length
1253 return length
1255 def __getstate__(self):
1256 # special handling for serializing transformer models
1257 config_state_dict = self.model.config.__dict__
1258 model_state_dict = self.model.state_dict()
1260 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-word-')[-1]
1262 # serialize the transformer models and the constructor arguments (but nothing else)
1263 model_state = {
1264 "config_state_dict": config_state_dict,
1265 "model_state_dict": model_state_dict,
1266 "embedding_length_internal": self.embedding_length,
1268 "base_model_name": self.base_model_name,
1269 "name": self.name,
1270 "layer_indexes": self.layer_indexes,
1271 "subtoken_pooling": self.pooling_operation,
1272 "context_length": self.context_length,
1273 "layer_mean": self.layer_mean,
1274 "fine_tune": self.fine_tune,
1275 "allow_long_sentences": self.allow_long_sentences,
1276 "memory_effective_training": self.memory_effective_training,
1277 "respect_document_boundaries": self.respect_document_boundaries,
1278 "context_dropout": self.context_dropout,
1279 }
1281 return model_state
1283 def __setstate__(self, d):
1284 self.__dict__ = d
1286 # necessary for reverse compatibility with Flair <= 0.7
1287 if 'use_scalar_mix' in self.__dict__.keys():
1288 self.__dict__['layer_mean'] = d['use_scalar_mix']
1289 if not 'memory_effective_training' in self.__dict__.keys():
1290 self.__dict__['memory_effective_training'] = True
1291 if 'pooling_operation' in self.__dict__.keys():
1292 self.__dict__['subtoken_pooling'] = d['pooling_operation']
1293 if not 'context_length' in self.__dict__.keys():
1294 self.__dict__['context_length'] = 0
1295 if 'use_context' in self.__dict__.keys():
1296 self.__dict__['context_length'] = 64 if self.__dict__['use_context'] == True else 0
1298 if not 'context_dropout' in self.__dict__.keys():
1299 self.__dict__['context_dropout'] = 0.5
1300 if not 'respect_document_boundaries' in self.__dict__.keys():
1301 self.__dict__['respect_document_boundaries'] = True
1302 if not 'memory_effective_training' in self.__dict__.keys():
1303 self.__dict__['memory_effective_training'] = True
1304 if not 'base_model_name' in self.__dict__.keys():
1305 self.__dict__['base_model_name'] = self.__dict__['name'].split('transformer-word-')[-1]
1307 # special handling for deserializing transformer models
1308 if "config_state_dict" in d:
1310 # load transformer model
1311 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert"
1312 config_class = CONFIG_MAPPING[model_type]
1313 loaded_config = config_class.from_dict(d["config_state_dict"])
1315 # constructor arguments
1316 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']])
1318 # re-initialize transformer word embeddings with constructor arguments
1319 embedding = TransformerWordEmbeddings(
1320 model=self.__dict__['base_model_name'],
1321 layers=layers,
1322 subtoken_pooling=self.__dict__['subtoken_pooling'],
1323 use_context=self.__dict__['context_length'],
1324 layer_mean=self.__dict__['layer_mean'],
1325 fine_tune=self.__dict__['fine_tune'],
1326 allow_long_sentences=self.__dict__['allow_long_sentences'],
1327 respect_document_boundaries=self.__dict__['respect_document_boundaries'],
1328 memory_effective_training=self.__dict__['memory_effective_training'],
1329 context_dropout=self.__dict__['context_dropout'],
1331 config=loaded_config,
1332 state_dict=d["model_state_dict"],
1333 )
1335 # I have no idea why this is necessary, but otherwise it doesn't work
1336 for key in embedding.__dict__.keys():
1337 self.__dict__[key] = embedding.__dict__[key]
1339 else:
1341 # reload tokenizer to get around serialization issues
1342 model_name = self.__dict__['name'].split('transformer-word-')[-1]
1343 try:
1344 tokenizer = AutoTokenizer.from_pretrained(model_name)
1345 except:
1346 pass
1348 self.tokenizer = tokenizer
1351class FastTextEmbeddings(TokenEmbeddings):
1352 """FastText Embeddings with oov functionality"""
1354 def __init__(self, embeddings: str, use_local: bool = True, field: str = None):
1355 """
1356 Initializes fasttext word embeddings. Constructor downloads required embedding file and stores in cache
1357 if use_local is False.
1359 :param embeddings: path to your embeddings '.bin' file
1360 :param use_local: set this to False if you are using embeddings from a remote source
1361 """
1362 self.instance_parameters = self.get_instance_parameters(locals=locals())
1364 cache_dir = Path("embeddings")
1366 if use_local:
1367 if not Path(embeddings).exists():
1368 raise ValueError(
1369 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
1370 )
1371 else:
1372 embeddings = cached_path(f"{embeddings}", cache_dir=cache_dir)
1374 self.embeddings = embeddings
1376 self.name: str = str(embeddings)
1378 self.static_embeddings = True
1380 self.precomputed_word_embeddings = gensim.models.FastText.load_fasttext_format(
1381 str(embeddings)
1382 )
1384 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
1386 self.field = field
1387 super().__init__()
1389 @property
1390 def embedding_length(self) -> int:
1391 return self.__embedding_length
1393 @instance_lru_cache(maxsize=10000, typed=False)
1394 def get_cached_vec(self, word: str) -> torch.Tensor:
1395 try:
1396 word_embedding = self.precomputed_word_embeddings[word]
1397 except:
1398 word_embedding = np.zeros(self.embedding_length, dtype="float")
1400 word_embedding = torch.tensor(
1401 word_embedding.tolist(), device=flair.device, dtype=torch.float
1402 )
1403 return word_embedding
1405 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1407 for i, sentence in enumerate(sentences):
1409 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1411 if "field" not in self.__dict__ or self.field is None:
1412 word = token.text
1413 else:
1414 word = token.get_tag(self.field).value
1416 word_embedding = self.get_cached_vec(word)
1418 token.set_embedding(self.name, word_embedding)
1420 return sentences
1422 def __str__(self):
1423 return self.name
1425 def extra_repr(self):
1426 return f"'{self.embeddings}'"
1429class OneHotEmbeddings(TokenEmbeddings):
1430 """One-hot encoded embeddings. """
1432 def __init__(
1433 self,
1434 corpus: Corpus,
1435 field: str = "text",
1436 embedding_length: int = 300,
1437 min_freq: int = 3,
1438 ):
1439 """
1440 Initializes one-hot encoded word embeddings and a trainable embedding layer
1441 :param corpus: you need to pass a Corpus in order to construct the vocabulary
1442 :param field: by default, the 'text' of tokens is embedded, but you can also embed tags such as 'pos'
1443 :param embedding_length: dimensionality of the trainable embedding layer
1444 :param min_freq: minimum frequency of a word to become part of the vocabulary
1445 """
1446 super().__init__()
1447 self.name = "one-hot"
1448 self.static_embeddings = False
1449 self.min_freq = min_freq
1450 self.field = field
1451 self.instance_parameters = self.get_instance_parameters(locals=locals())
1453 tokens = list(map((lambda s: s.tokens), corpus.train))
1454 tokens = [token for sublist in tokens for token in sublist]
1456 if field == "text":
1457 most_common = Counter(list(map((lambda t: t.text), tokens))).most_common()
1458 else:
1459 most_common = Counter(
1460 list(map((lambda t: t.get_tag(field).value), tokens))
1461 ).most_common()
1463 tokens = []
1464 for token, freq in most_common:
1465 if freq < min_freq:
1466 break
1467 tokens.append(token)
1469 self.vocab_dictionary: Dictionary = Dictionary()
1470 for token in tokens:
1471 self.vocab_dictionary.add_item(token)
1473 # max_tokens = 500
1474 self.__embedding_length = embedding_length
1476 print(self.vocab_dictionary.idx2item)
1477 print(f"vocabulary size of {len(self.vocab_dictionary)}")
1479 # model architecture
1480 self.embedding_layer = torch.nn.Embedding(
1481 len(self.vocab_dictionary), self.__embedding_length
1482 )
1483 torch.nn.init.xavier_uniform_(self.embedding_layer.weight)
1485 self.to(flair.device)
1487 @property
1488 def embedding_length(self) -> int:
1489 return self.__embedding_length
1491 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1493 one_hot_sentences = []
1494 for i, sentence in enumerate(sentences):
1496 if self.field == "text":
1497 context_idxs = [
1498 self.vocab_dictionary.get_idx_for_item(t.text)
1499 for t in sentence.tokens
1500 ]
1501 else:
1502 context_idxs = [
1503 self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value)
1504 for t in sentence.tokens
1505 ]
1507 one_hot_sentences.extend(context_idxs)
1509 one_hot_sentences = torch.tensor(one_hot_sentences, dtype=torch.long).to(
1510 flair.device
1511 )
1513 embedded = self.embedding_layer.forward(one_hot_sentences)
1515 index = 0
1516 for sentence in sentences:
1517 for token in sentence:
1518 embedding = embedded[index]
1519 token.set_embedding(self.name, embedding)
1520 index += 1
1522 return sentences
1524 def __str__(self):
1525 return self.name
1527 def extra_repr(self):
1528 return "min_freq={}".format(self.min_freq)
1531class HashEmbeddings(TokenEmbeddings):
1532 """Standard embeddings with Hashing Trick."""
1534 def __init__(
1535 self, num_embeddings: int = 1000, embedding_length: int = 300, hash_method="md5"
1536 ):
1538 super().__init__()
1539 self.name = "hash"
1540 self.static_embeddings = False
1541 self.instance_parameters = self.get_instance_parameters(locals=locals())
1543 self.__num_embeddings = num_embeddings
1544 self.__embedding_length = embedding_length
1546 self.__hash_method = hash_method
1548 # model architecture
1549 self.embedding_layer = torch.nn.Embedding(
1550 self.__num_embeddings, self.__embedding_length
1551 )
1552 torch.nn.init.xavier_uniform_(self.embedding_layer.weight)
1554 self.to(flair.device)
1556 @property
1557 def num_embeddings(self) -> int:
1558 return self.__num_embeddings
1560 @property
1561 def embedding_length(self) -> int:
1562 return self.__embedding_length
1564 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1565 def get_idx_for_item(text):
1566 hash_function = hashlib.new(self.__hash_method)
1567 hash_function.update(bytes(str(text), "utf-8"))
1568 return int(hash_function.hexdigest(), 16) % self.__num_embeddings
1570 hash_sentences = []
1571 for i, sentence in enumerate(sentences):
1572 context_idxs = [get_idx_for_item(t.text) for t in sentence.tokens]
1574 hash_sentences.extend(context_idxs)
1576 hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(flair.device)
1578 embedded = self.embedding_layer.forward(hash_sentences)
1580 index = 0
1581 for sentence in sentences:
1582 for token in sentence:
1583 embedding = embedded[index]
1584 token.set_embedding(self.name, embedding)
1585 index += 1
1587 return sentences
1589 def __str__(self):
1590 return self.name
1593class MuseCrosslingualEmbeddings(TokenEmbeddings):
1594 def __init__(self, ):
1595 self.name: str = f"muse-crosslingual"
1596 self.static_embeddings = True
1597 self.__embedding_length: int = 300
1598 self.language_embeddings = {}
1599 super().__init__()
1601 @instance_lru_cache(maxsize=10000, typed=False)
1602 def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor:
1603 current_embedding_model = self.language_embeddings[language_code]
1604 if word in current_embedding_model:
1605 word_embedding = current_embedding_model[word]
1606 elif word.lower() in current_embedding_model:
1607 word_embedding = current_embedding_model[word.lower()]
1608 elif re.sub(r"\d", "#", word.lower()) in current_embedding_model:
1609 word_embedding = current_embedding_model[re.sub(r"\d", "#", word.lower())]
1610 elif re.sub(r"\d", "0", word.lower()) in current_embedding_model:
1611 word_embedding = current_embedding_model[re.sub(r"\d", "0", word.lower())]
1612 else:
1613 word_embedding = np.zeros(self.embedding_length, dtype="float")
1614 word_embedding = torch.tensor(
1615 word_embedding, device=flair.device, dtype=torch.float
1616 )
1617 return word_embedding
1619 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1621 for i, sentence in enumerate(sentences):
1623 language_code = sentence.get_language_code()
1624 supported = [
1625 "en",
1626 "de",
1627 "bg",
1628 "ca",
1629 "hr",
1630 "cs",
1631 "da",
1632 "nl",
1633 "et",
1634 "fi",
1635 "fr",
1636 "el",
1637 "he",
1638 "hu",
1639 "id",
1640 "it",
1641 "mk",
1642 "no",
1643 "pl",
1644 "pt",
1645 "ro",
1646 "ru",
1647 "sk",
1648 ]
1649 if language_code not in supported:
1650 language_code = "en"
1652 if language_code not in self.language_embeddings:
1653 log.info(f"Loading up MUSE embeddings for '{language_code}'!")
1654 # download if necessary
1655 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/muse"
1656 cache_dir = Path("embeddings") / "MUSE"
1657 cached_path(
1658 f"{hu_path}/muse.{language_code}.vec.gensim.vectors.npy",
1659 cache_dir=cache_dir,
1660 )
1661 embeddings_file = cached_path(
1662 f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir
1663 )
1665 # load the model
1666 self.language_embeddings[
1667 language_code
1668 ] = gensim.models.KeyedVectors.load(str(embeddings_file))
1670 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1672 if "field" not in self.__dict__ or self.field is None:
1673 word = token.text
1674 else:
1675 word = token.get_tag(self.field).value
1677 word_embedding = self.get_cached_vec(
1678 language_code=language_code, word=word
1679 )
1681 token.set_embedding(self.name, word_embedding)
1683 return sentences
1685 @property
1686 def embedding_length(self) -> int:
1687 return self.__embedding_length
1689 def __str__(self):
1690 return self.name
1693# TODO: keep for backwards compatibility, but remove in future
1694class BPEmbSerializable(BPEmb):
1695 def __getstate__(self):
1696 state = self.__dict__.copy()
1697 # save the sentence piece model as binary file (not as path which may change)
1698 state["spm_model_binary"] = open(self.model_file, mode="rb").read()
1699 state["spm"] = None
1700 return state
1702 def __setstate__(self, state):
1703 from bpemb.util import sentencepiece_load
1705 model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"])
1706 self.__dict__ = state
1708 # write out the binary sentence piece model into the expected directory
1709 self.cache_dir: Path = flair.cache_root / "embeddings"
1710 if "spm_model_binary" in self.__dict__:
1711 # if the model was saved as binary and it is not found on disk, write to appropriate path
1712 if not os.path.exists(self.cache_dir / state["lang"]):
1713 os.makedirs(self.cache_dir / state["lang"])
1714 self.model_file = self.cache_dir / model_file
1715 with open(self.model_file, "wb") as out:
1716 out.write(self.__dict__["spm_model_binary"])
1717 else:
1718 # otherwise, use normal process and potentially trigger another download
1719 self.model_file = self._load_file(model_file)
1721 # once the modes if there, load it with sentence piece
1722 state["spm"] = sentencepiece_load(self.model_file)
1725class BytePairEmbeddings(TokenEmbeddings):
1726 def __init__(
1727 self,
1728 language: str = None,
1729 dim: int = 50,
1730 syllables: int = 100000,
1731 cache_dir=None,
1732 model_file_path: Path = None,
1733 embedding_file_path: Path = None,
1734 **kwargs,
1735 ):
1736 """
1737 Initializes BP embeddings. Constructor downloads required files if not there.
1738 """
1739 self.instance_parameters = self.get_instance_parameters(locals=locals())
1741 if not cache_dir:
1742 cache_dir = flair.cache_root / "embeddings"
1743 if language:
1744 self.name: str = f"bpe-{language}-{syllables}-{dim}"
1745 else:
1746 assert (
1747 model_file_path is not None and embedding_file_path is not None
1748 ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)"
1749 dim = None
1751 self.embedder = BPEmbSerializable(
1752 lang=language,
1753 vs=syllables,
1754 dim=dim,
1755 cache_dir=cache_dir,
1756 model_file=model_file_path,
1757 emb_file=embedding_file_path,
1758 **kwargs,
1759 )
1761 if not language:
1762 self.name: str = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}"
1763 self.static_embeddings = True
1765 self.__embedding_length: int = self.embedder.emb.vector_size * 2
1766 super().__init__()
1768 @property
1769 def embedding_length(self) -> int:
1770 return self.__embedding_length
1772 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1774 for i, sentence in enumerate(sentences):
1776 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1778 if "field" not in self.__dict__ or self.field is None:
1779 word = token.text
1780 else:
1781 word = token.get_tag(self.field).value
1783 if word.strip() == "":
1784 # empty words get no embedding
1785 token.set_embedding(
1786 self.name, torch.zeros(self.embedding_length, dtype=torch.float)
1787 )
1788 else:
1789 # all other words get embedded
1790 embeddings = self.embedder.embed(word.lower())
1791 embedding = np.concatenate(
1792 (embeddings[0], embeddings[len(embeddings) - 1])
1793 )
1794 token.set_embedding(
1795 self.name, torch.tensor(embedding, dtype=torch.float)
1796 )
1798 return sentences
1800 def __str__(self):
1801 return self.name
1803 def extra_repr(self):
1804 return "model={}".format(self.name)
1807class ELMoEmbeddings(TokenEmbeddings):
1808 """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018.
1809 ELMo word vectors can be constructed by combining layers in different ways.
1810 Default is to concatene the top 3 layers in the LM."""
1812 def __init__(
1813 self, model: str = "original", options_file: str = None, weight_file: str = None,
1814 embedding_mode: str = "all"
1815 ):
1816 super().__init__()
1818 self.instance_parameters = self.get_instance_parameters(locals=locals())
1820 try:
1821 import allennlp.commands.elmo
1822 except ModuleNotFoundError:
1823 log.warning("-" * 100)
1824 log.warning('ATTENTION! The library "allennlp" is not installed!')
1825 log.warning(
1826 'To use ELMoEmbeddings, please first install with "pip install allennlp==0.9.0"'
1827 )
1828 log.warning("-" * 100)
1829 pass
1831 assert embedding_mode in ["all", "top", "average"]
1833 self.name = f"elmo-{model}-{embedding_mode}"
1834 self.static_embeddings = True
1836 if not options_file or not weight_file:
1837 # the default model for ELMo is the 'original' model, which is very large
1838 options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE
1839 weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE
1840 # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name
1841 if model == "small":
1842 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json"
1843 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5"
1844 if model == "medium":
1845 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json"
1846 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5"
1847 if model in ["large", "5.5B"]:
1848 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json"
1849 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5"
1850 if model == "pt" or model == "portuguese":
1851 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json"
1852 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5"
1853 if model == "pubmed":
1854 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json"
1855 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5"
1857 if embedding_mode == "all":
1858 self.embedding_mode_fn = self.use_layers_all
1859 elif embedding_mode == "top":
1860 self.embedding_mode_fn = self.use_layers_top
1861 elif embedding_mode == "average":
1862 self.embedding_mode_fn = self.use_layers_average
1864 # put on Cuda if available
1865 from flair import device
1867 if re.fullmatch(r"cuda:[0-9]+", str(device)):
1868 cuda_device = int(str(device).split(":")[-1])
1869 elif str(device) == "cpu":
1870 cuda_device = -1
1871 else:
1872 cuda_device = 0
1874 self.ee = allennlp.commands.elmo.ElmoEmbedder(
1875 options_file=options_file, weight_file=weight_file, cuda_device=cuda_device
1876 )
1878 # embed a dummy sentence to determine embedding_length
1879 dummy_sentence: Sentence = Sentence()
1880 dummy_sentence.add_token(Token("hello"))
1881 embedded_dummy = self.embed(dummy_sentence)
1882 self.__embedding_length: int = len(
1883 embedded_dummy[0].get_token(1).get_embedding()
1884 )
1886 @property
1887 def embedding_length(self) -> int:
1888 return self.__embedding_length
1890 def use_layers_all(self, x):
1891 return torch.cat(x, 0)
1893 def use_layers_top(self, x):
1894 return x[-1]
1896 def use_layers_average(self, x):
1897 return torch.mean(torch.stack(x), 0)
1899 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
1900 # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn
1901 if not getattr(self, "embedding_mode_fn", None):
1902 self.embedding_mode_fn = self.use_layers_all
1904 sentence_words: List[List[str]] = []
1905 for sentence in sentences:
1906 sentence_words.append([token.text for token in sentence])
1908 embeddings = self.ee.embed_batch(sentence_words)
1910 for i, sentence in enumerate(sentences):
1912 sentence_embeddings = embeddings[i]
1914 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
1915 elmo_embedding_layers = [
1916 torch.FloatTensor(sentence_embeddings[0, token_idx, :]),
1917 torch.FloatTensor(sentence_embeddings[1, token_idx, :]),
1918 torch.FloatTensor(sentence_embeddings[2, token_idx, :])
1919 ]
1920 word_embedding = self.embedding_mode_fn(elmo_embedding_layers)
1921 token.set_embedding(self.name, word_embedding)
1923 return sentences
1925 def extra_repr(self):
1926 return "model={}".format(self.name)
1928 def __str__(self):
1929 return self.name
1931 def __setstate__(self, state):
1932 self.__dict__ = state
1934 if re.fullmatch(r"cuda:[0-9]+", str(flair.device)):
1935 cuda_device = int(str(flair.device).split(":")[-1])
1936 elif str(flair.device) == "cpu":
1937 cuda_device = -1
1938 else:
1939 cuda_device = 0
1941 self.ee.cuda_device = cuda_device
1943 self.ee.elmo_bilm.to(device=flair.device)
1944 self.ee.elmo_bilm._elmo_lstm._states = tuple(
1945 [state.to(flair.device) for state in self.ee.elmo_bilm._elmo_lstm._states])
1948class NILCEmbeddings(WordEmbeddings):
1949 def __init__(self, embeddings: str, model: str = "skip", size: int = 100):
1950 """
1951 Initializes portuguese classic word embeddings trained by NILC Lab (http://www.nilc.icmc.usp.br/embeddings).
1952 Constructor downloads required files if not there.
1953 :param embeddings: one of: 'fasttext', 'glove', 'wang2vec' or 'word2vec'
1954 :param model: one of: 'skip' or 'cbow'. This is not applicable to glove.
1955 :param size: one of: 50, 100, 300, 600 or 1000.
1956 """
1958 self.instance_parameters = self.get_instance_parameters(locals=locals())
1960 base_path = "http://143.107.183.175:22980/download.php?file=embeddings/"
1962 cache_dir = Path("embeddings") / embeddings.lower()
1964 # GLOVE embeddings
1965 if embeddings.lower() == "glove":
1966 cached_path(
1967 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir
1968 )
1969 embeddings = cached_path(
1970 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir
1971 )
1973 elif embeddings.lower() in ["fasttext", "wang2vec", "word2vec"]:
1974 cached_path(
1975 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir
1976 )
1977 embeddings = cached_path(
1978 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir
1979 )
1981 elif not Path(embeddings).exists():
1982 raise ValueError(
1983 f'The given embeddings "{embeddings}" is not available or is not a valid path.'
1984 )
1986 self.name: str = str(embeddings)
1987 self.static_embeddings = True
1989 log.info("Reading embeddings from %s" % embeddings)
1990 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
1991 open_inside_zip(str(embeddings), cache_dir=cache_dir)
1992 )
1994 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
1995 super(TokenEmbeddings, self).__init__()
1997 @property
1998 def embedding_length(self) -> int:
1999 return self.__embedding_length
2001 def __str__(self):
2002 return self.name
2005def replace_with_language_code(string: str):
2006 string = string.replace("arabic-", "ar-")
2007 string = string.replace("basque-", "eu-")
2008 string = string.replace("bulgarian-", "bg-")
2009 string = string.replace("croatian-", "hr-")
2010 string = string.replace("czech-", "cs-")
2011 string = string.replace("danish-", "da-")
2012 string = string.replace("dutch-", "nl-")
2013 string = string.replace("farsi-", "fa-")
2014 string = string.replace("persian-", "fa-")
2015 string = string.replace("finnish-", "fi-")
2016 string = string.replace("french-", "fr-")
2017 string = string.replace("german-", "de-")
2018 string = string.replace("hebrew-", "he-")
2019 string = string.replace("hindi-", "hi-")
2020 string = string.replace("indonesian-", "id-")
2021 string = string.replace("italian-", "it-")
2022 string = string.replace("japanese-", "ja-")
2023 string = string.replace("norwegian-", "no")
2024 string = string.replace("polish-", "pl-")
2025 string = string.replace("portuguese-", "pt-")
2026 string = string.replace("slovenian-", "sl-")
2027 string = string.replace("spanish-", "es-")
2028 string = string.replace("swedish-", "sv-")
2029 return string