Coverage for flair/flair/embeddings/token.py: 14%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1068 statements  

1import hashlib 

2import inspect 

3import logging 

4import os 

5import re 

6from abc import abstractmethod 

7from collections import Counter 

8from pathlib import Path 

9from typing import List, Union, Dict, Optional 

10 

11import gensim 

12import numpy as np 

13import torch 

14from bpemb import BPEmb 

15from gensim.models import KeyedVectors 

16from torch import nn 

17from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer, XLNetModel, \ 

18 TransfoXLModel 

19 

20import flair 

21from flair.data import Sentence, Token, Corpus, Dictionary 

22from flair.embeddings.base import Embeddings 

23from flair.file_utils import cached_path, open_inside_zip, instance_lru_cache 

24 

25log = logging.getLogger("flair") 

26 

27 

28class TokenEmbeddings(Embeddings): 

29 """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods.""" 

30 

31 @property 

32 @abstractmethod 

33 def embedding_length(self) -> int: 

34 """Returns the length of the embedding vector.""" 

35 pass 

36 

37 @property 

38 def embedding_type(self) -> str: 

39 return "word-level" 

40 

41 @staticmethod 

42 def get_instance_parameters(locals: dict) -> dict: 

43 class_definition = locals.get("__class__") 

44 instance_parameters = set(inspect.getfullargspec(class_definition.__init__).args) 

45 instance_parameters.difference_update(set(["self"])) 

46 instance_parameters.update(set(["__class__"])) 

47 instance_parameters = {class_attribute: attribute_value for class_attribute, attribute_value in locals.items() 

48 if class_attribute in instance_parameters} 

49 return instance_parameters 

50 

51 

52class StackedEmbeddings(TokenEmbeddings): 

53 """A stack of embeddings, used if you need to combine several different embedding types.""" 

54 

55 def __init__(self, embeddings: List[TokenEmbeddings]): 

56 """The constructor takes a list of embeddings to be combined.""" 

57 super().__init__() 

58 

59 self.embeddings = embeddings 

60 

61 # IMPORTANT: add embeddings as torch modules 

62 for i, embedding in enumerate(embeddings): 

63 embedding.name = f"{str(i)}-{embedding.name}" 

64 self.add_module(f"list_embedding_{str(i)}", embedding) 

65 

66 self.name: str = "Stack" 

67 self.static_embeddings: bool = True 

68 

69 self.__embedding_type: str = embeddings[0].embedding_type 

70 

71 self.__embedding_length: int = 0 

72 for embedding in embeddings: 

73 self.__embedding_length += embedding.embedding_length 

74 

75 def embed( 

76 self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True 

77 ): 

78 # if only one sentence is passed, convert to list of sentence 

79 if type(sentences) is Sentence: 

80 sentences = [sentences] 

81 

82 for embedding in self.embeddings: 

83 embedding.embed(sentences) 

84 

85 @property 

86 def embedding_type(self) -> str: 

87 return self.__embedding_type 

88 

89 @property 

90 def embedding_length(self) -> int: 

91 return self.__embedding_length 

92 

93 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

94 

95 for embedding in self.embeddings: 

96 embedding._add_embeddings_internal(sentences) 

97 

98 return sentences 

99 

100 def __str__(self): 

101 return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]' 

102 

103 def get_names(self) -> List[str]: 

104 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of 

105 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding). 

106 Then, the list contains the names of all embeddings in the stack.""" 

107 names = [] 

108 for embedding in self.embeddings: 

109 names.extend(embedding.get_names()) 

110 return names 

111 

112 def get_named_embeddings_dict(self) -> Dict: 

113 

114 named_embeddings_dict = {} 

115 for embedding in self.embeddings: 

116 named_embeddings_dict.update(embedding.get_named_embeddings_dict()) 

117 

118 return named_embeddings_dict 

119 

120 

121class WordEmbeddings(TokenEmbeddings): 

122 """Standard static word embeddings, such as GloVe or FastText.""" 

123 

124 def __init__(self, embeddings: str, field: str = None, fine_tune: bool = False, force_cpu: bool = True, 

125 stable: bool = False): 

126 """ 

127 Initializes classic word embeddings. Constructor downloads required files if not there. 

128 :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or custom 

129 If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable. 

130 set stable=True to use the stable embeddings as described in https://arxiv.org/abs/2110.02861 

131 """ 

132 self.embeddings = embeddings 

133 

134 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

135 

136 if fine_tune and force_cpu and flair.device.type != "cpu": 

137 raise ValueError("Cannot train WordEmbeddings on cpu if the model is trained on gpu, set force_cpu=False") 

138 

139 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/token" 

140 

141 cache_dir = Path("embeddings") 

142 

143 # GLOVE embeddings 

144 if embeddings.lower() == "glove" or embeddings.lower() == "en-glove": 

145 cached_path(f"{hu_path}/glove.gensim.vectors.npy", cache_dir=cache_dir) 

146 embeddings = cached_path(f"{hu_path}/glove.gensim", cache_dir=cache_dir) 

147 

148 # TURIAN embeddings 

149 elif embeddings.lower() == "turian" or embeddings.lower() == "en-turian": 

150 cached_path(f"{hu_path}/turian.vectors.npy", cache_dir=cache_dir) 

151 embeddings = cached_path(f"{hu_path}/turian", cache_dir=cache_dir) 

152 

153 # KOMNINOS embeddings 

154 elif embeddings.lower() == "extvec" or embeddings.lower() == "en-extvec": 

155 cached_path(f"{hu_path}/extvec.gensim.vectors.npy", cache_dir=cache_dir) 

156 embeddings = cached_path(f"{hu_path}/extvec.gensim", cache_dir=cache_dir) 

157 

158 # pubmed embeddings 

159 elif embeddings.lower() == "pubmed" or embeddings.lower() == "en-pubmed": 

160 cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim.vectors.npy", cache_dir=cache_dir) 

161 embeddings = cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim", cache_dir=cache_dir) 

162 

163 # FT-CRAWL embeddings 

164 elif embeddings.lower() == "crawl" or embeddings.lower() == "en-crawl": 

165 cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M.vectors.npy", cache_dir=cache_dir) 

166 embeddings = cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M", cache_dir=cache_dir) 

167 

168 # FT-CRAWL embeddings 

169 elif embeddings.lower() in ["news", "en-news", "en"]: 

170 cached_path(f"{hu_path}/en-fasttext-news-300d-1M.vectors.npy", cache_dir=cache_dir) 

171 embeddings = cached_path(f"{hu_path}/en-fasttext-news-300d-1M", cache_dir=cache_dir) 

172 

173 # twitter embeddings 

174 elif embeddings.lower() in ["twitter", "en-twitter"]: 

175 cached_path(f"{hu_path}/twitter.gensim.vectors.npy", cache_dir=cache_dir) 

176 embeddings = cached_path(f"{hu_path}/twitter.gensim", cache_dir=cache_dir) 

177 

178 # two-letter language code wiki embeddings 

179 elif len(embeddings.lower()) == 2: 

180 cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

181 embeddings = cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M", cache_dir=cache_dir) 

182 

183 # two-letter language code wiki embeddings 

184 elif len(embeddings.lower()) == 7 and embeddings.endswith("-wiki"): 

185 cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

186 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M", cache_dir=cache_dir) 

187 

188 # two-letter language code crawl embeddings 

189 elif len(embeddings.lower()) == 8 and embeddings.endswith("-crawl"): 

190 cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

191 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M", cache_dir=cache_dir) 

192 

193 elif not Path(embeddings).exists(): 

194 raise ValueError( 

195 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

196 ) 

197 

198 self.name: str = str(embeddings) 

199 self.static_embeddings = not fine_tune 

200 self.fine_tune = fine_tune 

201 self.force_cpu = force_cpu 

202 self.field = field 

203 self.stable = stable 

204 super().__init__() 

205 

206 if str(embeddings).endswith(".bin"): 

207 precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( 

208 str(embeddings), binary=True 

209 ) 

210 else: 

211 precomputed_word_embeddings = gensim.models.KeyedVectors.load( 

212 str(embeddings) 

213 ) 

214 

215 self.__embedding_length: int = precomputed_word_embeddings.vector_size 

216 

217 vectors = np.row_stack( 

218 (precomputed_word_embeddings.vectors, np.zeros(self.__embedding_length, dtype="float")) 

219 ) 

220 self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not fine_tune) 

221 

222 try: 

223 # gensim version 4 

224 self.vocab = precomputed_word_embeddings.key_to_index 

225 except: 

226 # gensim version 3 

227 self.vocab = {k: v.index for k, v in precomputed_word_embeddings.vocab.items()} 

228 

229 if stable: 

230 self.layer_norm = nn.LayerNorm(self.__embedding_length, elementwise_affine=fine_tune) 

231 else: 

232 self.layer_norm = None 

233 

234 self.device = None 

235 self.to(flair.device) 

236 

237 @property 

238 def embedding_length(self) -> int: 

239 return self.__embedding_length 

240 

241 @instance_lru_cache(maxsize=100000, typed=False) 

242 def get_cached_token_index(self, word: str) -> int: 

243 if word in self.vocab: 

244 return self.vocab[word] 

245 elif word.lower() in self.vocab: 

246 return self.vocab[word.lower()] 

247 elif re.sub(r"\d", "#", word.lower()) in self.vocab: 

248 return self.vocab[ 

249 re.sub(r"\d", "#", word.lower()) 

250 ] 

251 elif re.sub(r"\d", "0", word.lower()) in self.vocab: 

252 return self.vocab[ 

253 re.sub(r"\d", "0", word.lower()) 

254 ] 

255 else: 

256 return len(self.vocab) # <unk> token 

257 

258 def get_vec(self, word: str) -> torch.Tensor: 

259 word_embedding = self.vectors[self.get_cached_token_index(word)] 

260 

261 word_embedding = torch.tensor( 

262 word_embedding.tolist(), device=flair.device, dtype=torch.float 

263 ) 

264 return word_embedding 

265 

266 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

267 

268 tokens = [token for sentence in sentences for token in sentence.tokens] 

269 

270 word_indices: List[int] = [] 

271 for token in tokens: 

272 if "field" not in self.__dict__ or self.field is None: 

273 word = token.text 

274 else: 

275 word = token.get_tag(self.field).value 

276 word_indices.append(self.get_cached_token_index(word)) 

277 

278 embeddings = self.embedding(torch.tensor(word_indices, dtype=torch.long, device=self.device)) 

279 if self.stable: 

280 embeddings = self.layer_norm(embeddings) 

281 

282 if self.force_cpu: 

283 embeddings = embeddings.to(flair.device) 

284 

285 for emb, token in zip(embeddings, tokens): 

286 token.set_embedding(self.name, emb) 

287 

288 return sentences 

289 

290 def __str__(self): 

291 return self.name 

292 

293 def extra_repr(self): 

294 # fix serialized models 

295 if "embeddings" not in self.__dict__: 

296 self.embeddings = self.name 

297 

298 return f"'{self.embeddings}'" 

299 

300 def train(self, mode=True): 

301 if not self.fine_tune: 

302 pass 

303 else: 

304 super(WordEmbeddings, self).train(mode) 

305 

306 def to(self, device): 

307 if self.force_cpu: 

308 device = torch.device("cpu") 

309 self.device = device 

310 super(WordEmbeddings, self).to(device) 

311 

312 def _apply(self, fn): 

313 if fn.__name__ == "convert" and self.force_cpu: 

314 # this is required to force the module on the cpu, 

315 # if a parent module is put to gpu, the _apply is called to each sub_module 

316 # self.to(..) actually sets the device properly 

317 if not hasattr(self, "device"): 

318 self.to(flair.device) 

319 return 

320 super(WordEmbeddings, self)._apply(fn) 

321 

322 def __getattribute__(self, item): 

323 # this ignores the get_cached_vec method when loading older versions 

324 # it is needed for compatibility reasons 

325 if "get_cached_vec" == item: 

326 return None 

327 return super().__getattribute__(item) 

328 

329 def __setstate__(self, state): 

330 if "get_cached_vec" in state: 

331 del state["get_cached_vec"] 

332 if "force_cpu" not in state: 

333 state["force_cpu"] = True 

334 if "fine_tune" not in state: 

335 state["fine_tune"] = False 

336 if "precomputed_word_embeddings" in state: 

337 precomputed_word_embeddings: KeyedVectors = state.pop("precomputed_word_embeddings") 

338 vectors = np.row_stack( 

339 (precomputed_word_embeddings.vectors, np.zeros(precomputed_word_embeddings.vector_size, dtype="float")) 

340 ) 

341 embedding = nn.Embedding.from_pretrained(torch.FloatTensor(vectors), freeze=not state["fine_tune"]) 

342 

343 try: 

344 # gensim version 4 

345 vocab = precomputed_word_embeddings.key_to_index 

346 except: 

347 # gensim version 3 

348 vocab = {k: v.index for k, v in precomputed_word_embeddings.__dict__["vocab"].items()} 

349 state["embedding"] = embedding 

350 state["vocab"] = vocab 

351 if "stable" not in state: 

352 state["stable"] = False 

353 state["layer_norm"] = None 

354 

355 super().__setstate__(state) 

356 

357 

358class CharacterEmbeddings(TokenEmbeddings): 

359 """Character embeddings of words, as proposed in Lample et al., 2016.""" 

360 

361 def __init__( 

362 self, 

363 path_to_char_dict: str = None, 

364 char_embedding_dim: int = 25, 

365 hidden_size_char: int = 25, 

366 ): 

367 """Uses the default character dictionary if none provided.""" 

368 

369 super().__init__() 

370 self.name = "Char" 

371 self.static_embeddings = False 

372 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

373 

374 # use list of common characters if none provided 

375 if path_to_char_dict is None: 

376 self.char_dictionary: Dictionary = Dictionary.load("common-chars") 

377 else: 

378 self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict) 

379 

380 self.char_embedding_dim: int = char_embedding_dim 

381 self.hidden_size_char: int = hidden_size_char 

382 self.char_embedding = torch.nn.Embedding( 

383 len(self.char_dictionary.item2idx), self.char_embedding_dim 

384 ) 

385 self.char_rnn = torch.nn.LSTM( 

386 self.char_embedding_dim, 

387 self.hidden_size_char, 

388 num_layers=1, 

389 bidirectional=True, 

390 ) 

391 

392 self.__embedding_length = self.hidden_size_char * 2 

393 

394 self.to(flair.device) 

395 

396 @property 

397 def embedding_length(self) -> int: 

398 return self.__embedding_length 

399 

400 def _add_embeddings_internal(self, sentences: List[Sentence]): 

401 

402 for sentence in sentences: 

403 

404 tokens_char_indices = [] 

405 

406 # translate words in sentence into ints using dictionary 

407 for token in sentence.tokens: 

408 char_indices = [ 

409 self.char_dictionary.get_idx_for_item(char) for char in token.text 

410 ] 

411 tokens_char_indices.append(char_indices) 

412 

413 # sort words by length, for batching and masking 

414 tokens_sorted_by_length = sorted( 

415 tokens_char_indices, key=lambda p: len(p), reverse=True 

416 ) 

417 d = {} 

418 for i, ci in enumerate(tokens_char_indices): 

419 for j, cj in enumerate(tokens_sorted_by_length): 

420 if ci == cj: 

421 d[j] = i 

422 continue 

423 chars2_length = [len(c) for c in tokens_sorted_by_length] 

424 longest_token_in_sentence = max(chars2_length) 

425 tokens_mask = torch.zeros( 

426 (len(tokens_sorted_by_length), longest_token_in_sentence), 

427 dtype=torch.long, 

428 device=flair.device, 

429 ) 

430 

431 for i, c in enumerate(tokens_sorted_by_length): 

432 tokens_mask[i, : chars2_length[i]] = torch.tensor( 

433 c, dtype=torch.long, device=flair.device 

434 ) 

435 

436 # chars for rnn processing 

437 chars = tokens_mask 

438 

439 character_embeddings = self.char_embedding(chars).transpose(0, 1) 

440 

441 packed = torch.nn.utils.rnn.pack_padded_sequence( 

442 character_embeddings, chars2_length 

443 ) 

444 

445 lstm_out, self.hidden = self.char_rnn(packed) 

446 

447 outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out) 

448 outputs = outputs.transpose(0, 1) 

449 chars_embeds_temp = torch.zeros( 

450 (outputs.size(0), outputs.size(2)), 

451 dtype=torch.float, 

452 device=flair.device, 

453 ) 

454 for i, index in enumerate(output_lengths): 

455 chars_embeds_temp[i] = outputs[i, index - 1] 

456 character_embeddings = chars_embeds_temp.clone() 

457 for i in range(character_embeddings.size(0)): 

458 character_embeddings[d[i]] = chars_embeds_temp[i] 

459 

460 for token_number, token in enumerate(sentence.tokens): 

461 token.set_embedding(self.name, character_embeddings[token_number]) 

462 

463 def __str__(self): 

464 return self.name 

465 

466 

467class FlairEmbeddings(TokenEmbeddings): 

468 """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" 

469 

470 def __init__(self, 

471 model, 

472 fine_tune: bool = False, 

473 chars_per_chunk: int = 512, 

474 with_whitespace: bool = True, 

475 tokenized_lm: bool = True, 

476 is_lower: bool = False, 

477 ): 

478 """ 

479 initializes contextual string embeddings using a character-level language model. 

480 :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast', 

481 'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward', 

482 etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md) 

483 depending on which character language model is desired. 

484 :param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows 

485 down training and often leads to overfitting, so use with caution. 

486 :param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster 

487 but requires more memory. Lower means slower but less memory. 

488 :param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden 

489 state at last character of word. 

490 :param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text 

491 False might be better. 

492 """ 

493 super().__init__() 

494 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

495 

496 cache_dir = Path("embeddings") 

497 

498 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/flair" 

499 clef_hipe_path: str = "https://files.ifi.uzh.ch/cl/siclemat/impresso/clef-hipe-2020/flair" 

500 am_path : str = "http://ltdata1.informatik.uni-hamburg.de/amharic/models/flair/" 

501 

502 self.is_lower: bool = is_lower 

503 

504 self.PRETRAINED_MODEL_ARCHIVE_MAP = { 

505 # multilingual models 

506 "multi-forward": f"{hu_path}/lm-jw300-forward-v0.1.pt", 

507 "multi-backward": f"{hu_path}/lm-jw300-backward-v0.1.pt", 

508 "multi-v0-forward": f"{hu_path}/lm-multi-forward-v0.1.pt", 

509 "multi-v0-backward": f"{hu_path}/lm-multi-backward-v0.1.pt", 

510 "multi-forward-fast": f"{hu_path}/lm-multi-forward-fast-v0.1.pt", 

511 "multi-backward-fast": f"{hu_path}/lm-multi-backward-fast-v0.1.pt", 

512 # English models 

513 "en-forward": f"{hu_path}/news-forward-0.4.1.pt", 

514 "en-backward": f"{hu_path}/news-backward-0.4.1.pt", 

515 "en-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt", 

516 "en-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt", 

517 "news-forward": f"{hu_path}/news-forward-0.4.1.pt", 

518 "news-backward": f"{hu_path}/news-backward-0.4.1.pt", 

519 "news-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt", 

520 "news-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt", 

521 "mix-forward": f"{hu_path}/lm-mix-english-forward-v0.2rc.pt", 

522 "mix-backward": f"{hu_path}/lm-mix-english-backward-v0.2rc.pt", 

523 # Arabic 

524 "ar-forward": f"{hu_path}/lm-ar-opus-large-forward-v0.1.pt", 

525 "ar-backward": f"{hu_path}/lm-ar-opus-large-backward-v0.1.pt", 

526 # Bulgarian 

527 "bg-forward-fast": f"{hu_path}/lm-bg-small-forward-v0.1.pt", 

528 "bg-backward-fast": f"{hu_path}/lm-bg-small-backward-v0.1.pt", 

529 "bg-forward": f"{hu_path}/lm-bg-opus-large-forward-v0.1.pt", 

530 "bg-backward": f"{hu_path}/lm-bg-opus-large-backward-v0.1.pt", 

531 # Czech 

532 "cs-forward": f"{hu_path}/lm-cs-opus-large-forward-v0.1.pt", 

533 "cs-backward": f"{hu_path}/lm-cs-opus-large-backward-v0.1.pt", 

534 "cs-v0-forward": f"{hu_path}/lm-cs-large-forward-v0.1.pt", 

535 "cs-v0-backward": f"{hu_path}/lm-cs-large-backward-v0.1.pt", 

536 # Danish 

537 "da-forward": f"{hu_path}/lm-da-opus-large-forward-v0.1.pt", 

538 "da-backward": f"{hu_path}/lm-da-opus-large-backward-v0.1.pt", 

539 # German 

540 "de-forward": f"{hu_path}/lm-mix-german-forward-v0.2rc.pt", 

541 "de-backward": f"{hu_path}/lm-mix-german-backward-v0.2rc.pt", 

542 "de-historic-ha-forward": f"{hu_path}/lm-historic-hamburger-anzeiger-forward-v0.1.pt", 

543 "de-historic-ha-backward": f"{hu_path}/lm-historic-hamburger-anzeiger-backward-v0.1.pt", 

544 "de-historic-wz-forward": f"{hu_path}/lm-historic-wiener-zeitung-forward-v0.1.pt", 

545 "de-historic-wz-backward": f"{hu_path}/lm-historic-wiener-zeitung-backward-v0.1.pt", 

546 "de-historic-rw-forward": f"{hu_path}/redewiedergabe_lm_forward.pt", 

547 "de-historic-rw-backward": f"{hu_path}/redewiedergabe_lm_backward.pt", 

548 # Spanish 

549 "es-forward": f"{hu_path}/lm-es-forward.pt", 

550 "es-backward": f"{hu_path}/lm-es-backward.pt", 

551 "es-forward-fast": f"{hu_path}/lm-es-forward-fast.pt", 

552 "es-backward-fast": f"{hu_path}/lm-es-backward-fast.pt", 

553 # Basque 

554 "eu-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.2.pt", 

555 "eu-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.2.pt", 

556 "eu-v1-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.1.pt", 

557 "eu-v1-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.1.pt", 

558 "eu-v0-forward": f"{hu_path}/lm-eu-large-forward-v0.1.pt", 

559 "eu-v0-backward": f"{hu_path}/lm-eu-large-backward-v0.1.pt", 

560 # Persian 

561 "fa-forward": f"{hu_path}/lm-fa-opus-large-forward-v0.1.pt", 

562 "fa-backward": f"{hu_path}/lm-fa-opus-large-backward-v0.1.pt", 

563 # Finnish 

564 "fi-forward": f"{hu_path}/lm-fi-opus-large-forward-v0.1.pt", 

565 "fi-backward": f"{hu_path}/lm-fi-opus-large-backward-v0.1.pt", 

566 # French 

567 "fr-forward": f"{hu_path}/lm-fr-charlm-forward.pt", 

568 "fr-backward": f"{hu_path}/lm-fr-charlm-backward.pt", 

569 # Hebrew 

570 "he-forward": f"{hu_path}/lm-he-opus-large-forward-v0.1.pt", 

571 "he-backward": f"{hu_path}/lm-he-opus-large-backward-v0.1.pt", 

572 # Hindi 

573 "hi-forward": f"{hu_path}/lm-hi-opus-large-forward-v0.1.pt", 

574 "hi-backward": f"{hu_path}/lm-hi-opus-large-backward-v0.1.pt", 

575 # Croatian 

576 "hr-forward": f"{hu_path}/lm-hr-opus-large-forward-v0.1.pt", 

577 "hr-backward": f"{hu_path}/lm-hr-opus-large-backward-v0.1.pt", 

578 # Indonesian 

579 "id-forward": f"{hu_path}/lm-id-opus-large-forward-v0.1.pt", 

580 "id-backward": f"{hu_path}/lm-id-opus-large-backward-v0.1.pt", 

581 # Italian 

582 "it-forward": f"{hu_path}/lm-it-opus-large-forward-v0.1.pt", 

583 "it-backward": f"{hu_path}/lm-it-opus-large-backward-v0.1.pt", 

584 # Japanese 

585 "ja-forward": f"{hu_path}/japanese-forward.pt", 

586 "ja-backward": f"{hu_path}/japanese-backward.pt", 

587 # Malayalam 

588 "ml-forward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-forward.pt", 

589 "ml-backward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-backward.pt", 

590 # Dutch 

591 "nl-forward": f"{hu_path}/lm-nl-opus-large-forward-v0.1.pt", 

592 "nl-backward": f"{hu_path}/lm-nl-opus-large-backward-v0.1.pt", 

593 "nl-v0-forward": f"{hu_path}/lm-nl-large-forward-v0.1.pt", 

594 "nl-v0-backward": f"{hu_path}/lm-nl-large-backward-v0.1.pt", 

595 # Norwegian 

596 "no-forward": f"{hu_path}/lm-no-opus-large-forward-v0.1.pt", 

597 "no-backward": f"{hu_path}/lm-no-opus-large-backward-v0.1.pt", 

598 # Polish 

599 "pl-forward": f"{hu_path}/lm-polish-forward-v0.2.pt", 

600 "pl-backward": f"{hu_path}/lm-polish-backward-v0.2.pt", 

601 "pl-opus-forward": f"{hu_path}/lm-pl-opus-large-forward-v0.1.pt", 

602 "pl-opus-backward": f"{hu_path}/lm-pl-opus-large-backward-v0.1.pt", 

603 # Portuguese 

604 "pt-forward": f"{hu_path}/lm-pt-forward.pt", 

605 "pt-backward": f"{hu_path}/lm-pt-backward.pt", 

606 # Pubmed 

607 "pubmed-forward": f"{hu_path}/pubmed-forward.pt", 

608 "pubmed-backward": f"{hu_path}/pubmed-backward.pt", 

609 "pubmed-2015-forward": f"{hu_path}/pubmed-2015-fw-lm.pt", 

610 "pubmed-2015-backward": f"{hu_path}/pubmed-2015-bw-lm.pt", 

611 # Slovenian 

612 "sl-forward": f"{hu_path}/lm-sl-opus-large-forward-v0.1.pt", 

613 "sl-backward": f"{hu_path}/lm-sl-opus-large-backward-v0.1.pt", 

614 "sl-v0-forward": f"{hu_path}/lm-sl-large-forward-v0.1.pt", 

615 "sl-v0-backward": f"{hu_path}/lm-sl-large-backward-v0.1.pt", 

616 # Swedish 

617 "sv-forward": f"{hu_path}/lm-sv-opus-large-forward-v0.1.pt", 

618 "sv-backward": f"{hu_path}/lm-sv-opus-large-backward-v0.1.pt", 

619 "sv-v0-forward": f"{hu_path}/lm-sv-large-forward-v0.1.pt", 

620 "sv-v0-backward": f"{hu_path}/lm-sv-large-backward-v0.1.pt", 

621 # Tamil 

622 "ta-forward": f"{hu_path}/lm-ta-opus-large-forward-v0.1.pt", 

623 "ta-backward": f"{hu_path}/lm-ta-opus-large-backward-v0.1.pt", 

624 # Spanish clinical 

625 "es-clinical-forward": f"{hu_path}/es-clinical-forward.pt", 

626 "es-clinical-backward": f"{hu_path}/es-clinical-backward.pt", 

627 # CLEF HIPE Shared task 

628 "de-impresso-hipe-v1-forward": f"{clef_hipe_path}/de-hipe-flair-v1-forward/best-lm.pt", 

629 "de-impresso-hipe-v1-backward": f"{clef_hipe_path}/de-hipe-flair-v1-backward/best-lm.pt", 

630 "en-impresso-hipe-v1-forward": f"{clef_hipe_path}/en-flair-v1-forward/best-lm.pt", 

631 "en-impresso-hipe-v1-backward": f"{clef_hipe_path}/en-flair-v1-backward/best-lm.pt", 

632 "fr-impresso-hipe-v1-forward": f"{clef_hipe_path}/fr-hipe-flair-v1-forward/best-lm.pt", 

633 "fr-impresso-hipe-v1-backward": f"{clef_hipe_path}/fr-hipe-flair-v1-backward/best-lm.pt", 

634 # Amharic  

635 "am-forward": f"{am_path}/best-lm.pt", 

636 } 

637 

638 if type(model) == str: 

639 

640 # load model if in pretrained model map 

641 if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP: 

642 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()] 

643 

644 # Fix for CLEF HIPE models (avoid overwriting best-lm.pt in cache_dir) 

645 if "impresso-hipe" in model.lower(): 

646 cache_dir = cache_dir / model.lower() 

647 # CLEF HIPE models are lowercased 

648 self.is_lower = True 

649 model = cached_path(base_path, cache_dir=cache_dir) 

650 

651 elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP: 

652 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[ 

653 replace_with_language_code(model) 

654 ] 

655 model = cached_path(base_path, cache_dir=cache_dir) 

656 

657 elif not Path(model).exists(): 

658 raise ValueError( 

659 f'The given model "{model}" is not available or is not a valid path.' 

660 ) 

661 

662 from flair.models import LanguageModel 

663 

664 if type(model) == LanguageModel: 

665 self.lm: LanguageModel = model 

666 self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}" 

667 else: 

668 self.lm: LanguageModel = LanguageModel.load_language_model(model) 

669 self.name = str(model) 

670 

671 # embeddings are static if we don't do finetuning 

672 self.fine_tune = fine_tune 

673 self.static_embeddings = not fine_tune 

674 

675 self.is_forward_lm: bool = self.lm.is_forward_lm 

676 self.with_whitespace: bool = with_whitespace 

677 self.tokenized_lm: bool = tokenized_lm 

678 self.chars_per_chunk: int = chars_per_chunk 

679 

680 # embed a dummy sentence to determine embedding_length 

681 dummy_sentence: Sentence = Sentence() 

682 dummy_sentence.add_token(Token("hello")) 

683 embedded_dummy = self.embed(dummy_sentence) 

684 self.__embedding_length: int = len( 

685 embedded_dummy[0].get_token(1).get_embedding() 

686 ) 

687 

688 # set to eval mode 

689 self.eval() 

690 

691 def train(self, mode=True): 

692 

693 # make compatible with serialized models (TODO: remove) 

694 if "fine_tune" not in self.__dict__: 

695 self.fine_tune = False 

696 if "chars_per_chunk" not in self.__dict__: 

697 self.chars_per_chunk = 512 

698 

699 # unless fine-tuning is set, do not set language model to train() in order to disallow language model dropout 

700 if not self.fine_tune: 

701 pass 

702 else: 

703 super(FlairEmbeddings, self).train(mode) 

704 

705 @property 

706 def embedding_length(self) -> int: 

707 return self.__embedding_length 

708 

709 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

710 

711 # make compatible with serialized models (TODO: remove) 

712 if "with_whitespace" not in self.__dict__: 

713 self.with_whitespace = True 

714 if "tokenized_lm" not in self.__dict__: 

715 self.tokenized_lm = True 

716 if "is_lower" not in self.__dict__: 

717 self.is_lower = False 

718 

719 # gradients are enable if fine-tuning is enabled 

720 gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad() 

721 

722 with gradient_context: 

723 

724 # if this is not possible, use LM to generate embedding. First, get text sentences 

725 text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \ 

726 else [sentence.to_plain_string() for sentence in sentences] 

727 

728 if self.is_lower: 

729 text_sentences = [sentence.lower() for sentence in text_sentences] 

730 

731 start_marker = self.lm.document_delimiter if "document_delimiter" in self.lm.__dict__ else '\n' 

732 end_marker = " " 

733 

734 # get hidden states from language model 

735 all_hidden_states_in_lm = self.lm.get_representation( 

736 text_sentences, start_marker, end_marker, self.chars_per_chunk 

737 ) 

738 

739 if not self.fine_tune: 

740 all_hidden_states_in_lm = all_hidden_states_in_lm.detach() 

741 

742 # take first or last hidden states from language model as word representation 

743 for i, sentence in enumerate(sentences): 

744 sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string() 

745 

746 offset_forward: int = len(start_marker) 

747 offset_backward: int = len(sentence_text) + len(start_marker) 

748 

749 for token in sentence.tokens: 

750 

751 offset_forward += len(token.text) 

752 if self.is_forward_lm: 

753 offset_with_whitespace = offset_forward 

754 offset_without_whitespace = offset_forward - 1 

755 else: 

756 offset_with_whitespace = offset_backward 

757 offset_without_whitespace = offset_backward - 1 

758 

759 # offset mode that extracts at whitespace after last character 

760 if self.with_whitespace: 

761 embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :] 

762 # offset mode that extracts at last character 

763 else: 

764 embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :] 

765 

766 if self.tokenized_lm or token.whitespace_after: 

767 offset_forward += 1 

768 offset_backward -= 1 

769 

770 offset_backward -= len(token.text) 

771 

772 # only clone if optimization mode is 'gpu' 

773 if flair.embedding_storage_mode == "gpu": 

774 embedding = embedding.clone() 

775 

776 token.set_embedding(self.name, embedding) 

777 

778 del all_hidden_states_in_lm 

779 

780 return sentences 

781 

782 def __str__(self): 

783 return self.name 

784 

785 

786class PooledFlairEmbeddings(TokenEmbeddings): 

787 def __init__( 

788 self, 

789 contextual_embeddings: Union[str, FlairEmbeddings], 

790 pooling: str = "min", 

791 only_capitalized: bool = False, 

792 **kwargs, 

793 ): 

794 

795 super().__init__() 

796 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

797 

798 # use the character language model embeddings as basis 

799 if type(contextual_embeddings) is str: 

800 self.context_embeddings: FlairEmbeddings = FlairEmbeddings( 

801 contextual_embeddings, **kwargs 

802 ) 

803 else: 

804 self.context_embeddings: FlairEmbeddings = contextual_embeddings 

805 

806 # length is twice the original character LM embedding length 

807 self.embedding_length = self.context_embeddings.embedding_length * 2 

808 self.name = self.context_embeddings.name + "-context" 

809 

810 # these fields are for the embedding memory 

811 self.word_embeddings = {} 

812 self.word_count = {} 

813 

814 # whether to add only capitalized words to memory (faster runtime and lower memory consumption) 

815 self.only_capitalized = only_capitalized 

816 

817 # we re-compute embeddings dynamically at each epoch 

818 self.static_embeddings = False 

819 

820 # set the memory method 

821 self.pooling = pooling 

822 

823 def train(self, mode=True): 

824 super().train(mode=mode) 

825 if mode: 

826 # memory is wiped each time we do a training run 

827 print("train mode resetting embeddings") 

828 self.word_embeddings = {} 

829 self.word_count = {} 

830 

831 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

832 

833 self.context_embeddings.embed(sentences) 

834 

835 # if we keep a pooling, it needs to be updated continuously 

836 for sentence in sentences: 

837 for token in sentence.tokens: 

838 

839 # update embedding 

840 local_embedding = token._embeddings[self.context_embeddings.name].cpu() 

841 

842 # check token.text is empty or not 

843 if token.text: 

844 if token.text[0].isupper() or not self.only_capitalized: 

845 

846 if token.text not in self.word_embeddings: 

847 self.word_embeddings[token.text] = local_embedding 

848 self.word_count[token.text] = 1 

849 else: 

850 

851 # set aggregation operation 

852 if self.pooling == "mean": 

853 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding) 

854 elif self.pooling == "fade": 

855 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding) 

856 aggregated_embedding /= 2 

857 elif self.pooling == "max": 

858 aggregated_embedding = torch.max(self.word_embeddings[token.text], local_embedding) 

859 elif self.pooling == "min": 

860 aggregated_embedding = torch.min(self.word_embeddings[token.text], local_embedding) 

861 

862 self.word_embeddings[token.text] = aggregated_embedding 

863 self.word_count[token.text] += 1 

864 

865 # add embeddings after updating 

866 for sentence in sentences: 

867 for token in sentence.tokens: 

868 if token.text in self.word_embeddings: 

869 base = ( 

870 self.word_embeddings[token.text] / self.word_count[token.text] 

871 if self.pooling == "mean" 

872 else self.word_embeddings[token.text] 

873 ) 

874 else: 

875 base = token._embeddings[self.context_embeddings.name] 

876 

877 token.set_embedding(self.name, base) 

878 

879 return sentences 

880 

881 def embedding_length(self) -> int: 

882 return self.embedding_length 

883 

884 def get_names(self) -> List[str]: 

885 return [self.name, self.context_embeddings.name] 

886 

887 def __setstate__(self, d): 

888 self.__dict__ = d 

889 

890 if flair.device != 'cpu': 

891 for key in self.word_embeddings: 

892 self.word_embeddings[key] = self.word_embeddings[key].cpu() 

893 

894 

895class TransformerWordEmbeddings(TokenEmbeddings): 

896 NO_MAX_SEQ_LENGTH_MODELS = [XLNetModel, TransfoXLModel] 

897 

898 def __init__( 

899 self, 

900 model: str = "bert-base-uncased", 

901 layers: str = "all", 

902 subtoken_pooling: str = "first", 

903 layer_mean: bool = True, 

904 fine_tune: bool = False, 

905 allow_long_sentences: bool = True, 

906 use_context: Union[bool, int] = False, 

907 memory_effective_training: bool = True, 

908 respect_document_boundaries: bool = True, 

909 context_dropout: float = 0.5, 

910 **kwargs 

911 ): 

912 """ 

913 Bidirectional transformer embeddings of words from various transformer architectures. 

914 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for 

915 options) 

916 :param layers: string indicating which layers to take for embedding (-1 is topmost layer) 

917 :param subtoken_pooling: how to get from token piece embeddings to token embedding. Either take the first 

918 subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean') 

919 :param layer_mean: If True, uses a scalar mix of layers as embedding 

920 :param fine_tune: If True, allows transformers to be fine-tuned during training 

921 """ 

922 super().__init__() 

923 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

924 

925 # temporary fix to disable tokenizer parallelism warning 

926 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) 

927 import os 

928 os.environ["TOKENIZERS_PARALLELISM"] = "false" 

929 

930 # do not print transformer warnings as these are confusing in this case 

931 from transformers import logging 

932 logging.set_verbosity_error() 

933 

934 # load tokenizer and transformer model 

935 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) 

936 if self.tokenizer.model_max_length > 1000000000: 

937 self.tokenizer.model_max_length = 512 

938 log.info("No model_max_length in Tokenizer's config.json - setting it to 512. " 

939 "Specify desired model_max_length by passing it as attribute to embedding instance.") 

940 if not 'config' in kwargs: 

941 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) 

942 self.model = AutoModel.from_pretrained(model, config=config) 

943 else: 

944 self.model = AutoModel.from_pretrained(None, **kwargs) 

945 

946 logging.set_verbosity_warning() 

947 

948 if type(self.model) not in self.NO_MAX_SEQ_LENGTH_MODELS: 

949 self.allow_long_sentences = allow_long_sentences 

950 self.truncate = True 

951 self.max_subtokens_sequence_length = self.tokenizer.model_max_length 

952 self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0 

953 else: 

954 # in the end, these models don't need this configuration 

955 self.allow_long_sentences = False 

956 self.truncate = False 

957 self.max_subtokens_sequence_length = None 

958 self.stride = 0 

959 

960 self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb 

961 

962 # model name 

963 self.name = 'transformer-word-' + str(model) 

964 self.base_model = str(model) 

965 

966 # whether to detach gradients on overlong sentences 

967 self.memory_effective_training = memory_effective_training 

968 

969 # store whether to use context (and how much) 

970 if type(use_context) == bool: 

971 self.context_length: int = 64 if use_context else 0 

972 if type(use_context) == int: 

973 self.context_length: int = use_context 

974 

975 # dropout contexts 

976 self.context_dropout = context_dropout 

977 

978 # if using context, can we cross document boundaries? 

979 self.respect_document_boundaries = respect_document_boundaries 

980 

981 # send self to flair-device 

982 self.to(flair.device) 

983 

984 # embedding parameters 

985 if layers == 'all': 

986 # send mini-token through to check how many layers the model has 

987 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] 

988 self.layer_indexes = [int(x) for x in range(len(hidden_states))] 

989 else: 

990 self.layer_indexes = [int(x) for x in layers.split(",")] 

991 

992 self.pooling_operation = subtoken_pooling 

993 self.layer_mean = layer_mean 

994 self.fine_tune = fine_tune 

995 self.static_embeddings = not self.fine_tune 

996 

997 # calculate embedding length 

998 if not self.layer_mean: 

999 length = len(self.layer_indexes) * self.model.config.hidden_size 

1000 else: 

1001 length = self.model.config.hidden_size 

1002 if self.pooling_operation == 'first_last': length *= 2 

1003 

1004 # return length 

1005 self.embedding_length_internal = length 

1006 

1007 self.special_tokens = [] 

1008 # check if special tokens exist to circumvent error message 

1009 if self.tokenizer._bos_token: 

1010 self.special_tokens.append(self.tokenizer.bos_token) 

1011 if self.tokenizer._cls_token: 

1012 self.special_tokens.append(self.tokenizer.cls_token) 

1013 

1014 # most models have an intial BOS token, except for XLNet, T5 and GPT2 

1015 self.begin_offset = self._get_begin_offset_of_tokenizer(tokenizer=self.tokenizer) 

1016 

1017 # when initializing, embeddings are in eval mode by default 

1018 self.eval() 

1019 

1020 @staticmethod 

1021 def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int: 

1022 test_string = 'a' 

1023 tokens = tokenizer.encode(test_string) 

1024 

1025 for begin_offset, token in enumerate(tokens): 

1026 if tokenizer.decode([token]) == test_string or tokenizer.decode([token]) == tokenizer.unk_token: 

1027 break 

1028 return begin_offset 

1029 

1030 @staticmethod 

1031 def _remove_special_markup(text: str): 

1032 # remove special markup 

1033 text = re.sub('^Ġ', '', text) # RoBERTa models 

1034 text = re.sub('^##', '', text) # BERT models 

1035 text = re.sub('^▁', '', text) # XLNet models 

1036 text = re.sub('</w>$', '', text) # XLM models 

1037 return text 

1038 

1039 def _get_processed_token_text(self, token: Token) -> str: 

1040 pieces = self.tokenizer.tokenize(token.text) 

1041 token_text = '' 

1042 for piece in pieces: 

1043 token_text += self._remove_special_markup(piece) 

1044 token_text = token_text.lower() 

1045 return token_text 

1046 

1047 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1048 

1049 # we require encoded subtokenized sentences, the mapping to original tokens and the number of 

1050 # parts that each sentence produces 

1051 subtokenized_sentences = [] 

1052 all_token_subtoken_lengths = [] 

1053 

1054 # if we also use context, first expand sentence to include context 

1055 if self.context_length > 0: 

1056 

1057 # set context if not set already 

1058 previous_sentence = None 

1059 for sentence in sentences: 

1060 if sentence.is_context_set(): continue 

1061 sentence._previous_sentence = previous_sentence 

1062 sentence._next_sentence = None 

1063 if previous_sentence: previous_sentence._next_sentence = sentence 

1064 previous_sentence = sentence 

1065 

1066 original_sentences = [] 

1067 expanded_sentences = [] 

1068 context_offsets = [] 

1069 

1070 for sentence in sentences: 

1071 # in case of contextualization, we must remember non-expanded sentence 

1072 original_sentence = sentence 

1073 original_sentences.append(original_sentence) 

1074 

1075 # create expanded sentence and remember context offsets 

1076 expanded_sentence, context_offset = self._expand_sentence_with_context(sentence) 

1077 expanded_sentences.append(expanded_sentence) 

1078 context_offsets.append(context_offset) 

1079 

1080 # overwrite sentence with expanded sentence 

1081 sentence = expanded_sentence 

1082 

1083 sentences = expanded_sentences 

1084 

1085 tokenized_sentences = [] 

1086 for sentence in sentences: 

1087 

1088 # subtokenize the sentence 

1089 tokenized_string = sentence.to_tokenized_string() 

1090 

1091 # transformer specific tokenization 

1092 subtokenized_sentence = self.tokenizer.tokenize(tokenized_string) 

1093 

1094 # set zero embeddings for empty sentences and exclude 

1095 if len(subtokenized_sentence) == 0: 

1096 for token in sentence: 

1097 token.set_embedding(self.name, torch.zeros(self.embedding_length)) 

1098 continue 

1099 

1100 # determine into how many subtokens each token is split 

1101 token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence) 

1102 

1103 # remember tokenized sentences and their subtokenization 

1104 tokenized_sentences.append(tokenized_string) 

1105 all_token_subtoken_lengths.append(token_subtoken_lengths) 

1106 

1107 # encode inputs 

1108 batch_encoding = self.tokenizer(tokenized_sentences, 

1109 max_length=self.max_subtokens_sequence_length, 

1110 stride=self.stride, 

1111 return_overflowing_tokens=self.allow_long_sentences, 

1112 truncation=self.truncate, 

1113 padding=True, 

1114 return_tensors='pt', 

1115 ) 

1116 

1117 model_kwargs = {} 

1118 input_ids = batch_encoding['input_ids'].to(flair.device) 

1119 

1120 # Models such as FNet do not have an attention_mask 

1121 if 'attention_mask' in batch_encoding: 

1122 model_kwargs['attention_mask'] = batch_encoding['attention_mask'].to(flair.device) 

1123 

1124 # determine which sentence was split into how many parts 

1125 sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \ 

1126 else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist() 

1127 

1128 # set language IDs for XLM-style transformers 

1129 if self.use_lang_emb: 

1130 model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) 

1131 

1132 for s_id, sentence in enumerate(tokenized_sentences): 

1133 sequence_length = len(sentence) 

1134 lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) 

1135 model_kwargs["langs"][s_id][:sequence_length] = lang_id 

1136 

1137 # put encoded batch through transformer model to get all hidden states of all encoder layers 

1138 hidden_states = self.model(input_ids, **model_kwargs)[-1] 

1139 # make the tuple a tensor; makes working with it easier. 

1140 hidden_states = torch.stack(hidden_states) 

1141 

1142 sentence_idx_offset = 0 

1143 

1144 # gradients are enabled if fine-tuning is enabled 

1145 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() 

1146 

1147 with gradient_context: 

1148 

1149 # iterate over all subtokenized sentences 

1150 for sentence_idx, (sentence, subtoken_lengths, nr_sentence_parts) in enumerate( 

1151 zip(sentences, all_token_subtoken_lengths, sentence_parts_lengths)): 

1152 

1153 sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] 

1154 

1155 for i in range(1, nr_sentence_parts): 

1156 sentence_idx_offset += 1 

1157 remainder_sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] 

1158 # remove stride_size//2 at end of sentence_hidden_state, and half at beginning of remainder, 

1159 # in order to get some context into the embeddings of these words. 

1160 # also don't include the embedding of the extra [CLS] and [SEP] tokens. 

1161 sentence_hidden_state = torch.cat((sentence_hidden_state[:, :-1 - self.stride // 2, :], 

1162 remainder_sentence_hidden_state[:, 1 + self.stride // 2:, 

1163 :]), 1) 

1164 

1165 subword_start_idx = self.begin_offset 

1166 

1167 # for each token, get embedding 

1168 for token_idx, (token, number_of_subtokens) in enumerate(zip(sentence, subtoken_lengths)): 

1169 

1170 # some tokens have no subtokens at all (if omitted by BERT tokenizer) so return zero vector 

1171 if number_of_subtokens == 0: 

1172 token.set_embedding(self.name, torch.zeros(self.embedding_length)) 

1173 continue 

1174 

1175 subword_end_idx = subword_start_idx + number_of_subtokens 

1176 

1177 subtoken_embeddings: List[torch.FloatTensor] = [] 

1178 

1179 # get states from all selected layers, aggregate with pooling operation 

1180 for layer in self.layer_indexes: 

1181 current_embeddings = sentence_hidden_state[layer][subword_start_idx:subword_end_idx] 

1182 

1183 if self.pooling_operation == "first": 

1184 final_embedding: torch.FloatTensor = current_embeddings[0] 

1185 

1186 if self.pooling_operation == "last": 

1187 final_embedding: torch.FloatTensor = current_embeddings[-1] 

1188 

1189 if self.pooling_operation == "first_last": 

1190 final_embedding: torch.Tensor = torch.cat( 

1191 [current_embeddings[0], current_embeddings[-1]]) 

1192 

1193 if self.pooling_operation == "mean": 

1194 all_embeddings: List[torch.FloatTensor] = [ 

1195 embedding.unsqueeze(0) for embedding in current_embeddings 

1196 ] 

1197 final_embedding: torch.Tensor = torch.mean(torch.cat(all_embeddings, dim=0), dim=0) 

1198 

1199 subtoken_embeddings.append(final_embedding) 

1200 

1201 # use layer mean of embeddings if so selected 

1202 if self.layer_mean and len(self.layer_indexes) > 1: 

1203 sm_embeddings = torch.mean(torch.stack(subtoken_embeddings, dim=1), dim=1) 

1204 subtoken_embeddings = [sm_embeddings] 

1205 

1206 # set the extracted embedding for the token 

1207 token.set_embedding(self.name, torch.cat(subtoken_embeddings)) 

1208 

1209 subword_start_idx += number_of_subtokens 

1210 

1211 # move embeddings from context back to original sentence (if using context) 

1212 if self.context_length > 0: 

1213 for original_sentence, expanded_sentence, context_offset in zip(original_sentences, 

1214 sentences, 

1215 context_offsets): 

1216 for token_idx, token in enumerate(original_sentence): 

1217 token.set_embedding(self.name, 

1218 expanded_sentence[token_idx + context_offset].get_embedding(self.name)) 

1219 sentence = original_sentence 

1220 

1221 def _expand_sentence_with_context(self, sentence): 

1222 

1223 # remember original sentence 

1224 original_sentence = sentence 

1225 

1226 import random 

1227 expand_context = False if self.training and random.randint(1, 100) <= (self.context_dropout * 100) else True 

1228 

1229 left_context = '' 

1230 right_context = '' 

1231 

1232 if expand_context: 

1233 

1234 # get left context 

1235 while True: 

1236 sentence = sentence.previous_sentence() 

1237 if sentence is None: break 

1238 

1239 if self.respect_document_boundaries and sentence.is_document_boundary: break 

1240 

1241 left_context = sentence.to_tokenized_string() + ' ' + left_context 

1242 left_context = left_context.strip() 

1243 if len(left_context.split(" ")) > self.context_length: 

1244 left_context = " ".join(left_context.split(" ")[-self.context_length:]) 

1245 break 

1246 original_sentence.left_context = left_context 

1247 

1248 sentence = original_sentence 

1249 

1250 # get right context 

1251 while True: 

1252 sentence = sentence.next_sentence() 

1253 if sentence is None: break 

1254 if self.respect_document_boundaries and sentence.is_document_boundary: break 

1255 

1256 right_context += ' ' + sentence.to_tokenized_string() 

1257 right_context = right_context.strip() 

1258 if len(right_context.split(" ")) > self.context_length: 

1259 right_context = " ".join(right_context.split(" ")[:self.context_length]) 

1260 break 

1261 

1262 original_sentence.right_context = right_context 

1263 

1264 left_context_split = left_context.split(" ") 

1265 right_context_split = right_context.split(" ") 

1266 

1267 # empty contexts should not introduce whitespace tokens 

1268 if left_context_split == [""]: left_context_split = [] 

1269 if right_context_split == [""]: right_context_split = [] 

1270 

1271 # make expanded sentence 

1272 expanded_sentence = Sentence() 

1273 expanded_sentence.tokens = [Token(token) for token in left_context_split + 

1274 original_sentence.to_tokenized_string().split(" ") + 

1275 right_context_split] 

1276 

1277 context_length = len(left_context_split) 

1278 return expanded_sentence, context_length 

1279 

1280 def reconstruct_tokens_from_subtokens(self, sentence, subtokens): 

1281 word_iterator = iter(sentence) 

1282 token = next(word_iterator) 

1283 token_text = self._get_processed_token_text(token) 

1284 token_subtoken_lengths = [] 

1285 reconstructed_token = '' 

1286 subtoken_count = 0 

1287 # iterate over subtokens and reconstruct tokens 

1288 for subtoken_id, subtoken in enumerate(subtokens): 

1289 

1290 # remove special markup 

1291 subtoken = self._remove_special_markup(subtoken) 

1292 

1293 # TODO check if this is necessary is this method is called before prepare_for_model 

1294 # check if reconstructed token is special begin token ([CLS] or similar) 

1295 if subtoken in self.special_tokens and subtoken_id == 0: 

1296 continue 

1297 

1298 # some BERT tokenizers somehow omit words - in such cases skip to next token 

1299 if subtoken_count == 0 and not token_text.startswith(subtoken.lower()): 

1300 

1301 while True: 

1302 token_subtoken_lengths.append(0) 

1303 token = next(word_iterator) 

1304 token_text = self._get_processed_token_text(token) 

1305 if token_text.startswith(subtoken.lower()): break 

1306 

1307 subtoken_count += 1 

1308 

1309 # append subtoken to reconstruct token 

1310 reconstructed_token = reconstructed_token + subtoken 

1311 

1312 # check if reconstructed token is the same as current token 

1313 if reconstructed_token.lower() == token_text: 

1314 

1315 # if so, add subtoken count 

1316 token_subtoken_lengths.append(subtoken_count) 

1317 

1318 # reset subtoken count and reconstructed token 

1319 reconstructed_token = '' 

1320 subtoken_count = 0 

1321 

1322 # break from loop if all tokens are accounted for 

1323 if len(token_subtoken_lengths) < len(sentence): 

1324 token = next(word_iterator) 

1325 token_text = self._get_processed_token_text(token) 

1326 else: 

1327 break 

1328 

1329 # if tokens are unaccounted for 

1330 while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1: 

1331 token_subtoken_lengths.append(0) 

1332 if len(token_subtoken_lengths) == len(sentence): break 

1333 token = next(word_iterator) 

1334 

1335 # check if all tokens were matched to subtokens 

1336 if token != sentence[-1]: 

1337 log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'") 

1338 log.error(f"Last matched: '{token}'") 

1339 log.error(f"Last sentence: '{sentence[-1]}'") 

1340 log.error(f"subtokenized: '{subtokens}'") 

1341 return token_subtoken_lengths 

1342 

1343 @property 

1344 def embedding_length(self) -> int: 

1345 

1346 if "embedding_length_internal" in self.__dict__.keys(): 

1347 return self.embedding_length_internal 

1348 

1349 # """Returns the length of the embedding vector.""" 

1350 if not self.layer_mean: 

1351 length = len(self.layer_indexes) * self.model.config.hidden_size 

1352 else: 

1353 length = self.model.config.hidden_size 

1354 

1355 if self.pooling_operation == 'first_last': length *= 2 

1356 

1357 self.__embedding_length = length 

1358 

1359 return length 

1360 

1361 def __getstate__(self): 

1362 # special handling for serializing transformer models 

1363 config_state_dict = self.model.config.__dict__ 

1364 model_state_dict = self.model.state_dict() 

1365 

1366 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-word-')[-1] 

1367 

1368 # serialize the transformer models and the constructor arguments (but nothing else) 

1369 model_state = { 

1370 "config_state_dict": config_state_dict, 

1371 "model_state_dict": model_state_dict, 

1372 "embedding_length_internal": self.embedding_length, 

1373 

1374 "base_model_name": self.base_model_name, 

1375 "name": self.name, 

1376 "layer_indexes": self.layer_indexes, 

1377 "subtoken_pooling": self.pooling_operation, 

1378 "context_length": self.context_length, 

1379 "layer_mean": self.layer_mean, 

1380 "fine_tune": self.fine_tune, 

1381 "allow_long_sentences": self.allow_long_sentences, 

1382 "memory_effective_training": self.memory_effective_training, 

1383 "respect_document_boundaries": self.respect_document_boundaries, 

1384 "context_dropout": self.context_dropout, 

1385 } 

1386 

1387 return model_state 

1388 

1389 def __setstate__(self, d): 

1390 self.__dict__ = d 

1391 

1392 # necessary for reverse compatibility with Flair <= 0.7 

1393 if 'use_scalar_mix' in self.__dict__.keys(): 

1394 self.__dict__['layer_mean'] = d['use_scalar_mix'] 

1395 if not 'memory_effective_training' in self.__dict__.keys(): 

1396 self.__dict__['memory_effective_training'] = True 

1397 if 'pooling_operation' in self.__dict__.keys(): 

1398 self.__dict__['subtoken_pooling'] = d['pooling_operation'] 

1399 if not 'context_length' in self.__dict__.keys(): 

1400 self.__dict__['context_length'] = 0 

1401 if 'use_context' in self.__dict__.keys(): 

1402 self.__dict__['context_length'] = 64 if self.__dict__['use_context'] == True else 0 

1403 

1404 if not 'context_dropout' in self.__dict__.keys(): 

1405 self.__dict__['context_dropout'] = 0.5 

1406 if not 'respect_document_boundaries' in self.__dict__.keys(): 

1407 self.__dict__['respect_document_boundaries'] = True 

1408 if not 'memory_effective_training' in self.__dict__.keys(): 

1409 self.__dict__['memory_effective_training'] = True 

1410 if not 'base_model_name' in self.__dict__.keys(): 

1411 self.__dict__['base_model_name'] = self.__dict__['name'].split('transformer-word-')[-1] 

1412 

1413 # special handling for deserializing transformer models 

1414 if "config_state_dict" in d: 

1415 

1416 # load transformer model 

1417 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" 

1418 config_class = CONFIG_MAPPING[model_type] 

1419 loaded_config = config_class.from_dict(d["config_state_dict"]) 

1420 

1421 # constructor arguments 

1422 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']]) 

1423 

1424 # re-initialize transformer word embeddings with constructor arguments 

1425 embedding = TransformerWordEmbeddings( 

1426 model=self.__dict__['base_model_name'], 

1427 layers=layers, 

1428 subtoken_pooling=self.__dict__['subtoken_pooling'], 

1429 use_context=self.__dict__['context_length'], 

1430 layer_mean=self.__dict__['layer_mean'], 

1431 fine_tune=self.__dict__['fine_tune'], 

1432 allow_long_sentences=self.__dict__['allow_long_sentences'], 

1433 respect_document_boundaries=self.__dict__['respect_document_boundaries'], 

1434 memory_effective_training=self.__dict__['memory_effective_training'], 

1435 context_dropout=self.__dict__['context_dropout'], 

1436 

1437 config=loaded_config, 

1438 state_dict=d["model_state_dict"], 

1439 ) 

1440 

1441 # I have no idea why this is necessary, but otherwise it doesn't work 

1442 for key in embedding.__dict__.keys(): 

1443 self.__dict__[key] = embedding.__dict__[key] 

1444 

1445 else: 

1446 

1447 # reload tokenizer to get around serialization issues 

1448 model_name = self.__dict__['name'].split('transformer-word-')[-1] 

1449 try: 

1450 tokenizer = AutoTokenizer.from_pretrained(model_name) 

1451 except: 

1452 pass 

1453 

1454 self.tokenizer = tokenizer 

1455 

1456 

1457class FastTextEmbeddings(TokenEmbeddings): 

1458 """FastText Embeddings with oov functionality""" 

1459 

1460 def __init__(self, embeddings: str, use_local: bool = True, field: str = None): 

1461 """ 

1462 Initializes fasttext word embeddings. Constructor downloads required embedding file and stores in cache 

1463 if use_local is False. 

1464 

1465 :param embeddings: path to your embeddings '.bin' file 

1466 :param use_local: set this to False if you are using embeddings from a remote source 

1467 """ 

1468 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1469 

1470 cache_dir = Path("embeddings") 

1471 

1472 if use_local: 

1473 if not Path(embeddings).exists(): 

1474 raise ValueError( 

1475 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

1476 ) 

1477 else: 

1478 embeddings = cached_path(f"{embeddings}", cache_dir=cache_dir) 

1479 

1480 self.embeddings = embeddings 

1481 

1482 self.name: str = str(embeddings) 

1483 

1484 self.static_embeddings = True 

1485 

1486 self.precomputed_word_embeddings: gensim.models.FastText = gensim.models.FastText.load_fasttext_format( 

1487 str(embeddings) 

1488 ) 

1489 print(self.precomputed_word_embeddings) 

1490 

1491 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size 

1492 

1493 self.field = field 

1494 super().__init__() 

1495 

1496 @property 

1497 def embedding_length(self) -> int: 

1498 return self.__embedding_length 

1499 

1500 @instance_lru_cache(maxsize=10000, typed=False) 

1501 def get_cached_vec(self, word: str) -> torch.Tensor: 

1502 try: 

1503 word_embedding = self.precomputed_word_embeddings.wv[word] 

1504 except: 

1505 word_embedding = np.zeros(self.embedding_length, dtype="float") 

1506 

1507 word_embedding = torch.tensor( 

1508 word_embedding.tolist(), device=flair.device, dtype=torch.float 

1509 ) 

1510 return word_embedding 

1511 

1512 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1513 

1514 for i, sentence in enumerate(sentences): 

1515 

1516 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1517 

1518 if "field" not in self.__dict__ or self.field is None: 

1519 word = token.text 

1520 else: 

1521 word = token.get_tag(self.field).value 

1522 

1523 word_embedding = self.get_cached_vec(word) 

1524 

1525 token.set_embedding(self.name, word_embedding) 

1526 

1527 return sentences 

1528 

1529 def __str__(self): 

1530 return self.name 

1531 

1532 def extra_repr(self): 

1533 return f"'{self.embeddings}'" 

1534 

1535 

1536class OneHotEmbeddings(TokenEmbeddings): 

1537 """One-hot encoded embeddings. """ 

1538 

1539 def __init__( 

1540 self, 

1541 vocab_dictionary: Dictionary, 

1542 field: str = "text", 

1543 embedding_length: int = 300, 

1544 stable: bool = False, 

1545 ): 

1546 """ 

1547 Initializes one-hot encoded word embeddings and a trainable embedding layer 

1548 :param vocab_dictionary: the vocabulary that will be encoded 

1549 :param field: by default, the 'text' of tokens is embedded, but you can also embed tags such as 'pos' 

1550 :param embedding_length: dimensionality of the trainable embedding layer 

1551 :param stable: set stable=True to use the stable embeddings as described in https://arxiv.org/abs/2110.02861 

1552 """ 

1553 super().__init__() 

1554 self.name = f"one-hot-{field}" 

1555 self.static_embeddings = False 

1556 self.field = field 

1557 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1558 self.__embedding_length = embedding_length 

1559 self.vocab_dictionary = vocab_dictionary 

1560 

1561 print(self.vocab_dictionary.idx2item) 

1562 print(f"vocabulary size of {len(self.vocab_dictionary)}") 

1563 

1564 # model architecture 

1565 self.embedding_layer = torch.nn.Embedding( 

1566 len(self.vocab_dictionary), self.__embedding_length 

1567 ) 

1568 torch.nn.init.xavier_uniform_(self.embedding_layer.weight) 

1569 if stable: 

1570 self.layer_norm = torch.nn.LayerNorm(embedding_length) 

1571 else: 

1572 self.layer_norm = None 

1573 

1574 self.to(flair.device) 

1575 

1576 @property 

1577 def embedding_length(self) -> int: 

1578 return self.__embedding_length 

1579 

1580 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1581 

1582 tokens = [ 

1583 t 

1584 for sentence in sentences 

1585 for t in sentence.tokens 

1586 ] 

1587 

1588 if self.field == "text": 

1589 one_hot_sentences = [ 

1590 self.vocab_dictionary.get_idx_for_item(t.text) 

1591 for t in tokens 

1592 ] 

1593 else: 

1594 one_hot_sentences = [ 

1595 self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value) 

1596 for t in tokens 

1597 ] 

1598 

1599 one_hot_sentences = torch.tensor(one_hot_sentences, dtype=torch.long).to( 

1600 flair.device 

1601 ) 

1602 

1603 embedded = self.embedding_layer.forward(one_hot_sentences) 

1604 if self.layer_norm: 

1605 embedded = self.layer_norm(embedded) 

1606 

1607 for emb, token in zip(embedded, tokens): 

1608 token.set_embedding(self.name, emb) 

1609 

1610 return sentences 

1611 

1612 def __str__(self): 

1613 return self.name 

1614 

1615 @classmethod 

1616 def from_corpus( 

1617 cls, 

1618 corpus: Corpus, 

1619 field: str = "text", 

1620 min_freq: int = 3, 

1621 **kwargs 

1622 ): 

1623 vocab_dictionary = Dictionary() 

1624 

1625 tokens = list(map((lambda s: s.tokens), corpus.train)) 

1626 tokens = [token for sublist in tokens for token in sublist] 

1627 

1628 if field == "text": 

1629 most_common = Counter(list(map((lambda t: t.text), tokens))).most_common() 

1630 else: 

1631 most_common = Counter( 

1632 list(map((lambda t: t.get_tag(field).value), tokens)) 

1633 ).most_common() 

1634 

1635 tokens = [] 

1636 for token, freq in most_common: 

1637 if freq < min_freq: 

1638 break 

1639 tokens.append(token) 

1640 

1641 for token in tokens: 

1642 vocab_dictionary.add_item(token) 

1643 

1644 return cls(vocab_dictionary, field=field, **kwargs) 

1645 

1646 

1647class HashEmbeddings(TokenEmbeddings): 

1648 """Standard embeddings with Hashing Trick.""" 

1649 

1650 def __init__( 

1651 self, num_embeddings: int = 1000, embedding_length: int = 300, hash_method="md5" 

1652 ): 

1653 

1654 super().__init__() 

1655 self.name = "hash" 

1656 self.static_embeddings = False 

1657 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1658 

1659 self.__num_embeddings = num_embeddings 

1660 self.__embedding_length = embedding_length 

1661 

1662 self.__hash_method = hash_method 

1663 

1664 # model architecture 

1665 self.embedding_layer = torch.nn.Embedding( 

1666 self.__num_embeddings, self.__embedding_length 

1667 ) 

1668 torch.nn.init.xavier_uniform_(self.embedding_layer.weight) 

1669 

1670 self.to(flair.device) 

1671 

1672 @property 

1673 def num_embeddings(self) -> int: 

1674 return self.__num_embeddings 

1675 

1676 @property 

1677 def embedding_length(self) -> int: 

1678 return self.__embedding_length 

1679 

1680 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1681 def get_idx_for_item(text): 

1682 hash_function = hashlib.new(self.__hash_method) 

1683 hash_function.update(bytes(str(text), "utf-8")) 

1684 return int(hash_function.hexdigest(), 16) % self.__num_embeddings 

1685 

1686 hash_sentences = [] 

1687 for i, sentence in enumerate(sentences): 

1688 context_idxs = [get_idx_for_item(t.text) for t in sentence.tokens] 

1689 

1690 hash_sentences.extend(context_idxs) 

1691 

1692 hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(flair.device) 

1693 

1694 embedded = self.embedding_layer.forward(hash_sentences) 

1695 

1696 index = 0 

1697 for sentence in sentences: 

1698 for token in sentence: 

1699 embedding = embedded[index] 

1700 token.set_embedding(self.name, embedding) 

1701 index += 1 

1702 

1703 return sentences 

1704 

1705 def __str__(self): 

1706 return self.name 

1707 

1708 

1709class MuseCrosslingualEmbeddings(TokenEmbeddings): 

1710 def __init__(self, ): 

1711 self.name: str = f"muse-crosslingual" 

1712 self.static_embeddings = True 

1713 self.__embedding_length: int = 300 

1714 self.language_embeddings = {} 

1715 super().__init__() 

1716 

1717 @instance_lru_cache(maxsize=10000, typed=False) 

1718 def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor: 

1719 current_embedding_model = self.language_embeddings[language_code] 

1720 if word in current_embedding_model: 

1721 word_embedding = current_embedding_model[word] 

1722 elif word.lower() in current_embedding_model: 

1723 word_embedding = current_embedding_model[word.lower()] 

1724 elif re.sub(r"\d", "#", word.lower()) in current_embedding_model: 

1725 word_embedding = current_embedding_model[re.sub(r"\d", "#", word.lower())] 

1726 elif re.sub(r"\d", "0", word.lower()) in current_embedding_model: 

1727 word_embedding = current_embedding_model[re.sub(r"\d", "0", word.lower())] 

1728 else: 

1729 word_embedding = np.zeros(self.embedding_length, dtype="float") 

1730 word_embedding = torch.tensor( 

1731 word_embedding, device=flair.device, dtype=torch.float 

1732 ) 

1733 return word_embedding 

1734 

1735 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1736 

1737 for i, sentence in enumerate(sentences): 

1738 

1739 language_code = sentence.get_language_code() 

1740 supported = [ 

1741 "en", 

1742 "de", 

1743 "bg", 

1744 "ca", 

1745 "hr", 

1746 "cs", 

1747 "da", 

1748 "nl", 

1749 "et", 

1750 "fi", 

1751 "fr", 

1752 "el", 

1753 "he", 

1754 "hu", 

1755 "id", 

1756 "it", 

1757 "mk", 

1758 "no", 

1759 # "pl", 

1760 "pt", 

1761 "ro", 

1762 "ru", 

1763 "sk", 

1764 ] 

1765 if language_code not in supported: 

1766 language_code = "en" 

1767 

1768 if language_code not in self.language_embeddings: 

1769 log.info(f"Loading up MUSE embeddings for '{language_code}'!") 

1770 # download if necessary 

1771 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/muse" 

1772 cache_dir = Path("embeddings") / "MUSE" 

1773 cached_path( 

1774 f"{hu_path}/muse.{language_code}.vec.gensim.vectors.npy", 

1775 cache_dir=cache_dir, 

1776 ) 

1777 embeddings_file = cached_path( 

1778 f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir 

1779 ) 

1780 

1781 # load the model 

1782 self.language_embeddings[ 

1783 language_code 

1784 ] = gensim.models.KeyedVectors.load(str(embeddings_file)) 

1785 

1786 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1787 

1788 if "field" not in self.__dict__ or self.field is None: 

1789 word = token.text 

1790 else: 

1791 word = token.get_tag(self.field).value 

1792 

1793 word_embedding = self.get_cached_vec( 

1794 language_code=language_code, word=word 

1795 ) 

1796 

1797 token.set_embedding(self.name, word_embedding) 

1798 

1799 return sentences 

1800 

1801 @property 

1802 def embedding_length(self) -> int: 

1803 return self.__embedding_length 

1804 

1805 def __str__(self): 

1806 return self.name 

1807 

1808 

1809# TODO: keep for backwards compatibility, but remove in future 

1810class BPEmbSerializable(BPEmb): 

1811 def __getstate__(self): 

1812 state = self.__dict__.copy() 

1813 # save the sentence piece model as binary file (not as path which may change) 

1814 state["spm_model_binary"] = open(self.model_file, mode="rb").read() 

1815 state["spm"] = None 

1816 return state 

1817 

1818 def __setstate__(self, state): 

1819 from bpemb.util import sentencepiece_load 

1820 

1821 model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"]) 

1822 self.__dict__ = state 

1823 

1824 # write out the binary sentence piece model into the expected directory 

1825 self.cache_dir: Path = flair.cache_root / "embeddings" 

1826 if "spm_model_binary" in self.__dict__: 

1827 # if the model was saved as binary and it is not found on disk, write to appropriate path 

1828 if not os.path.exists(self.cache_dir / state["lang"]): 

1829 os.makedirs(self.cache_dir / state["lang"]) 

1830 self.model_file = self.cache_dir / model_file 

1831 with open(self.model_file, "wb") as out: 

1832 out.write(self.__dict__["spm_model_binary"]) 

1833 else: 

1834 # otherwise, use normal process and potentially trigger another download 

1835 self.model_file = self._load_file(model_file) 

1836 

1837 # once the modes if there, load it with sentence piece 

1838 state["spm"] = sentencepiece_load(self.model_file) 

1839 

1840 

1841class BytePairEmbeddings(TokenEmbeddings): 

1842 def __init__( 

1843 self, 

1844 language: str = None, 

1845 dim: int = 50, 

1846 syllables: int = 100000, 

1847 cache_dir=None, 

1848 model_file_path: Path = None, 

1849 embedding_file_path: Path = None, 

1850 **kwargs, 

1851 ): 

1852 """ 

1853 Initializes BP embeddings. Constructor downloads required files if not there. 

1854 """ 

1855 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1856 

1857 if not cache_dir: 

1858 cache_dir = flair.cache_root / "embeddings" 

1859 if language: 

1860 self.name: str = f"bpe-{language}-{syllables}-{dim}" 

1861 else: 

1862 assert ( 

1863 model_file_path is not None and embedding_file_path is not None 

1864 ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" 

1865 dim = None 

1866 

1867 self.embedder = BPEmbSerializable( 

1868 lang=language, 

1869 vs=syllables, 

1870 dim=dim, 

1871 cache_dir=cache_dir, 

1872 model_file=model_file_path, 

1873 emb_file=embedding_file_path, 

1874 **kwargs, 

1875 ) 

1876 

1877 if not language: 

1878 self.name: str = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}" 

1879 self.static_embeddings = True 

1880 

1881 self.__embedding_length: int = self.embedder.emb.vector_size * 2 

1882 super().__init__() 

1883 

1884 @property 

1885 def embedding_length(self) -> int: 

1886 return self.__embedding_length 

1887 

1888 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1889 

1890 for i, sentence in enumerate(sentences): 

1891 

1892 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1893 

1894 if "field" not in self.__dict__ or self.field is None: 

1895 word = token.text 

1896 else: 

1897 word = token.get_tag(self.field).value 

1898 

1899 if word.strip() == "": 

1900 # empty words get no embedding 

1901 token.set_embedding( 

1902 self.name, torch.zeros(self.embedding_length, dtype=torch.float) 

1903 ) 

1904 else: 

1905 # all other words get embedded 

1906 embeddings = self.embedder.embed(word.lower()) 

1907 embedding = np.concatenate( 

1908 (embeddings[0], embeddings[len(embeddings) - 1]) 

1909 ) 

1910 token.set_embedding( 

1911 self.name, torch.tensor(embedding, dtype=torch.float) 

1912 ) 

1913 

1914 return sentences 

1915 

1916 def __str__(self): 

1917 return self.name 

1918 

1919 def extra_repr(self): 

1920 return "model={}".format(self.name) 

1921 

1922 

1923class ELMoEmbeddings(TokenEmbeddings): 

1924 """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018. 

1925 ELMo word vectors can be constructed by combining layers in different ways. 

1926 Default is to concatene the top 3 layers in the LM.""" 

1927 

1928 def __init__( 

1929 self, model: str = "original", options_file: str = None, weight_file: str = None, 

1930 embedding_mode: str = "all" 

1931 ): 

1932 super().__init__() 

1933 

1934 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1935 

1936 try: 

1937 import allennlp.commands.elmo 

1938 except ModuleNotFoundError: 

1939 log.warning("-" * 100) 

1940 log.warning('ATTENTION! The library "allennlp" is not installed!') 

1941 log.warning( 

1942 'To use ELMoEmbeddings, please first install with "pip install allennlp==0.9.0"' 

1943 ) 

1944 log.warning("-" * 100) 

1945 pass 

1946 

1947 assert embedding_mode in ["all", "top", "average"] 

1948 

1949 self.name = f"elmo-{model}-{embedding_mode}" 

1950 self.static_embeddings = True 

1951 

1952 if not options_file or not weight_file: 

1953 # the default model for ELMo is the 'original' model, which is very large 

1954 options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE 

1955 weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE 

1956 # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name 

1957 if model == "small": 

1958 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json" 

1959 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5" 

1960 if model == "medium": 

1961 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json" 

1962 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" 

1963 if model in ["large", "5.5B"]: 

1964 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" 

1965 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" 

1966 if model == "pt" or model == "portuguese": 

1967 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json" 

1968 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5" 

1969 if model == "pubmed": 

1970 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json" 

1971 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5" 

1972 

1973 if embedding_mode == "all": 

1974 self.embedding_mode_fn = self.use_layers_all 

1975 elif embedding_mode == "top": 

1976 self.embedding_mode_fn = self.use_layers_top 

1977 elif embedding_mode == "average": 

1978 self.embedding_mode_fn = self.use_layers_average 

1979 

1980 # put on Cuda if available 

1981 from flair import device 

1982 

1983 if re.fullmatch(r"cuda:[0-9]+", str(device)): 

1984 cuda_device = int(str(device).split(":")[-1]) 

1985 elif str(device) == "cpu": 

1986 cuda_device = -1 

1987 else: 

1988 cuda_device = 0 

1989 

1990 self.ee = allennlp.commands.elmo.ElmoEmbedder( 

1991 options_file=options_file, weight_file=weight_file, cuda_device=cuda_device 

1992 ) 

1993 

1994 # embed a dummy sentence to determine embedding_length 

1995 dummy_sentence: Sentence = Sentence() 

1996 dummy_sentence.add_token(Token("hello")) 

1997 embedded_dummy = self.embed(dummy_sentence) 

1998 self.__embedding_length: int = len( 

1999 embedded_dummy[0].get_token(1).get_embedding() 

2000 ) 

2001 

2002 @property 

2003 def embedding_length(self) -> int: 

2004 return self.__embedding_length 

2005 

2006 def use_layers_all(self, x): 

2007 return torch.cat(x, 0) 

2008 

2009 def use_layers_top(self, x): 

2010 return x[-1] 

2011 

2012 def use_layers_average(self, x): 

2013 return torch.mean(torch.stack(x), 0) 

2014 

2015 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

2016 # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn 

2017 if not getattr(self, "embedding_mode_fn", None): 

2018 self.embedding_mode_fn = self.use_layers_all 

2019 

2020 sentence_words: List[List[str]] = [] 

2021 for sentence in sentences: 

2022 sentence_words.append([token.text for token in sentence]) 

2023 

2024 embeddings = self.ee.embed_batch(sentence_words) 

2025 

2026 for i, sentence in enumerate(sentences): 

2027 

2028 sentence_embeddings = embeddings[i] 

2029 

2030 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

2031 elmo_embedding_layers = [ 

2032 torch.FloatTensor(sentence_embeddings[0, token_idx, :]), 

2033 torch.FloatTensor(sentence_embeddings[1, token_idx, :]), 

2034 torch.FloatTensor(sentence_embeddings[2, token_idx, :]) 

2035 ] 

2036 word_embedding = self.embedding_mode_fn(elmo_embedding_layers) 

2037 token.set_embedding(self.name, word_embedding) 

2038 

2039 return sentences 

2040 

2041 def extra_repr(self): 

2042 return "model={}".format(self.name) 

2043 

2044 def __str__(self): 

2045 return self.name 

2046 

2047 def __setstate__(self, state): 

2048 self.__dict__ = state 

2049 

2050 if re.fullmatch(r"cuda:[0-9]+", str(flair.device)): 

2051 cuda_device = int(str(flair.device).split(":")[-1]) 

2052 elif str(flair.device) == "cpu": 

2053 cuda_device = -1 

2054 else: 

2055 cuda_device = 0 

2056 

2057 self.ee.cuda_device = cuda_device 

2058 

2059 self.ee.elmo_bilm.to(device=flair.device) 

2060 self.ee.elmo_bilm._elmo_lstm._states = tuple( 

2061 [state.to(flair.device) for state in self.ee.elmo_bilm._elmo_lstm._states]) 

2062 

2063 

2064class NILCEmbeddings(WordEmbeddings): 

2065 def __init__(self, embeddings: str, model: str = "skip", size: int = 100): 

2066 """ 

2067 Initializes portuguese classic word embeddings trained by NILC Lab (http://www.nilc.icmc.usp.br/embeddings). 

2068 Constructor downloads required files if not there. 

2069 :param embeddings: one of: 'fasttext', 'glove', 'wang2vec' or 'word2vec' 

2070 :param model: one of: 'skip' or 'cbow'. This is not applicable to glove. 

2071 :param size: one of: 50, 100, 300, 600 or 1000. 

2072 """ 

2073 

2074 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

2075 

2076 base_path = "http://143.107.183.175:22980/download.php?file=embeddings/" 

2077 

2078 cache_dir = Path("embeddings") / embeddings.lower() 

2079 

2080 # GLOVE embeddings 

2081 if embeddings.lower() == "glove": 

2082 cached_path( 

2083 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir 

2084 ) 

2085 embeddings = cached_path( 

2086 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir 

2087 ) 

2088 

2089 elif embeddings.lower() in ["fasttext", "wang2vec", "word2vec"]: 

2090 cached_path( 

2091 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir 

2092 ) 

2093 embeddings = cached_path( 

2094 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir 

2095 ) 

2096 

2097 elif not Path(embeddings).exists(): 

2098 raise ValueError( 

2099 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

2100 ) 

2101 

2102 self.name: str = str(embeddings) 

2103 self.static_embeddings = True 

2104 

2105 log.info("Reading embeddings from %s" % embeddings) 

2106 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( 

2107 open_inside_zip(str(embeddings), cache_dir=cache_dir) 

2108 ) 

2109 

2110 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size 

2111 super(TokenEmbeddings, self).__init__() 

2112 

2113 @property 

2114 def embedding_length(self) -> int: 

2115 return self.__embedding_length 

2116 

2117 def __str__(self): 

2118 return self.name 

2119 

2120 

2121def replace_with_language_code(string: str): 

2122 string = string.replace("arabic-", "ar-") 

2123 string = string.replace("basque-", "eu-") 

2124 string = string.replace("bulgarian-", "bg-") 

2125 string = string.replace("croatian-", "hr-") 

2126 string = string.replace("czech-", "cs-") 

2127 string = string.replace("danish-", "da-") 

2128 string = string.replace("dutch-", "nl-") 

2129 string = string.replace("farsi-", "fa-") 

2130 string = string.replace("persian-", "fa-") 

2131 string = string.replace("finnish-", "fi-") 

2132 string = string.replace("french-", "fr-") 

2133 string = string.replace("german-", "de-") 

2134 string = string.replace("hebrew-", "he-") 

2135 string = string.replace("hindi-", "hi-") 

2136 string = string.replace("indonesian-", "id-") 

2137 string = string.replace("italian-", "it-") 

2138 string = string.replace("japanese-", "ja-") 

2139 string = string.replace("norwegian-", "no") 

2140 string = string.replace("polish-", "pl-") 

2141 string = string.replace("portuguese-", "pt-") 

2142 string = string.replace("slovenian-", "sl-") 

2143 string = string.replace("spanish-", "es-") 

2144 string = string.replace("swedish-", "sv-") 

2145 return string