Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/embeddings/document.py: 75%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

412 statements  

1import logging 

2from abc import abstractmethod 

3from typing import List, Union 

4 

5import torch 

6from sklearn.feature_extraction.text import TfidfVectorizer 

7from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 

8from transformers import AutoTokenizer, AutoConfig, AutoModel, CONFIG_MAPPING, PreTrainedTokenizer 

9 

10import flair 

11from flair.data import Sentence 

12from flair.embeddings.base import Embeddings, ScalarMix 

13from flair.embeddings.token import TokenEmbeddings, StackedEmbeddings, FlairEmbeddings 

14from flair.nn import LockedDropout, WordDropout 

15 

16log = logging.getLogger("flair") 

17 

18 

19class DocumentEmbeddings(Embeddings): 

20 """Abstract base class for all document-level embeddings. Every new type of document embedding must implement these methods.""" 

21 

22 @property 

23 @abstractmethod 

24 def embedding_length(self) -> int: 

25 """Returns the length of the embedding vector.""" 

26 pass 

27 

28 @property 

29 def embedding_type(self) -> str: 

30 return "sentence-level" 

31 

32 

33class TransformerDocumentEmbeddings(DocumentEmbeddings): 

34 def __init__( 

35 self, 

36 model: str = "bert-base-uncased", 

37 fine_tune: bool = True, 

38 layers: str = "-1", 

39 layer_mean: bool = False, 

40 pooling: str = "cls", 

41 **kwargs 

42 ): 

43 """ 

44 Bidirectional transformer embeddings of words from various transformer architectures. 

45 :param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for 

46 options) 

47 :param fine_tune: If True, allows transformers to be fine-tuned during training 

48 :param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer 

49 models tend to be huge. 

50 :param layers: string indicating which layers to take for embedding (-1 is topmost layer) 

51 :param layer_mean: If True, uses a scalar mix of layers as embedding 

52 :param pooling: Pooling strategy for combining token level embeddings. options are 'cls', 'max', 'mean'. 

53 """ 

54 super().__init__() 

55 

56 if pooling not in ['cls', 'max', 'mean']: 

57 raise ValueError(f"Pooling operation `{pooling}` is not defined for TransformerDocumentEmbeddings") 

58 

59 # temporary fix to disable tokenizer parallelism warning 

60 # (see https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning) 

61 import os 

62 os.environ["TOKENIZERS_PARALLELISM"] = "false" 

63 

64 # do not print transformer warnings as these are confusing in this case 

65 from transformers import logging 

66 logging.set_verbosity_error() 

67 

68 # load tokenizer and transformer model 

69 self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(model, **kwargs) 

70 if self.tokenizer.model_max_length > 1000000000: 

71 self.tokenizer.model_max_length = 512 

72 log.info("No model_max_length in Tokenizer's config.json - setting it to 512. " 

73 "Specify desired model_max_length by passing it as attribute to embedding instance.") 

74 if not 'config' in kwargs: 

75 config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs) 

76 self.model = AutoModel.from_pretrained(model, config=config) 

77 else: 

78 self.model = AutoModel.from_pretrained(None, **kwargs) 

79 

80 logging.set_verbosity_warning() 

81 

82 # model name 

83 self.name = 'transformer-document-' + str(model) 

84 self.base_model_name = str(model) 

85 

86 # when initializing, embeddings are in eval mode by default 

87 self.model.eval() 

88 self.model.to(flair.device) 

89 

90 # embedding parameters 

91 if layers == 'all': 

92 # send mini-token through to check how many layers the model has 

93 hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1] 

94 self.layer_indexes = [int(x) for x in range(len(hidden_states))] 

95 else: 

96 self.layer_indexes = [int(x) for x in layers.split(",")] 

97 

98 self.layer_mean = layer_mean 

99 self.fine_tune = fine_tune 

100 self.static_embeddings = not self.fine_tune 

101 self.pooling = pooling 

102 

103 # check whether CLS is at beginning or end 

104 self.initial_cls_token: bool = self._has_initial_cls_token(tokenizer=self.tokenizer) 

105 

106 @staticmethod 

107 def _has_initial_cls_token(tokenizer: PreTrainedTokenizer) -> bool: 

108 # most models have CLS token as last token (GPT-1, GPT-2, TransfoXL, XLNet, XLM), but BERT is initial 

109 tokens = tokenizer.encode('a') 

110 initial_cls_token: bool = False 

111 if tokens[0] == tokenizer.cls_token_id: initial_cls_token = True 

112 return initial_cls_token 

113 

114 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

115 """Add embeddings to all words in a list of sentences.""" 

116 

117 # gradients are enabled if fine-tuning is enabled 

118 gradient_context = torch.enable_grad() if (self.fine_tune and self.training) else torch.no_grad() 

119 

120 with gradient_context: 

121 

122 # first, subtokenize each sentence and find out into how many subtokens each token was divided 

123 subtokenized_sentences = [] 

124 

125 # subtokenize sentences 

126 for sentence in sentences: 

127 # tokenize and truncate to max subtokens (TODO: check better truncation strategies) 

128 subtokenized_sentence = self.tokenizer.encode(sentence.to_tokenized_string(), 

129 add_special_tokens=True, 

130 max_length=self.tokenizer.model_max_length, 

131 truncation=True, 

132 ) 

133 

134 subtokenized_sentences.append( 

135 torch.tensor(subtokenized_sentence, dtype=torch.long, device=flair.device)) 

136 

137 # find longest sentence in batch 

138 longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len)) 

139 

140 # initialize batch tensors and mask 

141 input_ids = torch.zeros( 

142 [len(sentences), longest_sequence_in_batch], 

143 dtype=torch.long, 

144 device=flair.device, 

145 ) 

146 mask = torch.zeros( 

147 [len(sentences), longest_sequence_in_batch], 

148 dtype=torch.long, 

149 device=flair.device, 

150 ) 

151 for s_id, sentence in enumerate(subtokenized_sentences): 

152 sequence_length = len(sentence) 

153 input_ids[s_id][:sequence_length] = sentence 

154 mask[s_id][:sequence_length] = torch.ones(sequence_length) 

155 

156 # put encoded batch through transformer model to get all hidden states of all encoder layers 

157 hidden_states = self.model(input_ids, attention_mask=mask)[-1] if len(sentences) > 1 \ 

158 else self.model(input_ids)[-1] 

159 

160 # iterate over all subtokenized sentences 

161 for sentence_idx, (sentence, subtokens) in enumerate(zip(sentences, subtokenized_sentences)): 

162 

163 if self.pooling == "cls": 

164 index_of_CLS_token = 0 if self.initial_cls_token else len(subtokens) - 1 

165 

166 cls_embeddings_all_layers: List[torch.FloatTensor] = \ 

167 [hidden_states[layer][sentence_idx][index_of_CLS_token] for layer in self.layer_indexes] 

168 

169 embeddings_all_layers = cls_embeddings_all_layers 

170 

171 elif self.pooling == "mean": 

172 mean_embeddings_all_layers: List[torch.FloatTensor] = \ 

173 [torch.mean(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0) for layer in 

174 self.layer_indexes] 

175 

176 embeddings_all_layers = mean_embeddings_all_layers 

177 

178 elif self.pooling == "max": 

179 max_embeddings_all_layers: List[torch.FloatTensor] = \ 

180 [torch.max(hidden_states[layer][sentence_idx][:len(subtokens), :], dim=0)[0] for layer in 

181 self.layer_indexes] 

182 

183 embeddings_all_layers = max_embeddings_all_layers 

184 

185 # use scalar mix of embeddings if so selected 

186 if self.layer_mean: 

187 sm = ScalarMix(mixture_size=len(embeddings_all_layers)) 

188 sm_embeddings = sm(embeddings_all_layers) 

189 

190 embeddings_all_layers = [sm_embeddings] 

191 

192 # set the extracted embedding for the token 

193 sentence.set_embedding(self.name, torch.cat(embeddings_all_layers)) 

194 

195 return sentences 

196 

197 @property 

198 @abstractmethod 

199 def embedding_length(self) -> int: 

200 """Returns the length of the embedding vector.""" 

201 return ( 

202 len(self.layer_indexes) * self.model.config.hidden_size 

203 if not self.layer_mean 

204 else self.model.config.hidden_size 

205 ) 

206 

207 def __getstate__(self): 

208 # special handling for serializing transformer models 

209 config_state_dict = self.model.config.__dict__ 

210 model_state_dict = self.model.state_dict() 

211 

212 if not hasattr(self, "base_model_name"): self.base_model_name = self.name.split('transformer-document-')[-1] 

213 

214 # serialize the transformer models and the constructor arguments (but nothing else) 

215 model_state = { 

216 "config_state_dict": config_state_dict, 

217 "model_state_dict": model_state_dict, 

218 "embedding_length_internal": self.embedding_length, 

219 

220 "base_model_name": self.base_model_name, 

221 "fine_tune": self.fine_tune, 

222 "layer_indexes": self.layer_indexes, 

223 "layer_mean": self.layer_mean, 

224 "pooling": self.pooling, 

225 } 

226 

227 return model_state 

228 

229 def __setstate__(self, d): 

230 self.__dict__ = d 

231 

232 # necessary for reverse compatibility with Flair <= 0.7 

233 if 'use_scalar_mix' in self.__dict__.keys(): 

234 self.__dict__['layer_mean'] = d['use_scalar_mix'] 

235 

236 # special handling for deserializing transformer models 

237 if "config_state_dict" in d: 

238 

239 # load transformer model 

240 model_type = d["config_state_dict"]["model_type"] if "model_type" in d["config_state_dict"] else "bert" 

241 config_class = CONFIG_MAPPING[model_type] 

242 loaded_config = config_class.from_dict(d["config_state_dict"]) 

243 

244 # constructor arguments 

245 layers = ','.join([str(idx) for idx in self.__dict__['layer_indexes']]) 

246 

247 # re-initialize transformer word embeddings with constructor arguments 

248 embedding = TransformerDocumentEmbeddings( 

249 model=self.__dict__['base_model_name'], 

250 fine_tune=self.__dict__['fine_tune'], 

251 layers=layers, 

252 layer_mean=self.__dict__['layer_mean'], 

253 

254 config=loaded_config, 

255 state_dict=d["model_state_dict"], 

256 pooling=self.__dict__['pooling'] if 'pooling' in self.__dict__ else 'cls', 

257 # for backward compatibility with previous models 

258 ) 

259 

260 # I have no idea why this is necessary, but otherwise it doesn't work 

261 for key in embedding.__dict__.keys(): 

262 self.__dict__[key] = embedding.__dict__[key] 

263 

264 else: 

265 model_name = self.__dict__['name'].split('transformer-document-')[-1] 

266 # reload tokenizer to get around serialization issues 

267 try: 

268 tokenizer = AutoTokenizer.from_pretrained(model_name) 

269 except: 

270 pass 

271 self.tokenizer = tokenizer 

272 

273 

274class DocumentPoolEmbeddings(DocumentEmbeddings): 

275 def __init__( 

276 self, 

277 embeddings: List[TokenEmbeddings], 

278 fine_tune_mode: str = "none", 

279 pooling: str = "mean", 

280 ): 

281 """The constructor takes a list of embeddings to be combined. 

282 :param embeddings: a list of token embeddings 

283 :param fine_tune_mode: if set to "linear" a trainable layer is added, if set to 

284 "nonlinear", a nonlinearity is added as well. Set this to make the pooling trainable. 

285 :param pooling: a string which can any value from ['mean', 'max', 'min'] 

286 """ 

287 super().__init__() 

288 

289 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

290 self.__embedding_length = self.embeddings.embedding_length 

291 

292 # optional fine-tuning on top of embedding layer 

293 self.fine_tune_mode = fine_tune_mode 

294 if self.fine_tune_mode in ["nonlinear", "linear"]: 

295 self.embedding_flex = torch.nn.Linear( 

296 self.embedding_length, self.embedding_length, bias=False 

297 ) 

298 self.embedding_flex.weight.data.copy_(torch.eye(self.embedding_length)) 

299 

300 if self.fine_tune_mode in ["nonlinear"]: 

301 self.embedding_flex_nonlinear = torch.nn.ReLU(self.embedding_length) 

302 self.embedding_flex_nonlinear_map = torch.nn.Linear( 

303 self.embedding_length, self.embedding_length 

304 ) 

305 

306 self.__embedding_length: int = self.embeddings.embedding_length 

307 

308 self.to(flair.device) 

309 

310 if pooling not in ['min', 'max', 'mean']: 

311 raise ValueError(f"Pooling operation for {self.mode!r} is not defined") 

312 

313 self.pooling = pooling 

314 self.name: str = f"document_{self.pooling}" 

315 

316 @property 

317 def embedding_length(self) -> int: 

318 return self.__embedding_length 

319 

320 def embed(self, sentences: Union[List[Sentence], Sentence]): 

321 """Add embeddings to every sentence in the given list of sentences. If embeddings are already added, updates 

322 only if embeddings are non-static.""" 

323 

324 # if only one sentence is passed, convert to list of sentence 

325 if isinstance(sentences, Sentence): 

326 sentences = [sentences] 

327 

328 self.embeddings.embed(sentences) 

329 

330 for sentence in sentences: 

331 word_embeddings = [] 

332 for token in sentence.tokens: 

333 word_embeddings.append(token.get_embedding().unsqueeze(0)) 

334 

335 word_embeddings = torch.cat(word_embeddings, dim=0).to(flair.device) 

336 

337 if self.fine_tune_mode in ["nonlinear", "linear"]: 

338 word_embeddings = self.embedding_flex(word_embeddings) 

339 

340 if self.fine_tune_mode in ["nonlinear"]: 

341 word_embeddings = self.embedding_flex_nonlinear(word_embeddings) 

342 word_embeddings = self.embedding_flex_nonlinear_map(word_embeddings) 

343 

344 if self.pooling == "mean": 

345 pooled_embedding = torch.mean(word_embeddings, 0) 

346 elif self.pooling == "max": 

347 pooled_embedding, _ = torch.max(word_embeddings, 0) 

348 elif self.pooling == "min": 

349 pooled_embedding, _ = torch.min(word_embeddings, 0) 

350 

351 sentence.set_embedding(self.name, pooled_embedding) 

352 

353 def _add_embeddings_internal(self, sentences: List[Sentence]): 

354 pass 

355 

356 def extra_repr(self): 

357 return f"fine_tune_mode={self.fine_tune_mode}, pooling={self.pooling}" 

358 

359 

360class DocumentTFIDFEmbeddings(DocumentEmbeddings): 

361 def __init__( 

362 self, 

363 train_dataset, 

364 **vectorizer_params, 

365 ): 

366 """The constructor for DocumentTFIDFEmbeddings. 

367 :param train_dataset: the train dataset which will be used to construct vectorizer 

368 :param vectorizer_params: parameters given to Scikit-learn's TfidfVectorizer constructor 

369 """ 

370 super().__init__() 

371 

372 import numpy as np 

373 self.vectorizer = TfidfVectorizer(dtype=np.float32, **vectorizer_params) 

374 self.vectorizer.fit([s.to_original_text() for s in train_dataset]) 

375 

376 self.__embedding_length: int = len(self.vectorizer.vocabulary_) 

377 

378 self.to(flair.device) 

379 

380 self.name: str = f"document_tfidf" 

381 

382 @property 

383 def embedding_length(self) -> int: 

384 return self.__embedding_length 

385 

386 def embed(self, sentences: Union[List[Sentence], Sentence]): 

387 """Add embeddings to every sentence in the given list of sentences.""" 

388 

389 # if only one sentence is passed, convert to list of sentence 

390 if isinstance(sentences, Sentence): 

391 sentences = [sentences] 

392 

393 raw_sentences = [s.to_original_text() for s in sentences] 

394 tfidf_vectors = torch.from_numpy(self.vectorizer.transform(raw_sentences).A) 

395 

396 for sentence_id, sentence in enumerate(sentences): 

397 sentence.set_embedding(self.name, tfidf_vectors[sentence_id]) 

398 

399 def _add_embeddings_internal(self, sentences: List[Sentence]): 

400 pass 

401 

402 

403class DocumentRNNEmbeddings(DocumentEmbeddings): 

404 def __init__( 

405 self, 

406 embeddings: List[TokenEmbeddings], 

407 hidden_size=128, 

408 rnn_layers=1, 

409 reproject_words: bool = True, 

410 reproject_words_dimension: int = None, 

411 bidirectional: bool = False, 

412 dropout: float = 0.5, 

413 word_dropout: float = 0.0, 

414 locked_dropout: float = 0.0, 

415 rnn_type="GRU", 

416 fine_tune: bool = True, 

417 ): 

418 """The constructor takes a list of embeddings to be combined. 

419 :param embeddings: a list of token embeddings 

420 :param hidden_size: the number of hidden states in the rnn 

421 :param rnn_layers: the number of layers for the rnn 

422 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear 

423 layer before putting them into the rnn or not 

424 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output 

425 dimension as before will be taken. 

426 :param bidirectional: boolean value, indicating whether to use a bidirectional rnn or not 

427 :param dropout: the dropout value to be used 

428 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used 

429 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used 

430 :param rnn_type: 'GRU' or 'LSTM' 

431 """ 

432 super().__init__() 

433 

434 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

435 

436 self.rnn_type = rnn_type 

437 

438 self.reproject_words = reproject_words 

439 self.bidirectional = bidirectional 

440 

441 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length 

442 

443 self.static_embeddings = False if fine_tune else True 

444 

445 self.__embedding_length: int = hidden_size 

446 if self.bidirectional: 

447 self.__embedding_length *= 4 

448 

449 self.embeddings_dimension: int = self.length_of_all_token_embeddings 

450 if self.reproject_words and reproject_words_dimension is not None: 

451 self.embeddings_dimension = reproject_words_dimension 

452 

453 self.word_reprojection_map = torch.nn.Linear( 

454 self.length_of_all_token_embeddings, self.embeddings_dimension 

455 ) 

456 

457 # bidirectional RNN on top of embedding layer 

458 if rnn_type == "LSTM": 

459 self.rnn = torch.nn.LSTM( 

460 self.embeddings_dimension, 

461 hidden_size, 

462 num_layers=rnn_layers, 

463 bidirectional=self.bidirectional, 

464 batch_first=True, 

465 ) 

466 else: 

467 self.rnn = torch.nn.GRU( 

468 self.embeddings_dimension, 

469 hidden_size, 

470 num_layers=rnn_layers, 

471 bidirectional=self.bidirectional, 

472 batch_first=True, 

473 ) 

474 

475 self.name = "document_" + self.rnn._get_name() 

476 

477 # dropouts 

478 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None 

479 self.locked_dropout = ( 

480 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None 

481 ) 

482 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None 

483 

484 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight) 

485 

486 self.to(flair.device) 

487 

488 self.eval() 

489 

490 @property 

491 def embedding_length(self) -> int: 

492 return self.__embedding_length 

493 

494 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): 

495 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update 

496 only if embeddings are non-static.""" 

497 

498 # TODO: remove in future versions 

499 if not hasattr(self, "locked_dropout"): 

500 self.locked_dropout = None 

501 if not hasattr(self, "word_dropout"): 

502 self.word_dropout = None 

503 

504 if type(sentences) is Sentence: 

505 sentences = [sentences] 

506 

507 self.rnn.zero_grad() 

508 

509 # embed words in the sentence 

510 self.embeddings.embed(sentences) 

511 

512 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

513 longest_token_sequence_in_batch: int = max(lengths) 

514 

515 pre_allocated_zero_tensor = torch.zeros( 

516 self.embeddings.embedding_length * longest_token_sequence_in_batch, 

517 dtype=torch.float, 

518 device=flair.device, 

519 ) 

520 

521 all_embs: List[torch.Tensor] = list() 

522 for sentence in sentences: 

523 all_embs += [ 

524 emb for token in sentence for emb in token.get_each_embedding() 

525 ] 

526 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) 

527 

528 if nb_padding_tokens > 0: 

529 t = pre_allocated_zero_tensor[ 

530 : self.embeddings.embedding_length * nb_padding_tokens 

531 ] 

532 all_embs.append(t) 

533 

534 sentence_tensor = torch.cat(all_embs).view( 

535 [ 

536 len(sentences), 

537 longest_token_sequence_in_batch, 

538 self.embeddings.embedding_length, 

539 ] 

540 ) 

541 

542 # before-RNN dropout 

543 if self.dropout: 

544 sentence_tensor = self.dropout(sentence_tensor) 

545 if self.locked_dropout: 

546 sentence_tensor = self.locked_dropout(sentence_tensor) 

547 if self.word_dropout: 

548 sentence_tensor = self.word_dropout(sentence_tensor) 

549 

550 # reproject if set 

551 if self.reproject_words: 

552 sentence_tensor = self.word_reprojection_map(sentence_tensor) 

553 

554 # push through RNN 

555 packed = pack_padded_sequence( 

556 sentence_tensor, lengths, enforce_sorted=False, batch_first=True 

557 ) 

558 rnn_out, hidden = self.rnn(packed) 

559 outputs, output_lengths = pad_packed_sequence(rnn_out, batch_first=True) 

560 

561 # after-RNN dropout 

562 if self.dropout: 

563 outputs = self.dropout(outputs) 

564 if self.locked_dropout: 

565 outputs = self.locked_dropout(outputs) 

566 

567 # extract embeddings from RNN 

568 for sentence_no, length in enumerate(lengths): 

569 last_rep = outputs[sentence_no, length - 1] 

570 

571 embedding = last_rep 

572 if self.bidirectional: 

573 first_rep = outputs[sentence_no, 0] 

574 embedding = torch.cat([first_rep, last_rep], 0) 

575 

576 if self.static_embeddings: 

577 embedding = embedding.detach() 

578 

579 sentence = sentences[sentence_no] 

580 sentence.set_embedding(self.name, embedding) 

581 

582 def _apply(self, fn): 

583 

584 # models that were serialized using torch versions older than 1.4.0 lack the _flat_weights_names attribute 

585 # check if this is the case and if so, set it 

586 for child_module in self.children(): 

587 if isinstance(child_module, torch.nn.RNNBase) and not hasattr(child_module, "_flat_weights_names"): 

588 _flat_weights_names = [] 

589 

590 if child_module.__dict__["bidirectional"]: 

591 num_direction = 2 

592 else: 

593 num_direction = 1 

594 for layer in range(child_module.__dict__["num_layers"]): 

595 for direction in range(num_direction): 

596 suffix = "_reverse" if direction == 1 else "" 

597 param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] 

598 if child_module.__dict__["bias"]: 

599 param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] 

600 param_names = [ 

601 x.format(layer, suffix) for x in param_names 

602 ] 

603 _flat_weights_names.extend(param_names) 

604 

605 setattr(child_module, "_flat_weights_names", 

606 _flat_weights_names) 

607 

608 child_module._apply(fn) 

609 

610 def __getstate__(self): 

611 

612 # serialize the language models and the constructor arguments (but nothing else) 

613 model_state = { 

614 "state_dict": self.state_dict(), 

615 

616 "embeddings": self.embeddings.embeddings, 

617 "hidden_size": self.rnn.hidden_size, 

618 "rnn_layers": self.rnn.num_layers, 

619 "reproject_words": self.reproject_words, 

620 "reproject_words_dimension": self.embeddings_dimension, 

621 "bidirectional": self.bidirectional, 

622 "dropout": self.dropout.p if self.dropout is not None else 0., 

623 "word_dropout": self.word_dropout.p if self.word_dropout is not None else 0., 

624 "locked_dropout": self.locked_dropout.p if self.locked_dropout is not None else 0., 

625 "rnn_type": self.rnn_type, 

626 "fine_tune": not self.static_embeddings, 

627 } 

628 

629 return model_state 

630 

631 def __setstate__(self, d): 

632 

633 # special handling for deserializing language models 

634 if "state_dict" in d: 

635 

636 # re-initialize language model with constructor arguments 

637 language_model = DocumentRNNEmbeddings( 

638 embeddings=d['embeddings'], 

639 hidden_size=d['hidden_size'], 

640 rnn_layers=d['rnn_layers'], 

641 reproject_words=d['reproject_words'], 

642 reproject_words_dimension=d['reproject_words_dimension'], 

643 bidirectional=d['bidirectional'], 

644 dropout=d['dropout'], 

645 word_dropout=d['word_dropout'], 

646 locked_dropout=d['locked_dropout'], 

647 rnn_type=d['rnn_type'], 

648 fine_tune=d['fine_tune'], 

649 ) 

650 

651 language_model.load_state_dict(d['state_dict']) 

652 

653 # copy over state dictionary to self 

654 for key in language_model.__dict__.keys(): 

655 self.__dict__[key] = language_model.__dict__[key] 

656 

657 # set the language model to eval() by default (this is necessary since FlairEmbeddings "protect" the LM 

658 # in their "self.train()" method) 

659 self.eval() 

660 

661 else: 

662 self.__dict__ = d 

663 

664 

665class DocumentLMEmbeddings(DocumentEmbeddings): 

666 def __init__(self, flair_embeddings: List[FlairEmbeddings]): 

667 super().__init__() 

668 

669 self.embeddings = flair_embeddings 

670 self.name = "document_lm" 

671 

672 # IMPORTANT: add embeddings as torch modules 

673 for i, embedding in enumerate(flair_embeddings): 

674 self.add_module("lm_embedding_{}".format(i), embedding) 

675 if not embedding.static_embeddings: 

676 self.static_embeddings = False 

677 

678 self._embedding_length: int = sum( 

679 embedding.embedding_length for embedding in flair_embeddings 

680 ) 

681 

682 @property 

683 def embedding_length(self) -> int: 

684 return self._embedding_length 

685 

686 def _add_embeddings_internal(self, sentences: List[Sentence]): 

687 if type(sentences) is Sentence: 

688 sentences = [sentences] 

689 

690 for embedding in self.embeddings: 

691 embedding.embed(sentences) 

692 

693 # iterate over sentences 

694 for sentence in sentences: 

695 sentence: Sentence = sentence 

696 

697 # if its a forward LM, take last state 

698 if embedding.is_forward_lm: 

699 sentence.set_embedding( 

700 embedding.name, 

701 sentence[len(sentence) - 1]._embeddings[embedding.name], 

702 ) 

703 else: 

704 sentence.set_embedding( 

705 embedding.name, sentence[0]._embeddings[embedding.name] 

706 ) 

707 

708 return sentences 

709 

710 

711class SentenceTransformerDocumentEmbeddings(DocumentEmbeddings): 

712 def __init__( 

713 self, 

714 model: str = "bert-base-nli-mean-tokens", 

715 batch_size: int = 1, 

716 convert_to_numpy: bool = False, 

717 ): 

718 """ 

719 :param model: string name of models from SentencesTransformer Class 

720 :param name: string name of embedding type which will be set to Sentence object 

721 :param batch_size: int number of sentences to processed in one batch 

722 :param convert_to_numpy: bool whether the encode() returns a numpy array or PyTorch tensor 

723 """ 

724 super().__init__() 

725 

726 try: 

727 from sentence_transformers import SentenceTransformer 

728 except ModuleNotFoundError: 

729 log.warning("-" * 100) 

730 log.warning('ATTENTION! The library "sentence-transformers" is not installed!') 

731 log.warning( 

732 'To use Sentence Transformers, please first install with "pip install sentence-transformers"' 

733 ) 

734 log.warning("-" * 100) 

735 pass 

736 

737 self.model = SentenceTransformer(model) 

738 self.name = 'sentence-transformers-' + str(model) 

739 self.batch_size = batch_size 

740 self.convert_to_numpy = convert_to_numpy 

741 self.static_embeddings = True 

742 

743 def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]: 

744 

745 sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size] 

746 for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)] 

747 

748 for batch in sentence_batches: 

749 self._add_embeddings_to_sentences(batch) 

750 

751 return sentences 

752 

753 def _add_embeddings_to_sentences(self, sentences: List[Sentence]): 

754 

755 # convert to plain strings, embedded in a list for the encode function 

756 sentences_plain_text = [sentence.to_plain_string() for sentence in sentences] 

757 

758 embeddings = self.model.encode(sentences_plain_text, convert_to_numpy=self.convert_to_numpy) 

759 for sentence, embedding in zip(sentences, embeddings): 

760 sentence.set_embedding(self.name, embedding) 

761 

762 @property 

763 @abstractmethod 

764 def embedding_length(self) -> int: 

765 """Returns the length of the embedding vector.""" 

766 return self.model.get_sentence_embedding_dimension() 

767 

768 

769class DocumentCNNEmbeddings(DocumentEmbeddings): 

770 def __init__( 

771 self, 

772 embeddings: List[TokenEmbeddings], 

773 kernels=((100, 3), (100, 4), (100, 5)), 

774 reproject_words: bool = True, 

775 reproject_words_dimension: int = None, 

776 dropout: float = 0.5, 

777 word_dropout: float = 0.0, 

778 locked_dropout: float = 0.0, 

779 fine_tune: bool = True, 

780 ): 

781 """The constructor takes a list of embeddings to be combined. 

782 :param embeddings: a list of token embeddings 

783 :param kernels: list of (number of kernels, kernel size) 

784 :param reproject_words: boolean value, indicating whether to reproject the token embeddings in a separate linear 

785 layer before putting them into the rnn or not 

786 :param reproject_words_dimension: output dimension of reprojecting token embeddings. If None the same output 

787 dimension as before will be taken. 

788 :param dropout: the dropout value to be used 

789 :param word_dropout: the word dropout value to be used, if 0.0 word dropout is not used 

790 :param locked_dropout: the locked dropout value to be used, if 0.0 locked dropout is not used 

791 """ 

792 super().__init__() 

793 

794 self.embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embeddings) 

795 self.length_of_all_token_embeddings: int = self.embeddings.embedding_length 

796 

797 self.kernels = kernels 

798 self.reproject_words = reproject_words 

799 

800 self.static_embeddings = False if fine_tune else True 

801 

802 self.embeddings_dimension: int = self.length_of_all_token_embeddings 

803 if self.reproject_words and reproject_words_dimension is not None: 

804 self.embeddings_dimension = reproject_words_dimension 

805 

806 self.word_reprojection_map = torch.nn.Linear( 

807 self.length_of_all_token_embeddings, self.embeddings_dimension 

808 ) 

809 

810 # CNN 

811 self.__embedding_length: int = sum([kernel_num for kernel_num, kernel_size in self.kernels]) 

812 self.convs = torch.nn.ModuleList( 

813 [ 

814 torch.nn.Conv1d(self.embeddings_dimension, kernel_num, kernel_size) for kernel_num, kernel_size in 

815 self.kernels 

816 ] 

817 ) 

818 self.pool = torch.nn.AdaptiveMaxPool1d(1) 

819 

820 self.name = "document_cnn" 

821 

822 # dropouts 

823 self.dropout = torch.nn.Dropout(dropout) if dropout > 0.0 else None 

824 self.locked_dropout = ( 

825 LockedDropout(locked_dropout) if locked_dropout > 0.0 else None 

826 ) 

827 self.word_dropout = WordDropout(word_dropout) if word_dropout > 0.0 else None 

828 

829 torch.nn.init.xavier_uniform_(self.word_reprojection_map.weight) 

830 

831 self.to(flair.device) 

832 

833 self.eval() 

834 

835 @property 

836 def embedding_length(self) -> int: 

837 return self.__embedding_length 

838 

839 def _add_embeddings_internal(self, sentences: Union[List[Sentence], Sentence]): 

840 """Add embeddings to all sentences in the given list of sentences. If embeddings are already added, update 

841 only if embeddings are non-static.""" 

842 

843 # TODO: remove in future versions 

844 if not hasattr(self, "locked_dropout"): 

845 self.locked_dropout = None 

846 if not hasattr(self, "word_dropout"): 

847 self.word_dropout = None 

848 

849 if type(sentences) is Sentence: 

850 sentences = [sentences] 

851 

852 self.zero_grad() # is it necessary? 

853 

854 # embed words in the sentence 

855 self.embeddings.embed(sentences) 

856 

857 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

858 longest_token_sequence_in_batch: int = max(lengths) 

859 

860 pre_allocated_zero_tensor = torch.zeros( 

861 self.embeddings.embedding_length * longest_token_sequence_in_batch, 

862 dtype=torch.float, 

863 device=flair.device, 

864 ) 

865 

866 all_embs: List[torch.Tensor] = list() 

867 for sentence in sentences: 

868 all_embs += [ 

869 emb for token in sentence for emb in token.get_each_embedding() 

870 ] 

871 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) 

872 

873 if nb_padding_tokens > 0: 

874 t = pre_allocated_zero_tensor[ 

875 : self.embeddings.embedding_length * nb_padding_tokens 

876 ] 

877 all_embs.append(t) 

878 

879 sentence_tensor = torch.cat(all_embs).view( 

880 [ 

881 len(sentences), 

882 longest_token_sequence_in_batch, 

883 self.embeddings.embedding_length, 

884 ] 

885 ) 

886 

887 # before-RNN dropout 

888 if self.dropout: 

889 sentence_tensor = self.dropout(sentence_tensor) 

890 if self.locked_dropout: 

891 sentence_tensor = self.locked_dropout(sentence_tensor) 

892 if self.word_dropout: 

893 sentence_tensor = self.word_dropout(sentence_tensor) 

894 

895 # reproject if set 

896 if self.reproject_words: 

897 sentence_tensor = self.word_reprojection_map(sentence_tensor) 

898 

899 # push CNN 

900 x = sentence_tensor 

901 x = x.permute(0, 2, 1) 

902 

903 rep = [self.pool(torch.nn.functional.relu(conv(x))) for conv in self.convs] 

904 outputs = torch.cat(rep, 1) 

905 

906 outputs = outputs.reshape(outputs.size(0), -1) 

907 

908 # after-CNN dropout 

909 if self.dropout: 

910 outputs = self.dropout(outputs) 

911 if self.locked_dropout: 

912 outputs = self.locked_dropout(outputs) 

913 

914 # extract embeddings from CNN 

915 for sentence_no, length in enumerate(lengths): 

916 embedding = outputs[sentence_no] 

917 

918 if self.static_embeddings: 

919 embedding = embedding.detach() 

920 

921 sentence = sentences[sentence_no] 

922 sentence.set_embedding(self.name, embedding) 

923 

924 def _apply(self, fn): 

925 for child_module in self.children(): 

926 child_module._apply(fn)