Coverage for flair/flair/embeddings/token.py: 48%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

995 statements  

1import hashlib 

2import inspect 

3import logging 

4import os 

5import re 

6from abc import abstractmethod 

7from collections import Counter 

8from pathlib import Path 

9from typing import List, Union, Dict 

10 

11import gensim 

12import numpy as np 

13import torch 

14from bpemb import BPEmb 

15from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer, XLNetModel, \ 

16 TransfoXLModel 

17 

18import flair 

19from flair.data import Sentence, Token, Corpus, Dictionary 

20from flair.embeddings.base import Embeddings 

21from flair.file_utils import cached_path, open_inside_zip, instance_lru_cache 

22 

23log = logging.getLogger("flair") 

24 

25 

26class TokenEmbeddings(Embeddings): 

27 """Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods.""" 

28 

29 @property 

30 @abstractmethod 

31 def embedding_length(self) -> int: 

32 """Returns the length of the embedding vector.""" 

33 pass 

34 

35 @property 

36 def embedding_type(self) -> str: 

37 return "word-level" 

38 

39 @staticmethod 

40 def get_instance_parameters(locals: dict) -> dict: 

41 class_definition = locals.get("__class__") 

42 instance_parameters = set(inspect.getfullargspec(class_definition.__init__).args) 

43 instance_parameters.difference_update(set(["self"])) 

44 instance_parameters.update(set(["__class__"])) 

45 instance_parameters = {class_attribute: attribute_value for class_attribute, attribute_value in locals.items() 

46 if class_attribute in instance_parameters} 

47 return instance_parameters 

48 

49 

50class StackedEmbeddings(TokenEmbeddings): 

51 """A stack of embeddings, used if you need to combine several different embedding types.""" 

52 

53 def __init__(self, embeddings: List[TokenEmbeddings]): 

54 """The constructor takes a list of embeddings to be combined.""" 

55 super().__init__() 

56 

57 self.embeddings = embeddings 

58 

59 # IMPORTANT: add embeddings as torch modules 

60 for i, embedding in enumerate(embeddings): 

61 embedding.name = f"{str(i)}-{embedding.name}" 

62 self.add_module(f"list_embedding_{str(i)}", embedding) 

63 

64 self.name: str = "Stack" 

65 self.static_embeddings: bool = True 

66 

67 self.__embedding_type: str = embeddings[0].embedding_type 

68 

69 self.__embedding_length: int = 0 

70 for embedding in embeddings: 

71 self.__embedding_length += embedding.embedding_length 

72 

73 def embed( 

74 self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True 

75 ): 

76 # if only one sentence is passed, convert to list of sentence 

77 if type(sentences) is Sentence: 

78 sentences = [sentences] 

79 

80 for embedding in self.embeddings: 

81 embedding.embed(sentences) 

82 

83 @property 

84 def embedding_type(self) -> str: 

85 return self.__embedding_type 

86 

87 @property 

88 def embedding_length(self) -> int: 

89 return self.__embedding_length 

90 

91 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

92 

93 for embedding in self.embeddings: 

94 embedding._add_embeddings_internal(sentences) 

95 

96 return sentences 

97 

98 def __str__(self): 

99 return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]' 

100 

101 def get_names(self) -> List[str]: 

102 """Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of 

103 this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding). 

104 Then, the list contains the names of all embeddings in the stack.""" 

105 names = [] 

106 for embedding in self.embeddings: 

107 names.extend(embedding.get_names()) 

108 return names 

109 

110 def get_named_embeddings_dict(self) -> Dict: 

111 

112 named_embeddings_dict = {} 

113 for embedding in self.embeddings: 

114 named_embeddings_dict.update(embedding.get_named_embeddings_dict()) 

115 

116 return named_embeddings_dict 

117 

118 

119class WordEmbeddings(TokenEmbeddings): 

120 """Standard static word embeddings, such as GloVe or FastText.""" 

121 

122 def __init__(self, embeddings: str, field: str = None): 

123 """ 

124 Initializes classic word embeddings. Constructor downloads required files if not there. 

125 :param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or custom 

126 If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable. 

127 """ 

128 self.embeddings = embeddings 

129 

130 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

131 

132 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/token" 

133 

134 cache_dir = Path("embeddings") 

135 

136 # GLOVE embeddings 

137 if embeddings.lower() == "glove" or embeddings.lower() == "en-glove": 

138 cached_path(f"{hu_path}/glove.gensim.vectors.npy", cache_dir=cache_dir) 

139 embeddings = cached_path(f"{hu_path}/glove.gensim", cache_dir=cache_dir) 

140 

141 # TURIAN embeddings 

142 elif embeddings.lower() == "turian" or embeddings.lower() == "en-turian": 

143 cached_path(f"{hu_path}/turian.vectors.npy", cache_dir=cache_dir) 

144 embeddings = cached_path(f"{hu_path}/turian", cache_dir=cache_dir) 

145 

146 # KOMNINOS embeddings 

147 elif embeddings.lower() == "extvec" or embeddings.lower() == "en-extvec": 

148 cached_path(f"{hu_path}/extvec.gensim.vectors.npy", cache_dir=cache_dir) 

149 embeddings = cached_path(f"{hu_path}/extvec.gensim", cache_dir=cache_dir) 

150 

151 # pubmed embeddings 

152 elif embeddings.lower() == "pubmed" or embeddings.lower() == "en-pubmed": 

153 cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim.vectors.npy", cache_dir=cache_dir) 

154 embeddings = cached_path(f"{hu_path}/pubmed_pmc_wiki_sg_1M.gensim", cache_dir=cache_dir) 

155 

156 # FT-CRAWL embeddings 

157 elif embeddings.lower() == "crawl" or embeddings.lower() == "en-crawl": 

158 cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M.vectors.npy", cache_dir=cache_dir) 

159 embeddings = cached_path(f"{hu_path}/en-fasttext-crawl-300d-1M", cache_dir=cache_dir) 

160 

161 # FT-CRAWL embeddings 

162 elif embeddings.lower() in ["news", "en-news", "en"]: 

163 cached_path(f"{hu_path}/en-fasttext-news-300d-1M.vectors.npy", cache_dir=cache_dir) 

164 embeddings = cached_path(f"{hu_path}/en-fasttext-news-300d-1M", cache_dir=cache_dir) 

165 

166 # twitter embeddings 

167 elif embeddings.lower() in ["twitter", "en-twitter"]: 

168 cached_path(f"{hu_path}/twitter.gensim.vectors.npy", cache_dir=cache_dir) 

169 embeddings = cached_path(f"{hu_path}/twitter.gensim", cache_dir=cache_dir) 

170 

171 # two-letter language code wiki embeddings 

172 elif len(embeddings.lower()) == 2: 

173 cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

174 embeddings = cached_path(f"{hu_path}/{embeddings}-wiki-fasttext-300d-1M", cache_dir=cache_dir) 

175 

176 # two-letter language code wiki embeddings 

177 elif len(embeddings.lower()) == 7 and embeddings.endswith("-wiki"): 

178 cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

179 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-wiki-fasttext-300d-1M", cache_dir=cache_dir) 

180 

181 # two-letter language code crawl embeddings 

182 elif len(embeddings.lower()) == 8 and embeddings.endswith("-crawl"): 

183 cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M.vectors.npy", cache_dir=cache_dir) 

184 embeddings = cached_path(f"{hu_path}/{embeddings[:2]}-crawl-fasttext-300d-1M", cache_dir=cache_dir) 

185 

186 elif not Path(embeddings).exists(): 

187 raise ValueError( 

188 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

189 ) 

190 

191 self.name: str = str(embeddings) 

192 self.static_embeddings = True 

193 

194 if str(embeddings).endswith(".bin"): 

195 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( 

196 str(embeddings), binary=True 

197 ) 

198 else: 

199 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load( 

200 str(embeddings) 

201 ) 

202 

203 self.field = field 

204 

205 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size 

206 super().__init__() 

207 

208 @property 

209 def embedding_length(self) -> int: 

210 return self.__embedding_length 

211 

212 @instance_lru_cache(maxsize=10000, typed=False) 

213 def get_cached_vec(self, word: str) -> torch.Tensor: 

214 if word in self.precomputed_word_embeddings: 

215 word_embedding = self.precomputed_word_embeddings[word] 

216 elif word.lower() in self.precomputed_word_embeddings: 

217 word_embedding = self.precomputed_word_embeddings[word.lower()] 

218 elif re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings: 

219 word_embedding = self.precomputed_word_embeddings[ 

220 re.sub(r"\d", "#", word.lower()) 

221 ] 

222 elif re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings: 

223 word_embedding = self.precomputed_word_embeddings[ 

224 re.sub(r"\d", "0", word.lower()) 

225 ] 

226 else: 

227 word_embedding = np.zeros(self.embedding_length, dtype="float") 

228 

229 word_embedding = torch.tensor( 

230 word_embedding.tolist(), device=flair.device, dtype=torch.float 

231 ) 

232 return word_embedding 

233 

234 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

235 

236 for i, sentence in enumerate(sentences): 

237 

238 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

239 

240 if "field" not in self.__dict__ or self.field is None: 

241 word = token.text 

242 else: 

243 word = token.get_tag(self.field).value 

244 

245 word_embedding = self.get_cached_vec(word=word) 

246 

247 token.set_embedding(self.name, word_embedding) 

248 

249 return sentences 

250 

251 def __str__(self): 

252 return self.name 

253 

254 def extra_repr(self): 

255 # fix serialized models 

256 if "embeddings" not in self.__dict__: 

257 self.embeddings = self.name 

258 

259 return f"'{self.embeddings}'" 

260 

261 

262class CharacterEmbeddings(TokenEmbeddings): 

263 """Character embeddings of words, as proposed in Lample et al., 2016.""" 

264 

265 def __init__( 

266 self, 

267 path_to_char_dict: str = None, 

268 char_embedding_dim: int = 25, 

269 hidden_size_char: int = 25, 

270 ): 

271 """Uses the default character dictionary if none provided.""" 

272 

273 super().__init__() 

274 self.name = "Char" 

275 self.static_embeddings = False 

276 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

277 

278 # use list of common characters if none provided 

279 if path_to_char_dict is None: 

280 self.char_dictionary: Dictionary = Dictionary.load("common-chars") 

281 else: 

282 self.char_dictionary: Dictionary = Dictionary.load_from_file(path_to_char_dict) 

283 

284 self.char_embedding_dim: int = char_embedding_dim 

285 self.hidden_size_char: int = hidden_size_char 

286 self.char_embedding = torch.nn.Embedding( 

287 len(self.char_dictionary.item2idx), self.char_embedding_dim 

288 ) 

289 self.char_rnn = torch.nn.LSTM( 

290 self.char_embedding_dim, 

291 self.hidden_size_char, 

292 num_layers=1, 

293 bidirectional=True, 

294 ) 

295 

296 self.__embedding_length = self.hidden_size_char * 2 

297 

298 self.to(flair.device) 

299 

300 @property 

301 def embedding_length(self) -> int: 

302 return self.__embedding_length 

303 

304 def _add_embeddings_internal(self, sentences: List[Sentence]): 

305 

306 for sentence in sentences: 

307 

308 tokens_char_indices = [] 

309 

310 # translate words in sentence into ints using dictionary 

311 for token in sentence.tokens: 

312 char_indices = [ 

313 self.char_dictionary.get_idx_for_item(char) for char in token.text 

314 ] 

315 tokens_char_indices.append(char_indices) 

316 

317 # sort words by length, for batching and masking 

318 tokens_sorted_by_length = sorted( 

319 tokens_char_indices, key=lambda p: len(p), reverse=True 

320 ) 

321 d = {} 

322 for i, ci in enumerate(tokens_char_indices): 

323 for j, cj in enumerate(tokens_sorted_by_length): 

324 if ci == cj: 

325 d[j] = i 

326 continue 

327 chars2_length = [len(c) for c in tokens_sorted_by_length] 

328 longest_token_in_sentence = max(chars2_length) 

329 tokens_mask = torch.zeros( 

330 (len(tokens_sorted_by_length), longest_token_in_sentence), 

331 dtype=torch.long, 

332 device=flair.device, 

333 ) 

334 

335 for i, c in enumerate(tokens_sorted_by_length): 

336 tokens_mask[i, : chars2_length[i]] = torch.tensor( 

337 c, dtype=torch.long, device=flair.device 

338 ) 

339 

340 # chars for rnn processing 

341 chars = tokens_mask 

342 

343 character_embeddings = self.char_embedding(chars).transpose(0, 1) 

344 

345 packed = torch.nn.utils.rnn.pack_padded_sequence( 

346 character_embeddings, chars2_length 

347 ) 

348 

349 lstm_out, self.hidden = self.char_rnn(packed) 

350 

351 outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out) 

352 outputs = outputs.transpose(0, 1) 

353 chars_embeds_temp = torch.zeros( 

354 (outputs.size(0), outputs.size(2)), 

355 dtype=torch.float, 

356 device=flair.device, 

357 ) 

358 for i, index in enumerate(output_lengths): 

359 chars_embeds_temp[i] = outputs[i, index - 1] 

360 character_embeddings = chars_embeds_temp.clone() 

361 for i in range(character_embeddings.size(0)): 

362 character_embeddings[d[i]] = chars_embeds_temp[i] 

363 

364 for token_number, token in enumerate(sentence.tokens): 

365 token.set_embedding(self.name, character_embeddings[token_number]) 

366 

367 def __str__(self): 

368 return self.name 

369 

370 

371class FlairEmbeddings(TokenEmbeddings): 

372 """Contextual string embeddings of words, as proposed in Akbik et al., 2018.""" 

373 

374 def __init__(self, 

375 model, 

376 fine_tune: bool = False, 

377 chars_per_chunk: int = 512, 

378 with_whitespace: bool = True, 

379 tokenized_lm: bool = True, 

380 is_lower: bool = False, 

381 ): 

382 """ 

383 initializes contextual string embeddings using a character-level language model. 

384 :param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast', 

385 'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward', 

386 etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md) 

387 depending on which character language model is desired. 

388 :param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows 

389 down training and often leads to overfitting, so use with caution. 

390 :param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster 

391 but requires more memory. Lower means slower but less memory. 

392 :param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden 

393 state at last character of word. 

394 :param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text 

395 False might be better. 

396 """ 

397 super().__init__() 

398 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

399 

400 cache_dir = Path("embeddings") 

401 

402 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/flair" 

403 clef_hipe_path: str = "https://files.ifi.uzh.ch/cl/siclemat/impresso/clef-hipe-2020/flair" 

404 

405 self.is_lower: bool = is_lower 

406 

407 self.PRETRAINED_MODEL_ARCHIVE_MAP = { 

408 # multilingual models 

409 "multi-forward": f"{hu_path}/lm-jw300-forward-v0.1.pt", 

410 "multi-backward": f"{hu_path}/lm-jw300-backward-v0.1.pt", 

411 "multi-v0-forward": f"{hu_path}/lm-multi-forward-v0.1.pt", 

412 "multi-v0-backward": f"{hu_path}/lm-multi-backward-v0.1.pt", 

413 "multi-forward-fast": f"{hu_path}/lm-multi-forward-fast-v0.1.pt", 

414 "multi-backward-fast": f"{hu_path}/lm-multi-backward-fast-v0.1.pt", 

415 # English models 

416 "en-forward": f"{hu_path}/news-forward-0.4.1.pt", 

417 "en-backward": f"{hu_path}/news-backward-0.4.1.pt", 

418 "en-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt", 

419 "en-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt", 

420 "news-forward": f"{hu_path}/news-forward-0.4.1.pt", 

421 "news-backward": f"{hu_path}/news-backward-0.4.1.pt", 

422 "news-forward-fast": f"{hu_path}/lm-news-english-forward-1024-v0.2rc.pt", 

423 "news-backward-fast": f"{hu_path}/lm-news-english-backward-1024-v0.2rc.pt", 

424 "mix-forward": f"{hu_path}/lm-mix-english-forward-v0.2rc.pt", 

425 "mix-backward": f"{hu_path}/lm-mix-english-backward-v0.2rc.pt", 

426 # Arabic 

427 "ar-forward": f"{hu_path}/lm-ar-opus-large-forward-v0.1.pt", 

428 "ar-backward": f"{hu_path}/lm-ar-opus-large-backward-v0.1.pt", 

429 # Bulgarian 

430 "bg-forward-fast": f"{hu_path}/lm-bg-small-forward-v0.1.pt", 

431 "bg-backward-fast": f"{hu_path}/lm-bg-small-backward-v0.1.pt", 

432 "bg-forward": f"{hu_path}/lm-bg-opus-large-forward-v0.1.pt", 

433 "bg-backward": f"{hu_path}/lm-bg-opus-large-backward-v0.1.pt", 

434 # Czech 

435 "cs-forward": f"{hu_path}/lm-cs-opus-large-forward-v0.1.pt", 

436 "cs-backward": f"{hu_path}/lm-cs-opus-large-backward-v0.1.pt", 

437 "cs-v0-forward": f"{hu_path}/lm-cs-large-forward-v0.1.pt", 

438 "cs-v0-backward": f"{hu_path}/lm-cs-large-backward-v0.1.pt", 

439 # Danish 

440 "da-forward": f"{hu_path}/lm-da-opus-large-forward-v0.1.pt", 

441 "da-backward": f"{hu_path}/lm-da-opus-large-backward-v0.1.pt", 

442 # German 

443 "de-forward": f"{hu_path}/lm-mix-german-forward-v0.2rc.pt", 

444 "de-backward": f"{hu_path}/lm-mix-german-backward-v0.2rc.pt", 

445 "de-historic-ha-forward": f"{hu_path}/lm-historic-hamburger-anzeiger-forward-v0.1.pt", 

446 "de-historic-ha-backward": f"{hu_path}/lm-historic-hamburger-anzeiger-backward-v0.1.pt", 

447 "de-historic-wz-forward": f"{hu_path}/lm-historic-wiener-zeitung-forward-v0.1.pt", 

448 "de-historic-wz-backward": f"{hu_path}/lm-historic-wiener-zeitung-backward-v0.1.pt", 

449 "de-historic-rw-forward": f"{hu_path}/redewiedergabe_lm_forward.pt", 

450 "de-historic-rw-backward": f"{hu_path}/redewiedergabe_lm_backward.pt", 

451 # Spanish 

452 "es-forward": f"{hu_path}/lm-es-forward.pt", 

453 "es-backward": f"{hu_path}/lm-es-backward.pt", 

454 "es-forward-fast": f"{hu_path}/lm-es-forward-fast.pt", 

455 "es-backward-fast": f"{hu_path}/lm-es-backward-fast.pt", 

456 # Basque 

457 "eu-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.2.pt", 

458 "eu-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.2.pt", 

459 "eu-v1-forward": f"{hu_path}/lm-eu-opus-large-forward-v0.1.pt", 

460 "eu-v1-backward": f"{hu_path}/lm-eu-opus-large-backward-v0.1.pt", 

461 "eu-v0-forward": f"{hu_path}/lm-eu-large-forward-v0.1.pt", 

462 "eu-v0-backward": f"{hu_path}/lm-eu-large-backward-v0.1.pt", 

463 # Persian 

464 "fa-forward": f"{hu_path}/lm-fa-opus-large-forward-v0.1.pt", 

465 "fa-backward": f"{hu_path}/lm-fa-opus-large-backward-v0.1.pt", 

466 # Finnish 

467 "fi-forward": f"{hu_path}/lm-fi-opus-large-forward-v0.1.pt", 

468 "fi-backward": f"{hu_path}/lm-fi-opus-large-backward-v0.1.pt", 

469 # French 

470 "fr-forward": f"{hu_path}/lm-fr-charlm-forward.pt", 

471 "fr-backward": f"{hu_path}/lm-fr-charlm-backward.pt", 

472 # Hebrew 

473 "he-forward": f"{hu_path}/lm-he-opus-large-forward-v0.1.pt", 

474 "he-backward": f"{hu_path}/lm-he-opus-large-backward-v0.1.pt", 

475 # Hindi 

476 "hi-forward": f"{hu_path}/lm-hi-opus-large-forward-v0.1.pt", 

477 "hi-backward": f"{hu_path}/lm-hi-opus-large-backward-v0.1.pt", 

478 # Croatian 

479 "hr-forward": f"{hu_path}/lm-hr-opus-large-forward-v0.1.pt", 

480 "hr-backward": f"{hu_path}/lm-hr-opus-large-backward-v0.1.pt", 

481 # Indonesian 

482 "id-forward": f"{hu_path}/lm-id-opus-large-forward-v0.1.pt", 

483 "id-backward": f"{hu_path}/lm-id-opus-large-backward-v0.1.pt", 

484 # Italian 

485 "it-forward": f"{hu_path}/lm-it-opus-large-forward-v0.1.pt", 

486 "it-backward": f"{hu_path}/lm-it-opus-large-backward-v0.1.pt", 

487 # Japanese 

488 "ja-forward": f"{hu_path}/japanese-forward.pt", 

489 "ja-backward": f"{hu_path}/japanese-backward.pt", 

490 # Malayalam 

491 "ml-forward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-forward.pt", 

492 "ml-backward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-backward.pt", 

493 # Dutch 

494 "nl-forward": f"{hu_path}/lm-nl-opus-large-forward-v0.1.pt", 

495 "nl-backward": f"{hu_path}/lm-nl-opus-large-backward-v0.1.pt", 

496 "nl-v0-forward": f"{hu_path}/lm-nl-large-forward-v0.1.pt", 

497 "nl-v0-backward": f"{hu_path}/lm-nl-large-backward-v0.1.pt", 

498 # Norwegian 

499 "no-forward": f"{hu_path}/lm-no-opus-large-forward-v0.1.pt", 

500 "no-backward": f"{hu_path}/lm-no-opus-large-backward-v0.1.pt", 

501 # Polish 

502 "pl-forward": f"{hu_path}/lm-polish-forward-v0.2.pt", 

503 "pl-backward": f"{hu_path}/lm-polish-backward-v0.2.pt", 

504 "pl-opus-forward": f"{hu_path}/lm-pl-opus-large-forward-v0.1.pt", 

505 "pl-opus-backward": f"{hu_path}/lm-pl-opus-large-backward-v0.1.pt", 

506 # Portuguese 

507 "pt-forward": f"{hu_path}/lm-pt-forward.pt", 

508 "pt-backward": f"{hu_path}/lm-pt-backward.pt", 

509 # Pubmed 

510 "pubmed-forward": f"{hu_path}/pubmed-forward.pt", 

511 "pubmed-backward": f"{hu_path}/pubmed-backward.pt", 

512 "pubmed-2015-forward": f"{hu_path}/pubmed-2015-fw-lm.pt", 

513 "pubmed-2015-backward": f"{hu_path}/pubmed-2015-bw-lm.pt", 

514 # Slovenian 

515 "sl-forward": f"{hu_path}/lm-sl-opus-large-forward-v0.1.pt", 

516 "sl-backward": f"{hu_path}/lm-sl-opus-large-backward-v0.1.pt", 

517 "sl-v0-forward": f"{hu_path}/lm-sl-large-forward-v0.1.pt", 

518 "sl-v0-backward": f"{hu_path}/lm-sl-large-backward-v0.1.pt", 

519 # Swedish 

520 "sv-forward": f"{hu_path}/lm-sv-opus-large-forward-v0.1.pt", 

521 "sv-backward": f"{hu_path}/lm-sv-opus-large-backward-v0.1.pt", 

522 "sv-v0-forward": f"{hu_path}/lm-sv-large-forward-v0.1.pt", 

523 "sv-v0-backward": f"{hu_path}/lm-sv-large-backward-v0.1.pt", 

524 # Tamil 

525 "ta-forward": f"{hu_path}/lm-ta-opus-large-forward-v0.1.pt", 

526 "ta-backward": f"{hu_path}/lm-ta-opus-large-backward-v0.1.pt", 

527 # Spanish clinical 

528 "es-clinical-forward": f"{hu_path}/es-clinical-forward.pt", 

529 "es-clinical-backward": f"{hu_path}/es-clinical-backward.pt", 

530 # CLEF HIPE Shared task 

531 "de-impresso-hipe-v1-forward": f"{clef_hipe_path}/de-hipe-flair-v1-forward/best-lm.pt", 

532 "de-impresso-hipe-v1-backward": f"{clef_hipe_path}/de-hipe-flair-v1-backward/best-lm.pt", 

533 "en-impresso-hipe-v1-forward": f"{clef_hipe_path}/en-flair-v1-forward/best-lm.pt", 

534 "en-impresso-hipe-v1-backward": f"{clef_hipe_path}/en-flair-v1-backward/best-lm.pt", 

535 "fr-impresso-hipe-v1-forward": f"{clef_hipe_path}/fr-hipe-flair-v1-forward/best-lm.pt", 

536 "fr-impresso-hipe-v1-backward": f"{clef_hipe_path}/fr-hipe-flair-v1-backward/best-lm.pt", 

537 } 

538 

539 if type(model) == str: 

540 

541 # load model if in pretrained model map 

542 if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP: 

543 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()] 

544 

545 # Fix for CLEF HIPE models (avoid overwriting best-lm.pt in cache_dir) 

546 if "impresso-hipe" in model.lower(): 

547 cache_dir = cache_dir / model.lower() 

548 # CLEF HIPE models are lowercased 

549 self.is_lower = True 

550 model = cached_path(base_path, cache_dir=cache_dir) 

551 

552 elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP: 

553 base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[ 

554 replace_with_language_code(model) 

555 ] 

556 model = cached_path(base_path, cache_dir=cache_dir) 

557 

558 elif not Path(model).exists(): 

559 raise ValueError( 

560 f'The given model "{model}" is not available or is not a valid path.' 

561 ) 

562 

563 from flair.models import LanguageModel 

564 

565 if type(model) == LanguageModel: 

566 self.lm: LanguageModel = model 

567 self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}" 

568 else: 

569 self.lm: LanguageModel = LanguageModel.load_language_model(model) 

570 self.name = str(model) 

571 

572 # embeddings are static if we don't do finetuning 

573 self.fine_tune = fine_tune 

574 self.static_embeddings = not fine_tune 

575 

576 self.is_forward_lm: bool = self.lm.is_forward_lm 

577 self.with_whitespace: bool = with_whitespace 

578 self.tokenized_lm: bool = tokenized_lm 

579 self.chars_per_chunk: int = chars_per_chunk 

580 

581 # embed a dummy sentence to determine embedding_length 

582 dummy_sentence: Sentence = Sentence() 

583 dummy_sentence.add_token(Token("hello")) 

584 embedded_dummy = self.embed(dummy_sentence) 

585 self.__embedding_length: int = len( 

586 embedded_dummy[0].get_token(1).get_embedding() 

587 ) 

588 

589 # set to eval mode 

590 self.eval() 

591 

592 def train(self, mode=True): 

593 

594 # make compatible with serialized models (TODO: remove) 

595 if "fine_tune" not in self.__dict__: 

596 self.fine_tune = False 

597 if "chars_per_chunk" not in self.__dict__: 

598 self.chars_per_chunk = 512 

599 

600 # unless fine-tuning is set, do not set language model to train() in order to disallow language model dropout 

601 if not self.fine_tune: 

602 pass 

603 else: 

604 super(FlairEmbeddings, self).train(mode) 

605 

606 @property 

607 def embedding_length(self) -> int: 

608 return self.__embedding_length 

609 

610 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

611 

612 # make compatible with serialized models (TODO: remove) 

613 if "with_whitespace" not in self.__dict__: 

614 self.with_whitespace = True 

615 if "tokenized_lm" not in self.__dict__: 

616 self.tokenized_lm = True 

617 if "is_lower" not in self.__dict__: 

618 self.is_lower = False 

619 

620 # gradients are enable if fine-tuning is enabled 

621 gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad() 

622 

623 with gradient_context: 

624 

625 # if this is not possible, use LM to generate embedding. First, get text sentences 

626 text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \ 

627 else [sentence.to_plain_string() for sentence in sentences] 

628 

629 if self.is_lower: 

630 text_sentences = [sentence.lower() for sentence in text_sentences] 

631 

632 start_marker = self.lm.document_delimiter if "document_delimiter" in self.lm.__dict__ else '\n' 

633 end_marker = " " 

634 

635 # get hidden states from language model 

636 all_hidden_states_in_lm = self.lm.get_representation( 

637 text_sentences, start_marker, end_marker, self.chars_per_chunk 

638 ) 

639 

640 if not self.fine_tune: 

641 all_hidden_states_in_lm = all_hidden_states_in_lm.detach() 

642 

643 # take first or last hidden states from language model as word representation 

644 for i, sentence in enumerate(sentences): 

645 sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string() 

646 

647 offset_forward: int = len(start_marker) 

648 offset_backward: int = len(sentence_text) + len(start_marker) 

649 

650 for token in sentence.tokens: 

651 

652 offset_forward += len(token.text) 

653 if self.is_forward_lm: 

654 offset_with_whitespace = offset_forward 

655 offset_without_whitespace = offset_forward - 1 

656 else: 

657 offset_with_whitespace = offset_backward 

658 offset_without_whitespace = offset_backward - 1 

659 

660 # offset mode that extracts at whitespace after last character 

661 if self.with_whitespace: 

662 embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :] 

663 # offset mode that extracts at last character 

664 else: 

665 embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :] 

666 

667 if self.tokenized_lm or token.whitespace_after: 

668 offset_forward += 1 

669 offset_backward -= 1 

670 

671 offset_backward -= len(token.text) 

672 

673 # only clone if optimization mode is 'gpu' 

674 if flair.embedding_storage_mode == "gpu": 

675 embedding = embedding.clone() 

676 

677 token.set_embedding(self.name, embedding) 

678 

679 del all_hidden_states_in_lm 

680 

681 return sentences 

682 

683 def __str__(self): 

684 return self.name 

685 

686 

687class PooledFlairEmbeddings(TokenEmbeddings): 

688 def __init__( 

689 self, 

690 contextual_embeddings: Union[str, FlairEmbeddings], 

691 pooling: str = "min", 

692 only_capitalized: bool = False, 

693 **kwargs, 

694 ): 

695 

696 super().__init__() 

697 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

698 

699 # use the character language model embeddings as basis 

700 if type(contextual_embeddings) is str: 

701 self.context_embeddings: FlairEmbeddings = FlairEmbeddings( 

702 contextual_embeddings, **kwargs 

703 ) 

704 else: 

705 self.context_embeddings: FlairEmbeddings = contextual_embeddings 

706 

707 # length is twice the original character LM embedding length 

708 self.embedding_length = self.context_embeddings.embedding_length * 2 

709 self.name = self.context_embeddings.name + "-context" 

710 

711 # these fields are for the embedding memory 

712 self.word_embeddings = {} 

713 self.word_count = {} 

714 

715 # whether to add only capitalized words to memory (faster runtime and lower memory consumption) 

716 self.only_capitalized = only_capitalized 

717 

718 # we re-compute embeddings dynamically at each epoch 

719 self.static_embeddings = False 

720 

721 # set the memory method 

722 self.pooling = pooling 

723 

724 def train(self, mode=True): 

725 super().train(mode=mode) 

726 if mode: 

727 # memory is wiped each time we do a training run 

728 print("train mode resetting embeddings") 

729 self.word_embeddings = {} 

730 self.word_count = {} 

731 

732 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

733 

734 self.context_embeddings.embed(sentences) 

735 

736 # if we keep a pooling, it needs to be updated continuously 

737 for sentence in sentences: 

738 for token in sentence.tokens: 

739 

740 # update embedding 

741 local_embedding = token._embeddings[self.context_embeddings.name].cpu() 

742 

743 # check token.text is empty or not 

744 if token.text: 

745 if token.text[0].isupper() or not self.only_capitalized: 

746 

747 if token.text not in self.word_embeddings: 

748 self.word_embeddings[token.text] = local_embedding 

749 self.word_count[token.text] = 1 

750 else: 

751 

752 # set aggregation operation 

753 if self.pooling == "mean": 

754 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding) 

755 elif self.pooling == "fade": 

756 aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding) 

757 aggregated_embedding /= 2 

758 elif self.pooling == "max": 

759 aggregated_embedding = torch.max(self.word_embeddings[token.text], local_embedding) 

760 elif self.pooling == "min": 

761 aggregated_embedding = torch.min(self.word_embeddings[token.text], local_embedding) 

762 

763 self.word_embeddings[token.text] = aggregated_embedding 

764 self.word_count[token.text] += 1 

765 

766 # add embeddings after updating 

767 for sentence in sentences: 

768 for token in sentence.tokens: 

769 if token.text in self.word_embeddings: 

770 base = ( 

771 self.word_embeddings[token.text] / self.word_count[token.text] 

772 if self.pooling == "mean" 

773 else self.word_embeddings[token.text] 

774 ) 

775 else: 

776 base = token._embeddings[self.context_embeddings.name] 

777 

778 token.set_embedding(self.name, base) 

779 

780 return sentences 

781 

782 def embedding_length(self) -> int: 

783 return self.embedding_length 

784 

785 def get_names(self) -> List[str]: 

786 return [self.name, self.context_embeddings.name] 

787 

788 def __setstate__(self, d): 

789 self.__dict__ = d 

790 

791 if flair.device != 'cpu': 

792 for key in self.word_embeddings: 

793 self.word_embeddings[key] = self.word_embeddings[key].cpu() 

794 

795 

796class TransformerWordEmbeddings(TokenEmbeddings): 

797 NO_MAX_SEQ_LENGTH_MODELS = [XLNetModel, TransfoXLModel] 

798 

799 def __init__( 

800 self, 

801 model: str = "bert-base-uncased", 

802 layers: str = "all", 

803 subtoken_pooling: str = "first", 

804 layer_mean: bool = True, 

805 fine_tune: bool = False, 

806 allow_long_sentences: bool = True, 

807 use_context: Union[bool, int] = False, 

808 memory_effective_training: bool = True, 

809 respect_document_boundaries: bool = True, 

810 context_dropout: float = 0.5, 

811 **kwargs 

812 ): 

813 """ 

814 Bidirectional transformer embeddings of words from various transformer architectures. 

815 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for 

816 options) 

817 :param layers: string indicating which layers to take for embedding (-1 is topmost layer) 

818 :param subtoken_pooling: how to get from token piece embeddings to token embedding. Either take the first 

819 subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean') 

820 :param layer_mean: If True, uses a scalar mix of layers as embedding 

821 :param fine_tune: If True, allows transformers to be fine-tuned during training 

822 """ 

823 super().__init__() 

824 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

825 

826 # temporary fix to disable tokenizer parallelism warning 

827 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) 

828 import os 

829 os.environ["TOKENIZERS_PARALLELISM"] = "false" 

830 

831 # do not print transformer warnings as these are confusing in this case 

832 from transformers import logging 

833 logging.set_verbosity_error() 

834 

835 # load tokenizer and transformer model 

836 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) 

837 if not 'config' in kwargs: 

838 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) 

839 self.model = AutoModel.from_pretrained(model, config=config) 

840 else: 

841 self.model = AutoModel.from_pretrained(None, **kwargs) 

842 

843 logging.set_verbosity_warning() 

844 

845 if type(self.model) not in self.NO_MAX_SEQ_LENGTH_MODELS: 

846 self.allow_long_sentences = allow_long_sentences 

847 self.truncate = True 

848 self.max_subtokens_sequence_length = self.tokenizer.model_max_length 

849 self.stride = self.tokenizer.model_max_length // 2 if allow_long_sentences else 0 

850 else: 

851 # in the end, these models don't need this configuration 

852 self.allow_long_sentences = False 

853 self.truncate = False 

854 self.max_subtokens_sequence_length = None 

855 self.stride = 0 

856 

857 self.use_lang_emb = hasattr(self.model, "use_lang_emb") and self.model.use_lang_emb 

858 

859 # model name 

860 self.name = 'transformer-word-' + str(model) 

861 self.base_model = str(model) 

862 

863 # whether to detach gradients on overlong sentences 

864 self.memory_effective_training = memory_effective_training 

865 

866 # store whether to use context (and how much) 

867 if type(use_context) == bool: 

868 self.context_length: int = 64 if use_context else 0 

869 if type(use_context) == int: 

870 self.context_length: int = use_context 

871 

872 # dropout contexts 

873 self.context_dropout = context_dropout 

874 

875 # if using context, can we cross document boundaries? 

876 self.respect_document_boundaries = respect_document_boundaries 

877 

878 # send self to flair-device 

879 self.to(flair.device) 

880 

881 # embedding parameters 

882 if layers == 'all': 

883 # send mini-token through to check how many layers the model has 

884 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] 

885 self.layer_indexes = [int(x) for x in range(len(hidden_states))] 

886 else: 

887 self.layer_indexes = [int(x) for x in layers.split(",")] 

888 

889 self.pooling_operation = subtoken_pooling 

890 self.layer_mean = layer_mean 

891 self.fine_tune = fine_tune 

892 self.static_embeddings = not self.fine_tune 

893 

894 # calculate embedding length 

895 if not self.layer_mean: 

896 length = len(self.layer_indexes) * self.model.config.hidden_size 

897 else: 

898 length = self.model.config.hidden_size 

899 if self.pooling_operation == 'first_last': length *= 2 

900 

901 # return length 

902 self.embedding_length_internal = length 

903 

904 self.special_tokens = [] 

905 # check if special tokens exist to circumvent error message 

906 if self.tokenizer._bos_token: 

907 self.special_tokens.append(self.tokenizer.bos_token) 

908 if self.tokenizer._cls_token: 

909 self.special_tokens.append(self.tokenizer.cls_token) 

910 

911 # most models have an intial BOS token, except for XLNet, T5 and GPT2 

912 self.begin_offset = self._get_begin_offset_of_tokenizer(tokenizer=self.tokenizer) 

913 

914 # when initializing, embeddings are in eval mode by default 

915 self.eval() 

916 

917 @staticmethod 

918 def _get_begin_offset_of_tokenizer(tokenizer: PreTrainedTokenizer) -> int: 

919 test_string = 'a' 

920 tokens = tokenizer.encode(test_string) 

921 

922 for begin_offset, token in enumerate(tokens): 

923 if tokenizer.decode([token]) == test_string or tokenizer.decode([token]) == tokenizer.unk_token: 

924 break 

925 return begin_offset 

926 

927 @staticmethod 

928 def _remove_special_markup(text: str): 

929 # remove special markup 

930 text = re.sub('^Ġ', '', text) # RoBERTa models 

931 text = re.sub('^##', '', text) # BERT models 

932 text = re.sub('^▁', '', text) # XLNet models 

933 text = re.sub('</w>$', '', text) # XLM models 

934 return text 

935 

936 def _get_processed_token_text(self, token: Token) -> str: 

937 pieces = self.tokenizer.tokenize(token.text) 

938 token_text = '' 

939 for piece in pieces: 

940 token_text += self._remove_special_markup(piece) 

941 token_text = token_text.lower() 

942 return token_text 

943 

944 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

945 

946 # we require encoded subtokenized sentences, the mapping to original tokens and the number of 

947 # parts that each sentence produces 

948 subtokenized_sentences = [] 

949 all_token_subtoken_lengths = [] 

950 

951 # if we also use context, first expand sentence to include context 

952 if self.context_length > 0: 

953 

954 # set context if not set already 

955 previous_sentence = None 

956 for sentence in sentences: 

957 if sentence.is_context_set(): continue 

958 sentence._previous_sentence = previous_sentence 

959 sentence._next_sentence = None 

960 if previous_sentence: previous_sentence._next_sentence = sentence 

961 previous_sentence = sentence 

962 

963 original_sentences = [] 

964 expanded_sentences = [] 

965 context_offsets = [] 

966 

967 for sentence in sentences: 

968 # in case of contextualization, we must remember non-expanded sentence 

969 original_sentence = sentence 

970 original_sentences.append(original_sentence) 

971 

972 # create expanded sentence and remember context offsets 

973 expanded_sentence, context_offset = self._expand_sentence_with_context(sentence) 

974 expanded_sentences.append(expanded_sentence) 

975 context_offsets.append(context_offset) 

976 

977 # overwrite sentence with expanded sentence 

978 sentence = expanded_sentence 

979 

980 sentences = expanded_sentences 

981 

982 tokenized_sentences = [] 

983 for sentence in sentences: 

984 

985 # subtokenize the sentence 

986 tokenized_string = sentence.to_tokenized_string() 

987 

988 # transformer specific tokenization 

989 subtokenized_sentence = self.tokenizer.tokenize(tokenized_string) 

990 

991 # set zero embeddings for empty sentences and exclude 

992 if len(subtokenized_sentence) == 0: 

993 for token in sentence: 

994 token.set_embedding(self.name, torch.zeros(self.embedding_length)) 

995 continue 

996 

997 # determine into how many subtokens each token is split 

998 token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence) 

999 

1000 # remember tokenized sentences and their subtokenization 

1001 tokenized_sentences.append(tokenized_string) 

1002 all_token_subtoken_lengths.append(token_subtoken_lengths) 

1003 

1004 # encode inputs 

1005 batch_encoding = self.tokenizer(tokenized_sentences, 

1006 max_length=self.max_subtokens_sequence_length, 

1007 stride=self.stride, 

1008 return_overflowing_tokens=self.allow_long_sentences, 

1009 truncation=self.truncate, 

1010 padding=True, 

1011 return_tensors='pt', 

1012 ) 

1013 

1014 input_ids = batch_encoding['input_ids'].to(flair.device) 

1015 attention_mask = batch_encoding['attention_mask'].to(flair.device) 

1016 

1017 # determine which sentence was split into how many parts 

1018 sentence_parts_lengths = torch.ones(len(tokenized_sentences), dtype=torch.int) if not self.allow_long_sentences \ 

1019 else torch.unique(batch_encoding['overflow_to_sample_mapping'], return_counts=True, sorted=True)[1].tolist() 

1020 

1021 model_kwargs = {} 

1022 # set language IDs for XLM-style transformers 

1023 if self.use_lang_emb: 

1024 model_kwargs["langs"] = torch.zeros_like(input_ids, dtype=input_ids.dtype) 

1025 

1026 for s_id, sentence in enumerate(tokenized_sentences): 

1027 sequence_length = len(sentence) 

1028 lang_id = self.tokenizer.lang2id.get(sentences[s_id].get_language_code(), 0) 

1029 model_kwargs["langs"][s_id][:sequence_length] = lang_id 

1030 

1031 # put encoded batch through transformer model to get all hidden states of all encoder layers 

1032 hidden_states = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)[-1] 

1033 # make the tuple a tensor; makes working with it easier. 

1034 hidden_states = torch.stack(hidden_states) 

1035 

1036 sentence_idx_offset = 0 

1037 

1038 # gradients are enabled if fine-tuning is enabled 

1039 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() 

1040 

1041 with gradient_context: 

1042 

1043 # iterate over all subtokenized sentences 

1044 for sentence_idx, (sentence, subtoken_lengths, nr_sentence_parts) in enumerate( 

1045 zip(sentences, all_token_subtoken_lengths, sentence_parts_lengths)): 

1046 

1047 sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] 

1048 

1049 for i in range(1, nr_sentence_parts): 

1050 sentence_idx_offset += 1 

1051 remainder_sentence_hidden_state = hidden_states[:, sentence_idx + sentence_idx_offset, ...] 

1052 # remove stride_size//2 at end of sentence_hidden_state, and half at beginning of remainder, 

1053 # in order to get some context into the embeddings of these words. 

1054 # also don't include the embedding of the extra [CLS] and [SEP] tokens. 

1055 sentence_hidden_state = torch.cat((sentence_hidden_state[:, :-1 - self.stride // 2, :], 

1056 remainder_sentence_hidden_state[:, 1 + self.stride // 2:, 

1057 :]), 1) 

1058 

1059 subword_start_idx = self.begin_offset 

1060 

1061 # for each token, get embedding 

1062 for token_idx, (token, number_of_subtokens) in enumerate(zip(sentence, subtoken_lengths)): 

1063 

1064 # some tokens have no subtokens at all (if omitted by BERT tokenizer) so return zero vector 

1065 if number_of_subtokens == 0: 

1066 token.set_embedding(self.name, torch.zeros(self.embedding_length)) 

1067 continue 

1068 

1069 subword_end_idx = subword_start_idx + number_of_subtokens 

1070 

1071 subtoken_embeddings: List[torch.FloatTensor] = [] 

1072 

1073 # get states from all selected layers, aggregate with pooling operation 

1074 for layer in self.layer_indexes: 

1075 current_embeddings = sentence_hidden_state[layer][subword_start_idx:subword_end_idx] 

1076 

1077 if self.pooling_operation == "first": 

1078 final_embedding: torch.FloatTensor = current_embeddings[0] 

1079 

1080 if self.pooling_operation == "last": 

1081 final_embedding: torch.FloatTensor = current_embeddings[-1] 

1082 

1083 if self.pooling_operation == "first_last": 

1084 final_embedding: torch.Tensor = torch.cat( 

1085 [current_embeddings[0], current_embeddings[-1]]) 

1086 

1087 if self.pooling_operation == "mean": 

1088 all_embeddings: List[torch.FloatTensor] = [ 

1089 embedding.unsqueeze(0) for embedding in current_embeddings 

1090 ] 

1091 final_embedding: torch.Tensor = torch.mean(torch.cat(all_embeddings, dim=0), dim=0) 

1092 

1093 subtoken_embeddings.append(final_embedding) 

1094 

1095 # use layer mean of embeddings if so selected 

1096 if self.layer_mean and len(self.layer_indexes) > 1: 

1097 sm_embeddings = torch.mean(torch.stack(subtoken_embeddings, dim=1), dim=1) 

1098 subtoken_embeddings = [sm_embeddings] 

1099 

1100 # set the extracted embedding for the token 

1101 token.set_embedding(self.name, torch.cat(subtoken_embeddings)) 

1102 

1103 subword_start_idx += number_of_subtokens 

1104 

1105 # move embeddings from context back to original sentence (if using context) 

1106 if self.context_length > 0: 

1107 for original_sentence, expanded_sentence, context_offset in zip(original_sentences, 

1108 sentences, 

1109 context_offsets): 

1110 for token_idx, token in enumerate(original_sentence): 

1111 token.set_embedding(self.name, 

1112 expanded_sentence[token_idx + context_offset].get_embedding(self.name)) 

1113 sentence = original_sentence 

1114 

1115 def _expand_sentence_with_context(self, sentence): 

1116 

1117 # remember original sentence 

1118 original_sentence = sentence 

1119 

1120 import random 

1121 expand_context = False if self.training and random.randint(1, 100) <= (self.context_dropout * 100) else True 

1122 

1123 left_context = '' 

1124 right_context = '' 

1125 

1126 if expand_context: 

1127 

1128 # get left context 

1129 while True: 

1130 sentence = sentence.previous_sentence() 

1131 if sentence is None: break 

1132 

1133 if self.respect_document_boundaries and sentence.is_document_boundary: break 

1134 

1135 left_context = sentence.to_tokenized_string() + ' ' + left_context 

1136 left_context = left_context.strip() 

1137 if len(left_context.split(" ")) > self.context_length: 

1138 left_context = " ".join(left_context.split(" ")[-self.context_length:]) 

1139 break 

1140 original_sentence.left_context = left_context 

1141 

1142 sentence = original_sentence 

1143 

1144 # get right context 

1145 while True: 

1146 sentence = sentence.next_sentence() 

1147 if sentence is None: break 

1148 if self.respect_document_boundaries and sentence.is_document_boundary: break 

1149 

1150 right_context += ' ' + sentence.to_tokenized_string() 

1151 right_context = right_context.strip() 

1152 if len(right_context.split(" ")) > self.context_length: 

1153 right_context = " ".join(right_context.split(" ")[:self.context_length]) 

1154 break 

1155 

1156 original_sentence.right_context = right_context 

1157 

1158 left_context_split = left_context.split(" ") 

1159 right_context_split = right_context.split(" ") 

1160 

1161 # empty contexts should not introduce whitespace tokens 

1162 if left_context_split == [""]: left_context_split = [] 

1163 if right_context_split == [""]: right_context_split = [] 

1164 

1165 # make expanded sentence 

1166 expanded_sentence = Sentence() 

1167 expanded_sentence.tokens = [Token(token) for token in left_context_split + 

1168 original_sentence.to_tokenized_string().split(" ") + 

1169 right_context_split] 

1170 

1171 context_length = len(left_context_split) 

1172 return expanded_sentence, context_length 

1173 

1174 def reconstruct_tokens_from_subtokens(self, sentence, subtokens): 

1175 word_iterator = iter(sentence) 

1176 token = next(word_iterator) 

1177 token_text = self._get_processed_token_text(token) 

1178 token_subtoken_lengths = [] 

1179 reconstructed_token = '' 

1180 subtoken_count = 0 

1181 # iterate over subtokens and reconstruct tokens 

1182 for subtoken_id, subtoken in enumerate(subtokens): 

1183 

1184 # remove special markup 

1185 subtoken = self._remove_special_markup(subtoken) 

1186 

1187 # TODO check if this is necessary is this method is called before prepare_for_model 

1188 # check if reconstructed token is special begin token ([CLS] or similar) 

1189 if subtoken in self.special_tokens and subtoken_id == 0: 

1190 continue 

1191 

1192 # some BERT tokenizers somehow omit words - in such cases skip to next token 

1193 if subtoken_count == 0 and not token_text.startswith(subtoken.lower()): 

1194 

1195 while True: 

1196 token_subtoken_lengths.append(0) 

1197 token = next(word_iterator) 

1198 token_text = self._get_processed_token_text(token) 

1199 if token_text.startswith(subtoken.lower()): break 

1200 

1201 subtoken_count += 1 

1202 

1203 # append subtoken to reconstruct token 

1204 reconstructed_token = reconstructed_token + subtoken 

1205 

1206 # check if reconstructed token is the same as current token 

1207 if reconstructed_token.lower() == token_text: 

1208 

1209 # if so, add subtoken count 

1210 token_subtoken_lengths.append(subtoken_count) 

1211 

1212 # reset subtoken count and reconstructed token 

1213 reconstructed_token = '' 

1214 subtoken_count = 0 

1215 

1216 # break from loop if all tokens are accounted for 

1217 if len(token_subtoken_lengths) < len(sentence): 

1218 token = next(word_iterator) 

1219 token_text = self._get_processed_token_text(token) 

1220 else: 

1221 break 

1222 

1223 # if tokens are unaccounted for 

1224 while len(token_subtoken_lengths) < len(sentence) and len(token.text) == 1: 

1225 token_subtoken_lengths.append(0) 

1226 if len(token_subtoken_lengths) == len(sentence): break 

1227 token = next(word_iterator) 

1228 

1229 # check if all tokens were matched to subtokens 

1230 if token != sentence[-1]: 

1231 log.error(f"Tokenization MISMATCH in sentence '{sentence.to_tokenized_string()}'") 

1232 log.error(f"Last matched: '{token}'") 

1233 log.error(f"Last sentence: '{sentence[-1]}'") 

1234 log.error(f"subtokenized: '{subtokens}'") 

1235 return token_subtoken_lengths 

1236 

1237 @property 

1238 def embedding_length(self) -> int: 

1239 

1240 if "embedding_length_internal" in self.__dict__.keys(): 

1241 return self.embedding_length_internal 

1242 

1243 # """Returns the length of the embedding vector.""" 

1244 if not self.layer_mean: 

1245 length = len(self.layer_indexes) * self.model.config.hidden_size 

1246 else: 

1247 length = self.model.config.hidden_size 

1248 

1249 if self.pooling_operation == 'first_last': length *= 2 

1250 

1251 self.__embedding_length = length 

1252 

1253 return length 

1254 

1255 def __getstate__(self): 

1256 # special handling for serializing transformer models 

1257 config_state_dict = self.model.config.__dict__ 

1258 model_state_dict = self.model.state_dict() 

1259 

1260 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-word-')[-1] 

1261 

1262 # serialize the transformer models and the constructor arguments (but nothing else) 

1263 model_state = { 

1264 "config_state_dict": config_state_dict, 

1265 "model_state_dict": model_state_dict, 

1266 "embedding_length_internal": self.embedding_length, 

1267 

1268 "base_model_name": self.base_model_name, 

1269 "name": self.name, 

1270 "layer_indexes": self.layer_indexes, 

1271 "subtoken_pooling": self.pooling_operation, 

1272 "context_length": self.context_length, 

1273 "layer_mean": self.layer_mean, 

1274 "fine_tune": self.fine_tune, 

1275 "allow_long_sentences": self.allow_long_sentences, 

1276 "memory_effective_training": self.memory_effective_training, 

1277 "respect_document_boundaries": self.respect_document_boundaries, 

1278 "context_dropout": self.context_dropout, 

1279 } 

1280 

1281 return model_state 

1282 

1283 def __setstate__(self, d): 

1284 self.__dict__ = d 

1285 

1286 # necessary for reverse compatibility with Flair <= 0.7 

1287 if 'use_scalar_mix' in self.__dict__.keys(): 

1288 self.__dict__['layer_mean'] = d['use_scalar_mix'] 

1289 if not 'memory_effective_training' in self.__dict__.keys(): 

1290 self.__dict__['memory_effective_training'] = True 

1291 if 'pooling_operation' in self.__dict__.keys(): 

1292 self.__dict__['subtoken_pooling'] = d['pooling_operation'] 

1293 if not 'context_length' in self.__dict__.keys(): 

1294 self.__dict__['context_length'] = 0 

1295 if 'use_context' in self.__dict__.keys(): 

1296 self.__dict__['context_length'] = 64 if self.__dict__['use_context'] == True else 0 

1297 

1298 if not 'context_dropout' in self.__dict__.keys(): 

1299 self.__dict__['context_dropout'] = 0.5 

1300 if not 'respect_document_boundaries' in self.__dict__.keys(): 

1301 self.__dict__['respect_document_boundaries'] = True 

1302 if not 'memory_effective_training' in self.__dict__.keys(): 

1303 self.__dict__['memory_effective_training'] = True 

1304 if not 'base_model_name' in self.__dict__.keys(): 

1305 self.__dict__['base_model_name'] = self.__dict__['name'].split('transformer-word-')[-1] 

1306 

1307 # special handling for deserializing transformer models 

1308 if "config_state_dict" in d: 

1309 

1310 # load transformer model 

1311 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" 

1312 config_class = CONFIG_MAPPING[model_type] 

1313 loaded_config = config_class.from_dict(d["config_state_dict"]) 

1314 

1315 # constructor arguments 

1316 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']]) 

1317 

1318 # re-initialize transformer word embeddings with constructor arguments 

1319 embedding = TransformerWordEmbeddings( 

1320 model=self.__dict__['base_model_name'], 

1321 layers=layers, 

1322 subtoken_pooling=self.__dict__['subtoken_pooling'], 

1323 use_context=self.__dict__['context_length'], 

1324 layer_mean=self.__dict__['layer_mean'], 

1325 fine_tune=self.__dict__['fine_tune'], 

1326 allow_long_sentences=self.__dict__['allow_long_sentences'], 

1327 respect_document_boundaries=self.__dict__['respect_document_boundaries'], 

1328 memory_effective_training=self.__dict__['memory_effective_training'], 

1329 context_dropout=self.__dict__['context_dropout'], 

1330 

1331 config=loaded_config, 

1332 state_dict=d["model_state_dict"], 

1333 ) 

1334 

1335 # I have no idea why this is necessary, but otherwise it doesn't work 

1336 for key in embedding.__dict__.keys(): 

1337 self.__dict__[key] = embedding.__dict__[key] 

1338 

1339 else: 

1340 

1341 # reload tokenizer to get around serialization issues 

1342 model_name = self.__dict__['name'].split('transformer-word-')[-1] 

1343 try: 

1344 tokenizer = AutoTokenizer.from_pretrained(model_name) 

1345 except: 

1346 pass 

1347 

1348 self.tokenizer = tokenizer 

1349 

1350 

1351class FastTextEmbeddings(TokenEmbeddings): 

1352 """FastText Embeddings with oov functionality""" 

1353 

1354 def __init__(self, embeddings: str, use_local: bool = True, field: str = None): 

1355 """ 

1356 Initializes fasttext word embeddings. Constructor downloads required embedding file and stores in cache 

1357 if use_local is False. 

1358 

1359 :param embeddings: path to your embeddings '.bin' file 

1360 :param use_local: set this to False if you are using embeddings from a remote source 

1361 """ 

1362 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1363 

1364 cache_dir = Path("embeddings") 

1365 

1366 if use_local: 

1367 if not Path(embeddings).exists(): 

1368 raise ValueError( 

1369 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

1370 ) 

1371 else: 

1372 embeddings = cached_path(f"{embeddings}", cache_dir=cache_dir) 

1373 

1374 self.embeddings = embeddings 

1375 

1376 self.name: str = str(embeddings) 

1377 

1378 self.static_embeddings = True 

1379 

1380 self.precomputed_word_embeddings = gensim.models.FastText.load_fasttext_format( 

1381 str(embeddings) 

1382 ) 

1383 

1384 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size 

1385 

1386 self.field = field 

1387 super().__init__() 

1388 

1389 @property 

1390 def embedding_length(self) -> int: 

1391 return self.__embedding_length 

1392 

1393 @instance_lru_cache(maxsize=10000, typed=False) 

1394 def get_cached_vec(self, word: str) -> torch.Tensor: 

1395 try: 

1396 word_embedding = self.precomputed_word_embeddings[word] 

1397 except: 

1398 word_embedding = np.zeros(self.embedding_length, dtype="float") 

1399 

1400 word_embedding = torch.tensor( 

1401 word_embedding.tolist(), device=flair.device, dtype=torch.float 

1402 ) 

1403 return word_embedding 

1404 

1405 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1406 

1407 for i, sentence in enumerate(sentences): 

1408 

1409 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1410 

1411 if "field" not in self.__dict__ or self.field is None: 

1412 word = token.text 

1413 else: 

1414 word = token.get_tag(self.field).value 

1415 

1416 word_embedding = self.get_cached_vec(word) 

1417 

1418 token.set_embedding(self.name, word_embedding) 

1419 

1420 return sentences 

1421 

1422 def __str__(self): 

1423 return self.name 

1424 

1425 def extra_repr(self): 

1426 return f"'{self.embeddings}'" 

1427 

1428 

1429class OneHotEmbeddings(TokenEmbeddings): 

1430 """One-hot encoded embeddings. """ 

1431 

1432 def __init__( 

1433 self, 

1434 corpus: Corpus, 

1435 field: str = "text", 

1436 embedding_length: int = 300, 

1437 min_freq: int = 3, 

1438 ): 

1439 """ 

1440 Initializes one-hot encoded word embeddings and a trainable embedding layer 

1441 :param corpus: you need to pass a Corpus in order to construct the vocabulary 

1442 :param field: by default, the 'text' of tokens is embedded, but you can also embed tags such as 'pos' 

1443 :param embedding_length: dimensionality of the trainable embedding layer 

1444 :param min_freq: minimum frequency of a word to become part of the vocabulary 

1445 """ 

1446 super().__init__() 

1447 self.name = "one-hot" 

1448 self.static_embeddings = False 

1449 self.min_freq = min_freq 

1450 self.field = field 

1451 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1452 

1453 tokens = list(map((lambda s: s.tokens), corpus.train)) 

1454 tokens = [token for sublist in tokens for token in sublist] 

1455 

1456 if field == "text": 

1457 most_common = Counter(list(map((lambda t: t.text), tokens))).most_common() 

1458 else: 

1459 most_common = Counter( 

1460 list(map((lambda t: t.get_tag(field).value), tokens)) 

1461 ).most_common() 

1462 

1463 tokens = [] 

1464 for token, freq in most_common: 

1465 if freq < min_freq: 

1466 break 

1467 tokens.append(token) 

1468 

1469 self.vocab_dictionary: Dictionary = Dictionary() 

1470 for token in tokens: 

1471 self.vocab_dictionary.add_item(token) 

1472 

1473 # max_tokens = 500 

1474 self.__embedding_length = embedding_length 

1475 

1476 print(self.vocab_dictionary.idx2item) 

1477 print(f"vocabulary size of {len(self.vocab_dictionary)}") 

1478 

1479 # model architecture 

1480 self.embedding_layer = torch.nn.Embedding( 

1481 len(self.vocab_dictionary), self.__embedding_length 

1482 ) 

1483 torch.nn.init.xavier_uniform_(self.embedding_layer.weight) 

1484 

1485 self.to(flair.device) 

1486 

1487 @property 

1488 def embedding_length(self) -> int: 

1489 return self.__embedding_length 

1490 

1491 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1492 

1493 one_hot_sentences = [] 

1494 for i, sentence in enumerate(sentences): 

1495 

1496 if self.field == "text": 

1497 context_idxs = [ 

1498 self.vocab_dictionary.get_idx_for_item(t.text) 

1499 for t in sentence.tokens 

1500 ] 

1501 else: 

1502 context_idxs = [ 

1503 self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value) 

1504 for t in sentence.tokens 

1505 ] 

1506 

1507 one_hot_sentences.extend(context_idxs) 

1508 

1509 one_hot_sentences = torch.tensor(one_hot_sentences, dtype=torch.long).to( 

1510 flair.device 

1511 ) 

1512 

1513 embedded = self.embedding_layer.forward(one_hot_sentences) 

1514 

1515 index = 0 

1516 for sentence in sentences: 

1517 for token in sentence: 

1518 embedding = embedded[index] 

1519 token.set_embedding(self.name, embedding) 

1520 index += 1 

1521 

1522 return sentences 

1523 

1524 def __str__(self): 

1525 return self.name 

1526 

1527 def extra_repr(self): 

1528 return "min_freq={}".format(self.min_freq) 

1529 

1530 

1531class HashEmbeddings(TokenEmbeddings): 

1532 """Standard embeddings with Hashing Trick.""" 

1533 

1534 def __init__( 

1535 self, num_embeddings: int = 1000, embedding_length: int = 300, hash_method="md5" 

1536 ): 

1537 

1538 super().__init__() 

1539 self.name = "hash" 

1540 self.static_embeddings = False 

1541 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1542 

1543 self.__num_embeddings = num_embeddings 

1544 self.__embedding_length = embedding_length 

1545 

1546 self.__hash_method = hash_method 

1547 

1548 # model architecture 

1549 self.embedding_layer = torch.nn.Embedding( 

1550 self.__num_embeddings, self.__embedding_length 

1551 ) 

1552 torch.nn.init.xavier_uniform_(self.embedding_layer.weight) 

1553 

1554 self.to(flair.device) 

1555 

1556 @property 

1557 def num_embeddings(self) -> int: 

1558 return self.__num_embeddings 

1559 

1560 @property 

1561 def embedding_length(self) -> int: 

1562 return self.__embedding_length 

1563 

1564 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1565 def get_idx_for_item(text): 

1566 hash_function = hashlib.new(self.__hash_method) 

1567 hash_function.update(bytes(str(text), "utf-8")) 

1568 return int(hash_function.hexdigest(), 16) % self.__num_embeddings 

1569 

1570 hash_sentences = [] 

1571 for i, sentence in enumerate(sentences): 

1572 context_idxs = [get_idx_for_item(t.text) for t in sentence.tokens] 

1573 

1574 hash_sentences.extend(context_idxs) 

1575 

1576 hash_sentences = torch.tensor(hash_sentences, dtype=torch.long).to(flair.device) 

1577 

1578 embedded = self.embedding_layer.forward(hash_sentences) 

1579 

1580 index = 0 

1581 for sentence in sentences: 

1582 for token in sentence: 

1583 embedding = embedded[index] 

1584 token.set_embedding(self.name, embedding) 

1585 index += 1 

1586 

1587 return sentences 

1588 

1589 def __str__(self): 

1590 return self.name 

1591 

1592 

1593class MuseCrosslingualEmbeddings(TokenEmbeddings): 

1594 def __init__(self, ): 

1595 self.name: str = f"muse-crosslingual" 

1596 self.static_embeddings = True 

1597 self.__embedding_length: int = 300 

1598 self.language_embeddings = {} 

1599 super().__init__() 

1600 

1601 @instance_lru_cache(maxsize=10000, typed=False) 

1602 def get_cached_vec(self, language_code: str, word: str) -> torch.Tensor: 

1603 current_embedding_model = self.language_embeddings[language_code] 

1604 if word in current_embedding_model: 

1605 word_embedding = current_embedding_model[word] 

1606 elif word.lower() in current_embedding_model: 

1607 word_embedding = current_embedding_model[word.lower()] 

1608 elif re.sub(r"\d", "#", word.lower()) in current_embedding_model: 

1609 word_embedding = current_embedding_model[re.sub(r"\d", "#", word.lower())] 

1610 elif re.sub(r"\d", "0", word.lower()) in current_embedding_model: 

1611 word_embedding = current_embedding_model[re.sub(r"\d", "0", word.lower())] 

1612 else: 

1613 word_embedding = np.zeros(self.embedding_length, dtype="float") 

1614 word_embedding = torch.tensor( 

1615 word_embedding, device=flair.device, dtype=torch.float 

1616 ) 

1617 return word_embedding 

1618 

1619 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1620 

1621 for i, sentence in enumerate(sentences): 

1622 

1623 language_code = sentence.get_language_code() 

1624 supported = [ 

1625 "en", 

1626 "de", 

1627 "bg", 

1628 "ca", 

1629 "hr", 

1630 "cs", 

1631 "da", 

1632 "nl", 

1633 "et", 

1634 "fi", 

1635 "fr", 

1636 "el", 

1637 "he", 

1638 "hu", 

1639 "id", 

1640 "it", 

1641 "mk", 

1642 "no", 

1643 "pl", 

1644 "pt", 

1645 "ro", 

1646 "ru", 

1647 "sk", 

1648 ] 

1649 if language_code not in supported: 

1650 language_code = "en" 

1651 

1652 if language_code not in self.language_embeddings: 

1653 log.info(f"Loading up MUSE embeddings for '{language_code}'!") 

1654 # download if necessary 

1655 hu_path: str = "https://flair.informatik.hu-berlin.de/resources/embeddings/muse" 

1656 cache_dir = Path("embeddings") / "MUSE" 

1657 cached_path( 

1658 f"{hu_path}/muse.{language_code}.vec.gensim.vectors.npy", 

1659 cache_dir=cache_dir, 

1660 ) 

1661 embeddings_file = cached_path( 

1662 f"{hu_path}/muse.{language_code}.vec.gensim", cache_dir=cache_dir 

1663 ) 

1664 

1665 # load the model 

1666 self.language_embeddings[ 

1667 language_code 

1668 ] = gensim.models.KeyedVectors.load(str(embeddings_file)) 

1669 

1670 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1671 

1672 if "field" not in self.__dict__ or self.field is None: 

1673 word = token.text 

1674 else: 

1675 word = token.get_tag(self.field).value 

1676 

1677 word_embedding = self.get_cached_vec( 

1678 language_code=language_code, word=word 

1679 ) 

1680 

1681 token.set_embedding(self.name, word_embedding) 

1682 

1683 return sentences 

1684 

1685 @property 

1686 def embedding_length(self) -> int: 

1687 return self.__embedding_length 

1688 

1689 def __str__(self): 

1690 return self.name 

1691 

1692 

1693# TODO: keep for backwards compatibility, but remove in future 

1694class BPEmbSerializable(BPEmb): 

1695 def __getstate__(self): 

1696 state = self.__dict__.copy() 

1697 # save the sentence piece model as binary file (not as path which may change) 

1698 state["spm_model_binary"] = open(self.model_file, mode="rb").read() 

1699 state["spm"] = None 

1700 return state 

1701 

1702 def __setstate__(self, state): 

1703 from bpemb.util import sentencepiece_load 

1704 

1705 model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"]) 

1706 self.__dict__ = state 

1707 

1708 # write out the binary sentence piece model into the expected directory 

1709 self.cache_dir: Path = flair.cache_root / "embeddings" 

1710 if "spm_model_binary" in self.__dict__: 

1711 # if the model was saved as binary and it is not found on disk, write to appropriate path 

1712 if not os.path.exists(self.cache_dir / state["lang"]): 

1713 os.makedirs(self.cache_dir / state["lang"]) 

1714 self.model_file = self.cache_dir / model_file 

1715 with open(self.model_file, "wb") as out: 

1716 out.write(self.__dict__["spm_model_binary"]) 

1717 else: 

1718 # otherwise, use normal process and potentially trigger another download 

1719 self.model_file = self._load_file(model_file) 

1720 

1721 # once the modes if there, load it with sentence piece 

1722 state["spm"] = sentencepiece_load(self.model_file) 

1723 

1724 

1725class BytePairEmbeddings(TokenEmbeddings): 

1726 def __init__( 

1727 self, 

1728 language: str = None, 

1729 dim: int = 50, 

1730 syllables: int = 100000, 

1731 cache_dir=None, 

1732 model_file_path: Path = None, 

1733 embedding_file_path: Path = None, 

1734 **kwargs, 

1735 ): 

1736 """ 

1737 Initializes BP embeddings. Constructor downloads required files if not there. 

1738 """ 

1739 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1740 

1741 if not cache_dir: 

1742 cache_dir = flair.cache_root / "embeddings" 

1743 if language: 

1744 self.name: str = f"bpe-{language}-{syllables}-{dim}" 

1745 else: 

1746 assert ( 

1747 model_file_path is not None and embedding_file_path is not None 

1748 ), "Need to specify model_file_path and embedding_file_path if no language is given in BytePairEmbeddings(...)" 

1749 dim = None 

1750 

1751 self.embedder = BPEmbSerializable( 

1752 lang=language, 

1753 vs=syllables, 

1754 dim=dim, 

1755 cache_dir=cache_dir, 

1756 model_file=model_file_path, 

1757 emb_file=embedding_file_path, 

1758 **kwargs, 

1759 ) 

1760 

1761 if not language: 

1762 self.name: str = f"bpe-custom-{self.embedder.vs}-{self.embedder.dim}" 

1763 self.static_embeddings = True 

1764 

1765 self.__embedding_length: int = self.embedder.emb.vector_size * 2 

1766 super().__init__() 

1767 

1768 @property 

1769 def embedding_length(self) -> int: 

1770 return self.__embedding_length 

1771 

1772 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1773 

1774 for i, sentence in enumerate(sentences): 

1775 

1776 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1777 

1778 if "field" not in self.__dict__ or self.field is None: 

1779 word = token.text 

1780 else: 

1781 word = token.get_tag(self.field).value 

1782 

1783 if word.strip() == "": 

1784 # empty words get no embedding 

1785 token.set_embedding( 

1786 self.name, torch.zeros(self.embedding_length, dtype=torch.float) 

1787 ) 

1788 else: 

1789 # all other words get embedded 

1790 embeddings = self.embedder.embed(word.lower()) 

1791 embedding = np.concatenate( 

1792 (embeddings[0], embeddings[len(embeddings) - 1]) 

1793 ) 

1794 token.set_embedding( 

1795 self.name, torch.tensor(embedding, dtype=torch.float) 

1796 ) 

1797 

1798 return sentences 

1799 

1800 def __str__(self): 

1801 return self.name 

1802 

1803 def extra_repr(self): 

1804 return "model={}".format(self.name) 

1805 

1806 

1807class ELMoEmbeddings(TokenEmbeddings): 

1808 """Contextual word embeddings using word-level LM, as proposed in Peters et al., 2018. 

1809 ELMo word vectors can be constructed by combining layers in different ways. 

1810 Default is to concatene the top 3 layers in the LM.""" 

1811 

1812 def __init__( 

1813 self, model: str = "original", options_file: str = None, weight_file: str = None, 

1814 embedding_mode: str = "all" 

1815 ): 

1816 super().__init__() 

1817 

1818 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1819 

1820 try: 

1821 import allennlp.commands.elmo 

1822 except ModuleNotFoundError: 

1823 log.warning("-" * 100) 

1824 log.warning('ATTENTION! The library "allennlp" is not installed!') 

1825 log.warning( 

1826 'To use ELMoEmbeddings, please first install with "pip install allennlp==0.9.0"' 

1827 ) 

1828 log.warning("-" * 100) 

1829 pass 

1830 

1831 assert embedding_mode in ["all", "top", "average"] 

1832 

1833 self.name = f"elmo-{model}-{embedding_mode}" 

1834 self.static_embeddings = True 

1835 

1836 if not options_file or not weight_file: 

1837 # the default model for ELMo is the 'original' model, which is very large 

1838 options_file = allennlp.commands.elmo.DEFAULT_OPTIONS_FILE 

1839 weight_file = allennlp.commands.elmo.DEFAULT_WEIGHT_FILE 

1840 # alternatively, a small, medium or portuguese model can be selected by passing the appropriate mode name 

1841 if model == "small": 

1842 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_options.json" 

1843 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x1024_128_2048cnn_1xhighway/elmo_2x1024_128_2048cnn_1xhighway_weights.hdf5" 

1844 if model == "medium": 

1845 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_options.json" 

1846 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x2048_256_2048cnn_1xhighway/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" 

1847 if model in ["large", "5.5B"]: 

1848 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json" 

1849 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5" 

1850 if model == "pt" or model == "portuguese": 

1851 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_options.json" 

1852 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pt/elmo_pt_weights.hdf5" 

1853 if model == "pubmed": 

1854 options_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_options.json" 

1855 weight_file = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/contributed/pubmed/elmo_2x4096_512_2048cnn_2xhighway_weights_PubMed_only.hdf5" 

1856 

1857 if embedding_mode == "all": 

1858 self.embedding_mode_fn = self.use_layers_all 

1859 elif embedding_mode == "top": 

1860 self.embedding_mode_fn = self.use_layers_top 

1861 elif embedding_mode == "average": 

1862 self.embedding_mode_fn = self.use_layers_average 

1863 

1864 # put on Cuda if available 

1865 from flair import device 

1866 

1867 if re.fullmatch(r"cuda:[0-9]+", str(device)): 

1868 cuda_device = int(str(device).split(":")[-1]) 

1869 elif str(device) == "cpu": 

1870 cuda_device = -1 

1871 else: 

1872 cuda_device = 0 

1873 

1874 self.ee = allennlp.commands.elmo.ElmoEmbedder( 

1875 options_file=options_file, weight_file=weight_file, cuda_device=cuda_device 

1876 ) 

1877 

1878 # embed a dummy sentence to determine embedding_length 

1879 dummy_sentence: Sentence = Sentence() 

1880 dummy_sentence.add_token(Token("hello")) 

1881 embedded_dummy = self.embed(dummy_sentence) 

1882 self.__embedding_length: int = len( 

1883 embedded_dummy[0].get_token(1).get_embedding() 

1884 ) 

1885 

1886 @property 

1887 def embedding_length(self) -> int: 

1888 return self.__embedding_length 

1889 

1890 def use_layers_all(self, x): 

1891 return torch.cat(x, 0) 

1892 

1893 def use_layers_top(self, x): 

1894 return x[-1] 

1895 

1896 def use_layers_average(self, x): 

1897 return torch.mean(torch.stack(x), 0) 

1898 

1899 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

1900 # ELMoEmbeddings before Release 0.5 did not set self.embedding_mode_fn 

1901 if not getattr(self, "embedding_mode_fn", None): 

1902 self.embedding_mode_fn = self.use_layers_all 

1903 

1904 sentence_words: List[List[str]] = [] 

1905 for sentence in sentences: 

1906 sentence_words.append([token.text for token in sentence]) 

1907 

1908 embeddings = self.ee.embed_batch(sentence_words) 

1909 

1910 for i, sentence in enumerate(sentences): 

1911 

1912 sentence_embeddings = embeddings[i] 

1913 

1914 for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))): 

1915 elmo_embedding_layers = [ 

1916 torch.FloatTensor(sentence_embeddings[0, token_idx, :]), 

1917 torch.FloatTensor(sentence_embeddings[1, token_idx, :]), 

1918 torch.FloatTensor(sentence_embeddings[2, token_idx, :]) 

1919 ] 

1920 word_embedding = self.embedding_mode_fn(elmo_embedding_layers) 

1921 token.set_embedding(self.name, word_embedding) 

1922 

1923 return sentences 

1924 

1925 def extra_repr(self): 

1926 return "model={}".format(self.name) 

1927 

1928 def __str__(self): 

1929 return self.name 

1930 

1931 def __setstate__(self, state): 

1932 self.__dict__ = state 

1933 

1934 if re.fullmatch(r"cuda:[0-9]+", str(flair.device)): 

1935 cuda_device = int(str(flair.device).split(":")[-1]) 

1936 elif str(flair.device) == "cpu": 

1937 cuda_device = -1 

1938 else: 

1939 cuda_device = 0 

1940 

1941 self.ee.cuda_device = cuda_device 

1942 

1943 self.ee.elmo_bilm.to(device=flair.device) 

1944 self.ee.elmo_bilm._elmo_lstm._states = tuple( 

1945 [state.to(flair.device) for state in self.ee.elmo_bilm._elmo_lstm._states]) 

1946 

1947 

1948class NILCEmbeddings(WordEmbeddings): 

1949 def __init__(self, embeddings: str, model: str = "skip", size: int = 100): 

1950 """ 

1951 Initializes portuguese classic word embeddings trained by NILC Lab (http://www.nilc.icmc.usp.br/embeddings). 

1952 Constructor downloads required files if not there. 

1953 :param embeddings: one of: 'fasttext', 'glove', 'wang2vec' or 'word2vec' 

1954 :param model: one of: 'skip' or 'cbow'. This is not applicable to glove. 

1955 :param size: one of: 50, 100, 300, 600 or 1000. 

1956 """ 

1957 

1958 self.instance_parameters = self.get_instance_parameters(locals=locals()) 

1959 

1960 base_path = "http://143.107.183.175:22980/download.php?file=embeddings/" 

1961 

1962 cache_dir = Path("embeddings") / embeddings.lower() 

1963 

1964 # GLOVE embeddings 

1965 if embeddings.lower() == "glove": 

1966 cached_path( 

1967 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir 

1968 ) 

1969 embeddings = cached_path( 

1970 f"{base_path}{embeddings}/{embeddings}_s{size}.zip", cache_dir=cache_dir 

1971 ) 

1972 

1973 elif embeddings.lower() in ["fasttext", "wang2vec", "word2vec"]: 

1974 cached_path( 

1975 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir 

1976 ) 

1977 embeddings = cached_path( 

1978 f"{base_path}{embeddings}/{model}_s{size}.zip", cache_dir=cache_dir 

1979 ) 

1980 

1981 elif not Path(embeddings).exists(): 

1982 raise ValueError( 

1983 f'The given embeddings "{embeddings}" is not available or is not a valid path.' 

1984 ) 

1985 

1986 self.name: str = str(embeddings) 

1987 self.static_embeddings = True 

1988 

1989 log.info("Reading embeddings from %s" % embeddings) 

1990 self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format( 

1991 open_inside_zip(str(embeddings), cache_dir=cache_dir) 

1992 ) 

1993 

1994 self.__embedding_length: int = self.precomputed_word_embeddings.vector_size 

1995 super(TokenEmbeddings, self).__init__() 

1996 

1997 @property 

1998 def embedding_length(self) -> int: 

1999 return self.__embedding_length 

2000 

2001 def __str__(self): 

2002 return self.name 

2003 

2004 

2005def replace_with_language_code(string: str): 

2006 string = string.replace("arabic-", "ar-") 

2007 string = string.replace("basque-", "eu-") 

2008 string = string.replace("bulgarian-", "bg-") 

2009 string = string.replace("croatian-", "hr-") 

2010 string = string.replace("czech-", "cs-") 

2011 string = string.replace("danish-", "da-") 

2012 string = string.replace("dutch-", "nl-") 

2013 string = string.replace("farsi-", "fa-") 

2014 string = string.replace("persian-", "fa-") 

2015 string = string.replace("finnish-", "fi-") 

2016 string = string.replace("french-", "fr-") 

2017 string = string.replace("german-", "de-") 

2018 string = string.replace("hebrew-", "he-") 

2019 string = string.replace("hindi-", "hi-") 

2020 string = string.replace("indonesian-", "id-") 

2021 string = string.replace("italian-", "it-") 

2022 string = string.replace("japanese-", "ja-") 

2023 string = string.replace("norwegian-", "no") 

2024 string = string.replace("polish-", "pl-") 

2025 string = string.replace("portuguese-", "pt-") 

2026 string = string.replace("slovenian-", "sl-") 

2027 string = string.replace("spanish-", "es-") 

2028 string = string.replace("swedish-", "sv-") 

2029 return string