Coverage for flair/flair/embeddings/document.py: 33%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

409 statements  

1import logging 

2from abc import abstractmethod 

3from typing import List, Union 

4 

5import torch 

6from sklearn.feature_extraction.text import TfidfVectorizer 

7from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 

8from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer 

9 

10import flair 

11from flair.data import Sentence 

12from flair.embeddings.base import Embeddings, ScalarMix 

13from flair.embeddings.token import TokenEmbeddings, StackedEmbeddings, FlairEmbeddings 

14from flair.nn import LockedDropout, WordDropout 

15 

16log = logging.getLogger("flair") 

17 

18 

19class DocumentEmbeddings(Embeddings): 

20 """Abstract base class for all document-level embeddings. Every new type of document embedding must implement these methods.""" 

21 

22 @property 

23 @abstractmethod 

24 def embedding_length(self) -> int: 

25 """Returns the length of the embedding vector.""" 

26 pass 

27 

28 @property 

29 def embedding_type(self) -> str: 

30 return "sentence-level" 

31 

32 

33class TransformerDocumentEmbeddings(DocumentEmbeddings): 

34 def __init__( 

35 self, 

36 model: str = "bert-base-uncased", 

37 fine_tune: bool = True, 

38 layers: str = "-1", 

39 layer_mean: bool = False, 

40 pooling: str = "cls", 

41 **kwargs 

42 ): 

43 """ 

44 Bidirectional transformer embeddings of words from various transformer architectures. 

45 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for 

46 options) 

47 :param fine_tune: If True, allows transformers to be fine-tuned during training 

48 :param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer 

49 models tend to be huge. 

50 :param layers: string indicating which layers to take for embedding (-1 is topmost layer) 

51 :param layer_mean: If True, uses a scalar mix of layers as embedding 

52 :param pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'. 

53 """ 

54 super().__init__() 

55 

56 if pooling not in ['cls', 'max', 'mean']: 

57 raise ValueError(f"Pooling operation `{pooling}` is not defined for TransformerDocumentEmbeddings") 

58 

59 # temporary fix to disable tokenizer parallelism warning 

60 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) 

61 import os 

62 os.environ["TOKENIZERS_PARALLELISM"] = "false" 

63 

64 # do not print transformer warnings as these are confusing in this case 

65 from transformers import logging 

66 logging.set_verbosity_error() 

67 

68 # load tokenizer and transformer model 

69 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) 

70 if not 'config' in kwargs: 

71 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) 

72 self.model = AutoModel.from_pretrained(model, config=config) 

73 else: 

74 self.model = AutoModel.from_pretrained(None, **kwargs) 

75 

76 logging.set_verbosity_warning() 

77 

78 # model name 

79 self.name = 'transformer-document-' + str(model) 

80 self.base_model_name = str(model) 

81 

82 # when initializing, embeddings are in eval mode by default 

83 self.model.eval() 

84 self.model.to(flair.device) 

85 

86 # embedding parameters 

87 if layers == 'all': 

88 # send mini-token through to check how many layers the model has 

89 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] 

90 self.layer_indexes = [int(x) for x in range(len(hidden_states))] 

91 else: 

92 self.layer_indexes = [int(x) for x in layers.split(",")] 

93 

94 self.layer_mean = layer_mean 

95 self.fine_tune = fine_tune 

96 self.static_embeddings = not self.fine_tune 

97 self.pooling = pooling 

98 

99 # check whether CLS is at beginning or end 

100 self.initial_cls_token: bool = self._has_initial_cls_token(tokenizer=self.tokenizer) 

101 

102 @staticmethod 

103 def _has_initial_cls_token(tokenizer: PreTrainedTokenizer) -> bool: 

104 # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial 

105 tokens = tokenizer.encode('a') 

106 initial_cls_token: bool = False 

107 if tokens[0] == tokenizer.cls_token_id: initial_cls_token = True 

108 return initial_cls_token 

109 

110 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

111 """Add embeddings to all words in a list of sentences.""" 

112 

113 # gradients are enabled if fine-tuning is enabled 

114 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() 

115 

116 with gradient_context: 

117 

118 # first, subtokenize each sentence and find out into how many subtokens each token was divided 

119 subtokenized_sentences = [] 

120 

121 # subtokenize sentences 

122 for sentence in sentences: 

123 # tokenize and truncate to max subtokens (TODO: check better truncation strategies) 

124 subtokenized_sentence = self.tokenizer.encode(sentence.to_tokenized_string(), 

125 add_special_tokens=True, 

126 max_length=self.tokenizer.model_max_length, 

127 truncation=True, 

128 ) 

129 

130 subtokenized_sentences.append( 

131 torch.tensor(subtokenized_sentence, dtype=torch.long, device=flair.device)) 

132 

133 # find longest sentence in batch 

134 longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len)) 

135 

136 # initialize batch tensors and mask 

137 input_ids = torch.zeros( 

138 [len(sentences), longest_sequence_in_batch], 

139 dtype=torch.long, 

140 device=flair.device, 

141 ) 

142 mask = torch.zeros( 

143 [len(sentences), longest_sequence_in_batch], 

144 dtype=torch.long, 

145 device=flair.device, 

146 ) 

147 for s_id, sentence in enumerate(subtokenized_sentences): 

148 sequence_length = len(sentence) 

149 input_ids[s_id][:sequence_length] = sentence 

150 mask[s_id][:sequence_length] = torch.ones(sequence_length) 

151 

152 # put encoded batch through transformer model to get all hidden states of all encoder layers 

153 hidden_states = self.model(input_ids, attention_mask=mask)[-1] if len(sentences) > 1 \ 

154 else self.model(input_ids)[-1] 

155 

156 # iterate over all subtokenized sentences 

157 for sentence_idx, (sentence, subtokens) in enumerate(zip(sentences, subtokenized_sentences)): 

158 

159 if self.pooling == "cls": 

160 index_of_CLS_token = 0 if self.initial_cls_token else len(subtokens) - 1 

161 

162 cls_embeddings_all_layers: List[torch.FloatTensor] = \ 

163 [hidden_states[layer][sentence_idx][index_of_CLS_token] for layer in self.layer_indexes] 

164 

165 embeddings_all_layers = cls_embeddings_all_layers 

166 

167 elif self.pooling == "mean": 

168 mean_embeddings_all_layers: List[torch.FloatTensor] = \ 

169 [torch.mean(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0) for layer in 

170 self.layer_indexes] 

171 

172 embeddings_all_layers = mean_embeddings_all_layers 

173 

174 elif self.pooling == "max": 

175 max_embeddings_all_layers: List[torch.FloatTensor] = \ 

176 [torch.max(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0)[0] for layer in 

177 self.layer_indexes] 

178 

179 embeddings_all_layers = max_embeddings_all_layers 

180 

181 # use scalar mix of embeddings if so selected 

182 if self.layer_mean: 

183 sm = ScalarMix(mixture_size=len(embeddings_all_layers)) 

184 sm_embeddings = sm(embeddings_all_layers) 

185 

186 embeddings_all_layers = [sm_embeddings] 

187 

188 # set the extracted embedding for the token 

189 sentence.set_embedding(self.name, torch.cat(embeddings_all_layers)) 

190 

191 return sentences 

192 

193 @property 

194 @abstractmethod 

195 def embedding_length(self) -> int: 

196 """Returns the length of the embedding vector.""" 

197 return ( 

198 len(self.layer_indexes) * self.model.config.hidden_size 

199 if not self.layer_mean 

200 else self.model.config.hidden_size 

201 ) 

202 

203 def __getstate__(self): 

204 # special handling for serializing transformer models 

205 config_state_dict = self.model.config.__dict__ 

206 model_state_dict = self.model.state_dict() 

207 

208 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-document-')[-1] 

209 

210 # serialize the transformer models and the constructor arguments (but nothing else) 

211 model_state = { 

212 "config_state_dict": config_state_dict, 

213 "model_state_dict": model_state_dict, 

214 "embedding_length_internal": self.embedding_length, 

215 

216 "base_model_name": self.base_model_name, 

217 "fine_tune": self.fine_tune, 

218 "layer_indexes": self.layer_indexes, 

219 "layer_mean": self.layer_mean, 

220 "pooling": self.pooling, 

221 } 

222 

223 return model_state 

224 

225 def __setstate__(self, d): 

226 self.__dict__ = d 

227 

228 # necessary for reverse compatibility with Flair <= 0.7 

229 if 'use_scalar_mix' in self.__dict__.keys(): 

230 self.__dict__['layer_mean'] = d['use_scalar_mix'] 

231 

232 # special handling for deserializing transformer models 

233 if "config_state_dict" in d: 

234 

235 # load transformer model 

236 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" 

237 config_class = CONFIG_MAPPING[model_type] 

238 loaded_config = config_class.from_dict(d["config_state_dict"]) 

239 

240 # constructor arguments 

241 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']]) 

242 

243 # re-initialize transformer word embeddings with constructor arguments 

244 embedding = TransformerDocumentEmbeddings( 

245 model=self.__dict__['base_model_name'], 

246 fine_tune=self.__dict__['fine_tune'], 

247 layers=layers, 

248 layer_mean=self.__dict__['layer_mean'], 

249 

250 config=loaded_config, 

251 state_dict=d["model_state_dict"], 

252 pooling=self.__dict__['pooling'] if 'pooling' in self.__dict__ else 'cls', 

253 # for backward compatibility with previous models 

254 ) 

255 

256 # I have no idea why this is necessary, but otherwise it doesn't work 

257 for key in embedding.__dict__.keys(): 

258 self.__dict__[key] = embedding.__dict__[key] 

259 

260 else: 

261 model_name = self.__dict__['name'].split('transformer-document-')[-1] 

262 # reload tokenizer to get around serialization issues 

263 try: 

264 tokenizer = AutoTokenizer.from_pretrained(model_name) 

265 except: 

266 pass 

267 self.tokenizer = tokenizer 

268 

269 

270class DocumentPoolEmbeddings(DocumentEmbeddings): 

271 def __init__( 

272 self, 

273 embeddings: List[TokenEmbeddings], 

274 fine_tune_mode: str = "none", 

275 pooling: str = "mean", 

276 ): 

277 """The constructor takes a list of embeddings to be combined. 

278 :param embeddings: a list of token embeddings 

279 :param fine_tune_mode: if set to "linear" a trainable layer is added, if set to 

280 "nonlinear", a nonlinearity is added as well. Set this to make the pooling trainable. 

281 :param pooling: a string which can any value from ['mean', 'max', 'min'] 

282 """ 

283 super().__init__() 

284 

285 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

286 self.__embedding_length = self.embeddings.embedding_length 

287 

288 # optional fine-tuning on top of embedding layer 

289 self.fine_tune_mode = fine_tune_mode 

290 if self.fine_tune_mode in ["nonlinear", "linear"]: 

291 self.embedding_flex = torch.nn.Linear( 

292 self.embedding_length, self.embedding_length, bias=False 

293 ) 

294 self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length)) 

295 

296 if self.fine_tune_mode in ["nonlinear"]: 

297 self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length) 

298 self.embedding_flex_nonlinear_map = torch.nn.Linear( 

299 self.embedding_length, self.embedding_length 

300 ) 

301 

302 self.__embedding_length: int = self.embeddings.embedding_length 

303 

304 self.to(flair.device) 

305 

306 if pooling not in ['min', 'max', 'mean']: 

307 raise ValueError(f"Pooling operation for {self.mode!r} is not defined") 

308 

309 self.pooling = pooling 

310 self.name: str = f"document_{self.pooling}" 

311 

312 @property 

313 def embedding_length(self) -> int: 

314 return self.__embedding_length 

315 

316 def embed(self, sentences: Union[List[Sentence], Sentence]): 

317 """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates 

318 only if embeddings are non-static.""" 

319 

320 # if only one sentence is passed, convert to list of sentence 

321 if isinstance(sentences, Sentence): 

322 sentences = [sentences] 

323 

324 self.embeddings.embed(sentences) 

325 

326 for sentence in sentences: 

327 word_embeddings = [] 

328 for token in sentence.tokens: 

329 word_embeddings.append(token.get_embedding().unsqueeze(0)) 

330 

331 word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device) 

332 

333 if self.fine_tune_mode in ["nonlinear", "linear"]: 

334 word_embeddings = self.embedding_flex(word_embeddings) 

335 

336 if self.fine_tune_mode in ["nonlinear"]: 

337 word_embeddings = self.embedding_flex_nonlinear(word_embeddings) 

338 word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings) 

339 

340 if self.pooling == "mean": 

341 pooled_embedding = torch.mean(word_embeddings, 0) 

342 elif self.pooling == "max": 

343 pooled_embedding, _ = torch.max(word_embeddings, 0) 

344 elif self.pooling == "min": 

345 pooled_embedding, _ = torch.min(word_embeddings, 0) 

346 

347 sentence.set_embedding(self.name, pooled_embedding) 

348 

349 def _add_embeddings_internal(self, sentences: List[Sentence]): 

350 pass 

351 

352 def extra_repr(self): 

353 return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}" 

354 

355 

356class DocumentTFIDFEmbeddings(DocumentEmbeddings): 

357 def __init__( 

358 self, 

359 train_dataset, 

360 **vectorizer_params, 

361 ): 

362 """The constructor for DocumentTFIDFEmbeddings. 

363 :param train_dataset: the train dataset which will be used to construct vectorizer 

364 :param vectorizer_params: parameters given to Scikit-learn's TfidfVectorizer constructor 

365 """ 

366 super().__init__() 

367 

368 import numpy as np 

369 self.vectorizer = TfidfVectorizer(dtype=np.float32, **vectorizer_params) 

370 self.vectorizer.fit([s.to_original_text() for s in train_dataset]) 

371 

372 self.__embedding_length: int = len(self.vectorizer.vocabulary_) 

373 

374 self.to(flair.device) 

375 

376 self.name: str = f"document_tfidf" 

377 

378 @property 

379 def embedding_length(self) -> int: 

380 return self.__embedding_length 

381 

382 def embed(self, sentences: Union[List[Sentence], Sentence]): 

383 """Add embeddings to every sentence in the given list of sentences.""" 

384 

385 # if only one sentence is passed, convert to list of sentence 

386 if isinstance(sentences, Sentence): 

387 sentences = [sentences] 

388 

389 raw_sentences = [s.to_original_text() for s in sentences] 

390 tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A) 

391 

392 for sentence_id, sentence in enumerate(sentences): 

393 sentence.set_embedding(self.name, tfidf_vectors[sentence_id]) 

394 

395 def _add_embeddings_internal(self, sentences: List[Sentence]): 

396 pass 

397 

398 

399class DocumentRNNEmbeddings(DocumentEmbeddings): 

400 def __init__( 

401 self, 

402 embeddings: List[TokenEmbeddings], 

403 hidden_size=128, 

404 rnn_layers=1, 

405 reproject_words: bool = True, 

406 reproject_words_dimension: int = None, 

407 bidirectional: bool = False, 

408 dropout: float = 0.5, 

409 word_dropout: float = 0.0, 

410 locked_dropout: float = 0.0, 

411 rnn_type="GRU", 

412 fine_tune: bool = True, 

413 ): 

414 """The constructor takes a list of embeddings to be combined. 

415 :param embeddings: a list of token embeddings 

416 :param hidden_size: the number of hidden states in the rnn 

417 :param rnn_layers: the number of layers for the rnn 

418 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear 

419 layer before putting them into the rnn or not 

420 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output 

421 dimension as before will be taken. 

422 :param bidirectional: boolean value, indicating whether to use a bidirectional rnn or not 

423 :param dropout: the dropout value to be used 

424 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used 

425 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used 

426 :param rnn_type: 'GRU' or 'LSTM' 

427 """ 

428 super().__init__() 

429 

430 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

431 

432 self.rnn_type = rnn_type 

433 

434 self.reproject_words = reproject_words 

435 self.bidirectional = bidirectional 

436 

437 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length 

438 

439 self.static_embeddings = False if fine_tune else True 

440 

441 self.__embedding_length: int = hidden_size 

442 if self.bidirectional: 

443 self.__embedding_length *= 4 

444 

445 self.embeddings_dimension: int = self.length_of_all_token_embeddings 

446 if self.reproject_words and reproject_words_dimension is not None: 

447 self.embeddings_dimension = reproject_words_dimension 

448 

449 self.word_reprojection_map = torch.nn.Linear( 

450 self.length_of_all_token_embeddings, self.embeddings_dimension 

451 ) 

452 

453 # bidirectional RNN on top of embedding layer 

454 if rnn_type == "LSTM": 

455 self.rnn = torch.nn.LSTM( 

456 self.embeddings_dimension, 

457 hidden_size, 

458 num_layers=rnn_layers, 

459 bidirectional=self.bidirectional, 

460 batch_first=True, 

461 ) 

462 else: 

463 self.rnn = torch.nn.GRU( 

464 self.embeddings_dimension, 

465 hidden_size, 

466 num_layers=rnn_layers, 

467 bidirectional=self.bidirectional, 

468 batch_first=True, 

469 ) 

470 

471 self.name = "document_" + self.rnn._get_name() 

472 

473 # dropouts 

474 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None 

475 self.locked_dropout = ( 

476 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None 

477 ) 

478 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None 

479 

480 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight) 

481 

482 self.to(flair.device) 

483 

484 self.eval() 

485 

486 @property 

487 def embedding_length(self) -> int: 

488 return self.__embedding_length 

489 

490 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): 

491 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update 

492 only if embeddings are non-static.""" 

493 

494 # TODO: remove in future versions 

495 if not hasattr(self, "locked_dropout"): 

496 self.locked_dropout = None 

497 if not hasattr(self, "word_dropout"): 

498 self.word_dropout = None 

499 

500 if type(sentences) is Sentence: 

501 sentences = [sentences] 

502 

503 self.rnn.zero_grad() 

504 

505 # embed words in the sentence 

506 self.embeddings.embed(sentences) 

507 

508 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

509 longest_token_sequence_in_batch: int = max(lengths) 

510 

511 pre_allocated_zero_tensor = torch.zeros( 

512 self.embeddings.embedding_length * longest_token_sequence_in_batch, 

513 dtype=torch.float, 

514 device=flair.device, 

515 ) 

516 

517 all_embs: List[torch.Tensor] = list() 

518 for sentence in sentences: 

519 all_embs += [ 

520 emb for token in sentence for emb in token.get_each_embedding() 

521 ] 

522 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) 

523 

524 if nb_padding_tokens > 0: 

525 t = pre_allocated_zero_tensor[ 

526 : self.embeddings.embedding_length * nb_padding_tokens 

527 ] 

528 all_embs.append(t) 

529 

530 sentence_tensor = torch.cat(all_embs).view( 

531 [ 

532 len(sentences), 

533 longest_token_sequence_in_batch, 

534 self.embeddings.embedding_length, 

535 ] 

536 ) 

537 

538 # before-RNN dropout 

539 if self.dropout: 

540 sentence_tensor = self.dropout(sentence_tensor) 

541 if self.locked_dropout: 

542 sentence_tensor = self.locked_dropout(sentence_tensor) 

543 if self.word_dropout: 

544 sentence_tensor = self.word_dropout(sentence_tensor) 

545 

546 # reproject if set 

547 if self.reproject_words: 

548 sentence_tensor = self.word_reprojection_map(sentence_tensor) 

549 

550 # push through RNN 

551 packed = pack_padded_sequence( 

552 sentence_tensor, lengths, enforce_sorted=False, batch_first=True 

553 ) 

554 rnn_out, hidden = self.rnn(packed) 

555 outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True) 

556 

557 # after-RNN dropout 

558 if self.dropout: 

559 outputs = self.dropout(outputs) 

560 if self.locked_dropout: 

561 outputs = self.locked_dropout(outputs) 

562 

563 # extract embeddings from RNN 

564 for sentence_no, length in enumerate(lengths): 

565 last_rep = outputs[sentence_no, length - 1] 

566 

567 embedding = last_rep 

568 if self.bidirectional: 

569 first_rep = outputs[sentence_no, 0] 

570 embedding = torch.cat([first_rep, last_rep], 0) 

571 

572 if self.static_embeddings: 

573 embedding = embedding.detach() 

574 

575 sentence = sentences[sentence_no] 

576 sentence.set_embedding(self.name, embedding) 

577 

578 def _apply(self, fn): 

579 

580 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute 

581 # check if this is the case and if so, set it 

582 for child_module in self.children(): 

583 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"): 

584 _flat_weights_names = [] 

585 

586 if child_module.__dict__["bidirectional"]: 

587 num_direction = 2 

588 else: 

589 num_direction = 1 

590 for layer in range(child_module.__dict__["num_layers"]): 

591 for direction in range(num_direction): 

592 suffix = "_reverse" if direction == 1 else "" 

593 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] 

594 if child_module.__dict__["bias"]: 

595 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] 

596 param_names = [ 

597 x.format(layer, suffix) for x in param_names 

598 ] 

599 _flat_weights_names.extend(param_names) 

600 

601 setattr(child_module, "_flat_weights_names", 

602 _flat_weights_names) 

603 

604 child_module._apply(fn) 

605 

606 def __getstate__(self): 

607 

608 # serialize the language models and the constructor arguments (but nothing else) 

609 model_state = { 

610 "state_dict": self.state_dict(), 

611 

612 "embeddings": self.embeddings.embeddings, 

613 "hidden_size": self.rnn.hidden_size, 

614 "rnn_layers": self.rnn.num_layers, 

615 "reproject_words": self.reproject_words, 

616 "reproject_words_dimension": self.embeddings_dimension, 

617 "bidirectional": self.bidirectional, 

618 "dropout": self.dropout.p if self.dropout is not None else 0., 

619 "word_dropout": self.word_dropout.p if self.word_dropout is not None else 0., 

620 "locked_dropout": self.locked_dropout.p if self.locked_dropout is not None else 0., 

621 "rnn_type": self.rnn_type, 

622 "fine_tune": not self.static_embeddings, 

623 } 

624 

625 return model_state 

626 

627 def __setstate__(self, d): 

628 

629 # special handling for deserializing language models 

630 if "state_dict" in d: 

631 

632 # re-initialize language model with constructor arguments 

633 language_model = DocumentRNNEmbeddings( 

634 embeddings=d['embeddings'], 

635 hidden_size=d['hidden_size'], 

636 rnn_layers=d['rnn_layers'], 

637 reproject_words=d['reproject_words'], 

638 reproject_words_dimension=d['reproject_words_dimension'], 

639 bidirectional=d['bidirectional'], 

640 dropout=d['dropout'], 

641 word_dropout=d['word_dropout'], 

642 locked_dropout=d['locked_dropout'], 

643 rnn_type=d['rnn_type'], 

644 fine_tune=d['fine_tune'], 

645 ) 

646 

647 language_model.load_state_dict(d['state_dict']) 

648 

649 # copy over state dictionary to self 

650 for key in language_model.__dict__.keys(): 

651 self.__dict__[key] = language_model.__dict__[key] 

652 

653 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM 

654 # in their "self.train()" method) 

655 self.eval() 

656 

657 else: 

658 self.__dict__ = d 

659 

660 

661class DocumentLMEmbeddings(DocumentEmbeddings): 

662 def __init__(self, flair_embeddings: List[FlairEmbeddings]): 

663 super().__init__() 

664 

665 self.embeddings = flair_embeddings 

666 self.name = "document_lm" 

667 

668 # IMPORTANT: add embeddings as torch modules 

669 for i, embedding in enumerate(flair_embeddings): 

670 self.add_module("lm_embedding_{}".format(i), embedding) 

671 if not embedding.static_embeddings: 

672 self.static_embeddings = False 

673 

674 self._embedding_length: int = sum( 

675 embedding.embedding_length for embedding in flair_embeddings 

676 ) 

677 

678 @property 

679 def embedding_length(self) -> int: 

680 return self._embedding_length 

681 

682 def _add_embeddings_internal(self, sentences: List[Sentence]): 

683 if type(sentences) is Sentence: 

684 sentences = [sentences] 

685 

686 for embedding in self.embeddings: 

687 embedding.embed(sentences) 

688 

689 # iterate over sentences 

690 for sentence in sentences: 

691 sentence: Sentence = sentence 

692 

693 # if its a forward LM, take last state 

694 if embedding.is_forward_lm: 

695 sentence.set_embedding( 

696 embedding.name, 

697 sentence[len(sentence) - 1]._embeddings[embedding.name], 

698 ) 

699 else: 

700 sentence.set_embedding( 

701 embedding.name, sentence[0]._embeddings[embedding.name] 

702 ) 

703 

704 return sentences 

705 

706 

707class SentenceTransformerDocumentEmbeddings(DocumentEmbeddings): 

708 def __init__( 

709 self, 

710 model: str = "bert-base-nli-mean-tokens", 

711 batch_size: int = 1, 

712 convert_to_numpy: bool = False, 

713 ): 

714 """ 

715 :param model: string name of models from SentencesTransformer Class 

716 :param name: string name of embedding type which will be set to Sentence object 

717 :param batch_size: int number of sentences to processed in one batch 

718 :param convert_to_numpy: bool whether the encode() returns a numpy array or PyTorch tensor 

719 """ 

720 super().__init__() 

721 

722 try: 

723 from sentence_transformers import SentenceTransformer 

724 except ModuleNotFoundError: 

725 log.warning("-" * 100) 

726 log.warning('ATTENTION! The library "sentence-transformers" is not installed!') 

727 log.warning( 

728 'To use Sentence Transformers, please first install with "pip install sentence-transformers"' 

729 ) 

730 log.warning("-" * 100) 

731 pass 

732 

733 self.model = SentenceTransformer(model) 

734 self.name = 'sentence-transformers-' + str(model) 

735 self.batch_size = batch_size 

736 self.convert_to_numpy = convert_to_numpy 

737 self.static_embeddings = True 

738 

739 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

740 

741 sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size] 

742 for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)] 

743 

744 for batch in sentence_batches: 

745 self._add_embeddings_to_sentences(batch) 

746 

747 return sentences 

748 

749 def _add_embeddings_to_sentences(self, sentences: List[Sentence]): 

750 

751 # convert to plain strings, embedded in a list for the encode function 

752 sentences_plain_text = [sentence.to_plain_string() for sentence in sentences] 

753 

754 embeddings = self.model.encode(sentences_plain_text, convert_to_numpy=self.convert_to_numpy) 

755 for sentence, embedding in zip(sentences, embeddings): 

756 sentence.set_embedding(self.name, embedding) 

757 

758 @property 

759 @abstractmethod 

760 def embedding_length(self) -> int: 

761 """Returns the length of the embedding vector.""" 

762 return self.model.get_sentence_embedding_dimension() 

763 

764 

765class DocumentCNNEmbeddings(DocumentEmbeddings): 

766 def __init__( 

767 self, 

768 embeddings: List[TokenEmbeddings], 

769 kernels=((100, 3), (100, 4), (100, 5)), 

770 reproject_words: bool = True, 

771 reproject_words_dimension: int = None, 

772 dropout: float = 0.5, 

773 word_dropout: float = 0.0, 

774 locked_dropout: float = 0.0, 

775 fine_tune: bool = True, 

776 ): 

777 """The constructor takes a list of embeddings to be combined. 

778 :param embeddings: a list of token embeddings 

779 :param kernels: list of (number of kernels, kernel size) 

780 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear 

781 layer before putting them into the rnn or not 

782 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output 

783 dimension as before will be taken. 

784 :param dropout: the dropout value to be used 

785 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used 

786 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used 

787 """ 

788 super().__init__() 

789 

790 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

791 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length 

792 

793 self.kernels = kernels 

794 self.reproject_words = reproject_words 

795 

796 self.static_embeddings = False if fine_tune else True 

797 

798 self.embeddings_dimension: int = self.length_of_all_token_embeddings 

799 if self.reproject_words and reproject_words_dimension is not None: 

800 self.embeddings_dimension = reproject_words_dimension 

801 

802 self.word_reprojection_map = torch.nn.Linear( 

803 self.length_of_all_token_embeddings, self.embeddings_dimension 

804 ) 

805 

806 # CNN 

807 self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels]) 

808 self.convs = torch.nn.ModuleList( 

809 [ 

810 torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in 

811 self.kernels 

812 ] 

813 ) 

814 self.pool = torch.nn.AdaptiveMaxPool1d(1) 

815 

816 self.name = "document_cnn" 

817 

818 # dropouts 

819 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None 

820 self.locked_dropout = ( 

821 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None 

822 ) 

823 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None 

824 

825 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight) 

826 

827 self.to(flair.device) 

828 

829 self.eval() 

830 

831 @property 

832 def embedding_length(self) -> int: 

833 return self.__embedding_length 

834 

835 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): 

836 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update 

837 only if embeddings are non-static.""" 

838 

839 # TODO: remove in future versions 

840 if not hasattr(self, "locked_dropout"): 

841 self.locked_dropout = None 

842 if not hasattr(self, "word_dropout"): 

843 self.word_dropout = None 

844 

845 if type(sentences) is Sentence: 

846 sentences = [sentences] 

847 

848 self.zero_grad() # is it necessary? 

849 

850 # embed words in the sentence 

851 self.embeddings.embed(sentences) 

852 

853 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

854 longest_token_sequence_in_batch: int = max(lengths) 

855 

856 pre_allocated_zero_tensor = torch.zeros( 

857 self.embeddings.embedding_length * longest_token_sequence_in_batch, 

858 dtype=torch.float, 

859 device=flair.device, 

860 ) 

861 

862 all_embs: List[torch.Tensor] = list() 

863 for sentence in sentences: 

864 all_embs += [ 

865 emb for token in sentence for emb in token.get_each_embedding() 

866 ] 

867 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) 

868 

869 if nb_padding_tokens > 0: 

870 t = pre_allocated_zero_tensor[ 

871 : self.embeddings.embedding_length * nb_padding_tokens 

872 ] 

873 all_embs.append(t) 

874 

875 sentence_tensor = torch.cat(all_embs).view( 

876 [ 

877 len(sentences), 

878 longest_token_sequence_in_batch, 

879 self.embeddings.embedding_length, 

880 ] 

881 ) 

882 

883 # before-RNN dropout 

884 if self.dropout: 

885 sentence_tensor = self.dropout(sentence_tensor) 

886 if self.locked_dropout: 

887 sentence_tensor = self.locked_dropout(sentence_tensor) 

888 if self.word_dropout: 

889 sentence_tensor = self.word_dropout(sentence_tensor) 

890 

891 # reproject if set 

892 if self.reproject_words: 

893 sentence_tensor = self.word_reprojection_map(sentence_tensor) 

894 

895 # push CNN 

896 x = sentence_tensor 

897 x = x.permute(0, 2, 1) 

898 

899 rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs] 

900 outputs = torch.cat(rep, 1) 

901 

902 outputs = outputs.reshape(outputs.size(0), -1) 

903 

904 # after-CNN dropout 

905 if self.dropout: 

906 outputs = self.dropout(outputs) 

907 if self.locked_dropout: 

908 outputs = self.locked_dropout(outputs) 

909 

910 # extract embeddings from CNN 

911 for sentence_no, length in enumerate(lengths): 

912 embedding = outputs[sentence_no] 

913 

914 if self.static_embeddings: 

915 embedding = embedding.detach() 

916 

917 sentence = sentences[sentence_no] 

918 sentence.set_embedding(self.name, embedding) 

919 

920 def _apply(self, fn): 

921 for child_module in self.children(): 

922 child_module._apply(fn)