Coverage for flair/flair/models/sequence_tagger_model.py: 53%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

475 statements  

1import logging 

2import sys 

3 

4from pathlib import Path 

5from typing import List, Union, Optional, Dict, Tuple 

6from warnings import warn 

7 

8import numpy as np 

9import torch 

10import torch.nn 

11import torch.nn.functional as F 

12from requests import HTTPError 

13from tabulate import tabulate 

14from torch.nn.parameter import Parameter 

15from tqdm import tqdm 

16 

17import flair.nn 

18from flair.data import Dictionary, Sentence, Label 

19from flair.datasets import SentenceDataset, DataLoader 

20from flair.embeddings import TokenEmbeddings, StackedEmbeddings, Embeddings 

21from flair.file_utils import cached_path, unzip_file 

22from flair.training_utils import store_embeddings 

23 

24log = logging.getLogger("flair") 

25 

26START_TAG: str = "<START>" 

27STOP_TAG: str = "<STOP>" 

28 

29 

30def to_scalar(var): 

31 return var.view(-1).detach().tolist()[0] 

32 

33 

34def argmax(vec): 

35 _, idx = torch.max(vec, 1) 

36 return to_scalar(idx) 

37 

38 

39def log_sum_exp(vec): 

40 max_score = vec[0, argmax(vec)] 

41 max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1]) 

42 return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast))) 

43 

44 

45def argmax_batch(vecs): 

46 _, idx = torch.max(vecs, 1) 

47 return idx 

48 

49 

50def log_sum_exp_batch(vecs): 

51 maxi = torch.max(vecs, 1)[0] 

52 maxi_bc = maxi[:, None].repeat(1, vecs.shape[1]) 

53 recti_ = torch.log(torch.sum(torch.exp(vecs - maxi_bc), 1)) 

54 return maxi + recti_ 

55 

56 

57def pad_tensors(tensor_list): 

58 ml = max([x.shape[0] for x in tensor_list]) 

59 shape = [len(tensor_list), ml] + list(tensor_list[0].shape[1:]) 

60 template = torch.zeros(*shape, dtype=torch.long, device=flair.device) 

61 lens_ = [x.shape[0] for x in tensor_list] 

62 for i, tensor in enumerate(tensor_list): 

63 template[i, : lens_[i]] = tensor 

64 

65 return template, lens_ 

66 

67 

68class SequenceTagger(flair.nn.Classifier): 

69 def __init__( 

70 self, 

71 hidden_size: int, 

72 embeddings: TokenEmbeddings, 

73 tag_dictionary: Dictionary, 

74 tag_type: str, 

75 use_crf: bool = True, 

76 use_rnn: bool = True, 

77 rnn_layers: int = 1, 

78 dropout: float = 0.0, 

79 word_dropout: float = 0.05, 

80 locked_dropout: float = 0.5, 

81 reproject_embeddings: Union[bool, int] = True, 

82 train_initial_hidden_state: bool = False, 

83 rnn_type: str = "LSTM", 

84 beta: float = 1.0, 

85 loss_weights: Dict[str, float] = None, 

86 ): 

87 """ 

88 Initializes a SequenceTagger 

89 :param hidden_size: number of hidden states in RNN 

90 :param embeddings: word embeddings used in tagger 

91 :param tag_dictionary: dictionary of tags you want to predict 

92 :param tag_type: string identifier for tag type 

93 :param use_crf: if True use CRF decoder, else project directly to tag space 

94 :param use_rnn: if True use RNN layer, otherwise use word embeddings directly 

95 :param rnn_layers: number of RNN layers 

96 :param dropout: dropout probability 

97 :param word_dropout: word dropout probability 

98 :param reproject_embeddings: if True, adds trainable linear map on top of embedding layer. If False, no map. 

99 If you set this to an integer, you can control the dimensionality of the reprojection layer 

100 :param locked_dropout: locked dropout probability 

101 :param train_initial_hidden_state: if True, trains initial hidden state of RNN 

102 :param beta: Parameter for F-beta score for evaluation and training annealing 

103 :param loss_weights: Dictionary of weights for classes (tags) for the loss function 

104 (if any tag's weight is unspecified it will default to 1.0) 

105 

106 """ 

107 super(SequenceTagger, self).__init__() 

108 

109 self.use_rnn = use_rnn 

110 self.hidden_size = hidden_size 

111 self.use_crf: bool = use_crf 

112 self.rnn_layers: int = rnn_layers 

113 

114 self.trained_epochs: int = 0 

115 

116 self.embeddings = embeddings 

117 

118 # set the dictionaries 

119 self.tag_dictionary: Dictionary = tag_dictionary 

120 # if we use a CRF, we must add special START and STOP tags to the dictionary 

121 if use_crf: 

122 self.tag_dictionary.add_item(START_TAG) 

123 self.tag_dictionary.add_item(STOP_TAG) 

124 

125 self.tag_type: str = tag_type 

126 self.tagset_size: int = len(tag_dictionary) 

127 

128 self.beta = beta 

129 

130 self.weight_dict = loss_weights 

131 # Initialize the weight tensor 

132 if loss_weights is not None: 

133 n_classes = len(self.tag_dictionary) 

134 weight_list = [1. for i in range(n_classes)] 

135 for i, tag in enumerate(self.tag_dictionary.get_items()): 

136 if tag in loss_weights.keys(): 

137 weight_list[i] = loss_weights[tag] 

138 self.loss_weights = torch.FloatTensor(weight_list).to(flair.device) 

139 else: 

140 self.loss_weights = None 

141 

142 # initialize the network architecture 

143 self.nlayers: int = rnn_layers 

144 self.hidden_word = None 

145 

146 # dropouts 

147 self.use_dropout: float = dropout 

148 self.use_word_dropout: float = word_dropout 

149 self.use_locked_dropout: float = locked_dropout 

150 

151 if dropout > 0.0: 

152 self.dropout = torch.nn.Dropout(dropout) 

153 

154 if word_dropout > 0.0: 

155 self.word_dropout = flair.nn.WordDropout(word_dropout) 

156 

157 if locked_dropout > 0.0: 

158 self.locked_dropout = flair.nn.LockedDropout(locked_dropout) 

159 

160 embedding_dim: int = self.embeddings.embedding_length 

161 rnn_input_dim: int = embedding_dim 

162 

163 # optional reprojection layer on top of word embeddings 

164 self.reproject_embeddings = reproject_embeddings 

165 if self.reproject_embeddings: 

166 if type(self.reproject_embeddings) == int: 

167 rnn_input_dim = self.reproject_embeddings 

168 

169 self.embedding2nn = torch.nn.Linear(embedding_dim, rnn_input_dim) 

170 

171 self.train_initial_hidden_state = train_initial_hidden_state 

172 self.bidirectional = True 

173 self.rnn_type = rnn_type 

174 

175 # bidirectional LSTM on top of embedding layer 

176 if self.use_rnn: 

177 num_directions = 2 if self.bidirectional else 1 

178 

179 if self.rnn_type in ["LSTM", "GRU"]: 

180 

181 self.rnn = getattr(torch.nn, self.rnn_type)( 

182 rnn_input_dim, 

183 hidden_size, 

184 num_layers=self.nlayers, 

185 dropout=0.0 if self.nlayers == 1 else 0.5, 

186 bidirectional=True, 

187 batch_first=True, 

188 ) 

189 # Create initial hidden state and initialize it 

190 if self.train_initial_hidden_state: 

191 self.hs_initializer = torch.nn.init.xavier_normal_ 

192 

193 self.lstm_init_h = Parameter( 

194 torch.randn(self.nlayers * num_directions, self.hidden_size), 

195 requires_grad=True, 

196 ) 

197 

198 self.lstm_init_c = Parameter( 

199 torch.randn(self.nlayers * num_directions, self.hidden_size), 

200 requires_grad=True, 

201 ) 

202 

203 # TODO: Decide how to initialize the hidden state variables 

204 # self.hs_initializer(self.lstm_init_h) 

205 # self.hs_initializer(self.lstm_init_c) 

206 

207 # final linear map to tag space 

208 self.linear = torch.nn.Linear( 

209 hidden_size * num_directions, len(tag_dictionary) 

210 ) 

211 else: 

212 self.linear = torch.nn.Linear( 

213 rnn_input_dim, len(tag_dictionary) 

214 ) 

215 

216 if self.use_crf: 

217 self.transitions = torch.nn.Parameter( 

218 torch.randn(self.tagset_size, self.tagset_size) 

219 ) 

220 

221 self.transitions.detach()[ 

222 self.tag_dictionary.get_idx_for_item(START_TAG), : 

223 ] = -10000 

224 

225 self.transitions.detach()[ 

226 :, self.tag_dictionary.get_idx_for_item(STOP_TAG) 

227 ] = -10000 

228 

229 self.to(flair.device) 

230 

231 def _get_state_dict(self): 

232 model_state = { 

233 "state_dict": self.state_dict(), 

234 "embeddings": self.embeddings, 

235 "hidden_size": self.hidden_size, 

236 "train_initial_hidden_state": self.train_initial_hidden_state, 

237 "tag_dictionary": self.tag_dictionary, 

238 "tag_type": self.tag_type, 

239 "use_crf": self.use_crf, 

240 "use_rnn": self.use_rnn, 

241 "rnn_layers": self.rnn_layers, 

242 "use_dropout": self.use_dropout, 

243 "use_word_dropout": self.use_word_dropout, 

244 "use_locked_dropout": self.use_locked_dropout, 

245 "rnn_type": self.rnn_type, 

246 "beta": self.beta, 

247 "weight_dict": self.weight_dict, 

248 "reproject_embeddings": self.reproject_embeddings, 

249 } 

250 return model_state 

251 

252 @staticmethod 

253 def _init_model_with_state_dict(state): 

254 

255 rnn_type = "LSTM" if "rnn_type" not in state.keys() else state["rnn_type"] 

256 use_dropout = 0.0 if "use_dropout" not in state.keys() else state["use_dropout"] 

257 use_word_dropout = 0.0 if "use_word_dropout" not in state.keys() else state["use_word_dropout"] 

258 use_locked_dropout = 0.0 if "use_locked_dropout" not in state.keys() else state["use_locked_dropout"] 

259 

260 train_initial_hidden_state = ( 

261 False 

262 if "train_initial_hidden_state" not in state.keys() 

263 else state["train_initial_hidden_state"] 

264 ) 

265 beta = 1.0 if "beta" not in state.keys() else state["beta"] 

266 weights = None if "weight_dict" not in state.keys() else state["weight_dict"] 

267 reproject_embeddings = True if "reproject_embeddings" not in state.keys() else state["reproject_embeddings"] 

268 if "reproject_to" in state.keys(): 

269 reproject_embeddings = state["reproject_to"] 

270 

271 model = SequenceTagger( 

272 hidden_size=state["hidden_size"], 

273 embeddings=state["embeddings"], 

274 tag_dictionary=state["tag_dictionary"], 

275 tag_type=state["tag_type"], 

276 use_crf=state["use_crf"], 

277 use_rnn=state["use_rnn"], 

278 rnn_layers=state["rnn_layers"], 

279 dropout=use_dropout, 

280 word_dropout=use_word_dropout, 

281 locked_dropout=use_locked_dropout, 

282 train_initial_hidden_state=train_initial_hidden_state, 

283 rnn_type=rnn_type, 

284 beta=beta, 

285 loss_weights=weights, 

286 reproject_embeddings=reproject_embeddings, 

287 ) 

288 model.load_state_dict(state["state_dict"]) 

289 return model 

290 

291 def predict( 

292 self, 

293 sentences: Union[List[Sentence], Sentence], 

294 mini_batch_size=32, 

295 all_tag_prob: bool = False, 

296 verbose: bool = False, 

297 label_name: Optional[str] = None, 

298 return_loss=False, 

299 embedding_storage_mode="none", 

300 ): 

301 """ 

302 Predict sequence tags for Named Entity Recognition task 

303 :param sentences: a Sentence or a List of Sentence 

304 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

305 up to a point when it has no more effect. 

306 :param all_tag_prob: True to compute the score for each tag on each token, 

307 otherwise only the score of the best tag is returned 

308 :param verbose: set to True to display a progress bar 

309 :param return_loss: set to True to return loss 

310 :param label_name: set this to change the name of the label type that is predicted 

311 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

312 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

313 'gpu' to store embeddings in GPU memory. 

314 """ 

315 if label_name == None: 

316 label_name = self.tag_type 

317 

318 with torch.no_grad(): 

319 if not sentences: 

320 return sentences 

321 

322 if isinstance(sentences, Sentence): 

323 sentences = [sentences] 

324 

325 # reverse sort all sequences by their length 

326 rev_order_len_index = sorted( 

327 range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True 

328 ) 

329 

330 reordered_sentences: List[Union[Sentence, str]] = [ 

331 sentences[index] for index in rev_order_len_index 

332 ] 

333 

334 dataloader = DataLoader( 

335 dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size 

336 ) 

337 

338 if self.use_crf: 

339 transitions = self.transitions.detach().cpu().numpy() 

340 else: 

341 transitions = None 

342 

343 # progress bar for verbosity 

344 if verbose: 

345 dataloader = tqdm(dataloader) 

346 

347 overall_loss = 0 

348 overall_count = 0 

349 batch_no = 0 

350 for batch in dataloader: 

351 

352 batch_no += 1 

353 

354 if verbose: 

355 dataloader.set_description(f"Inferencing on batch {batch_no}") 

356 

357 batch = self._filter_empty_sentences(batch) 

358 # stop if all sentences are empty 

359 if not batch: 

360 continue 

361 

362 feature = self.forward(batch) 

363 

364 if return_loss: 

365 loss_and_count = self._calculate_loss(feature, batch) 

366 overall_loss += loss_and_count[0] 

367 overall_count += loss_and_count[1] 

368 

369 tags, all_tags = self._obtain_labels( 

370 feature=feature, 

371 batch_sentences=batch, 

372 transitions=transitions, 

373 get_all_tags=all_tag_prob, 

374 ) 

375 

376 for (sentence, sent_tags) in zip(batch, tags): 

377 for (token, tag) in zip(sentence.tokens, sent_tags): 

378 token.add_tag_label(label_name, tag) 

379 

380 # all_tags will be empty if all_tag_prob is set to False, so the for loop will be avoided 

381 for (sentence, sent_all_tags) in zip(batch, all_tags): 

382 for (token, token_all_tags) in zip(sentence.tokens, sent_all_tags): 

383 token.add_tags_proba_dist(label_name, token_all_tags) 

384 

385 # clearing token embeddings to save memory 

386 store_embeddings(batch, storage_mode=embedding_storage_mode) 

387 

388 if return_loss: 

389 return overall_loss, overall_count 

390 

391 def forward_loss( 

392 self, data_points: Union[List[Sentence], Sentence], sort=True 

393 ) -> torch.tensor: 

394 features = self.forward(data_points) 

395 return self._calculate_loss(features, data_points) 

396 

397 def forward(self, sentences: List[Sentence]): 

398 

399 self.embeddings.embed(sentences) 

400 

401 names = self.embeddings.get_names() 

402 

403 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

404 longest_token_sequence_in_batch: int = max(lengths) 

405 

406 pre_allocated_zero_tensor = torch.zeros( 

407 self.embeddings.embedding_length * longest_token_sequence_in_batch, 

408 dtype=torch.float, 

409 device=flair.device, 

410 ) 

411 

412 all_embs = list() 

413 for sentence in sentences: 

414 all_embs += [ 

415 emb for token in sentence for emb in token.get_each_embedding(names) 

416 ] 

417 nb_padding_tokens = longest_token_sequence_in_batch - len(sentence) 

418 

419 if nb_padding_tokens > 0: 

420 t = pre_allocated_zero_tensor[ 

421 : self.embeddings.embedding_length * nb_padding_tokens 

422 ] 

423 all_embs.append(t) 

424 

425 sentence_tensor = torch.cat(all_embs).view( 

426 [ 

427 len(sentences), 

428 longest_token_sequence_in_batch, 

429 self.embeddings.embedding_length, 

430 ] 

431 ) 

432 

433 # -------------------------------------------------------------------- 

434 # FF PART 

435 # -------------------------------------------------------------------- 

436 if self.use_dropout > 0.0: 

437 sentence_tensor = self.dropout(sentence_tensor) 

438 if self.use_word_dropout > 0.0: 

439 sentence_tensor = self.word_dropout(sentence_tensor) 

440 if self.use_locked_dropout > 0.0: 

441 sentence_tensor = self.locked_dropout(sentence_tensor) 

442 

443 if self.reproject_embeddings: 

444 sentence_tensor = self.embedding2nn(sentence_tensor) 

445 

446 if self.use_rnn: 

447 packed = torch.nn.utils.rnn.pack_padded_sequence( 

448 sentence_tensor, lengths, enforce_sorted=False, batch_first=True 

449 ) 

450 

451 # if initial hidden state is trainable, use this state 

452 if self.train_initial_hidden_state: 

453 initial_hidden_state = [ 

454 self.lstm_init_h.unsqueeze(1).repeat(1, len(sentences), 1), 

455 self.lstm_init_c.unsqueeze(1).repeat(1, len(sentences), 1), 

456 ] 

457 rnn_output, hidden = self.rnn(packed, initial_hidden_state) 

458 else: 

459 rnn_output, hidden = self.rnn(packed) 

460 

461 sentence_tensor, output_lengths = torch.nn.utils.rnn.pad_packed_sequence( 

462 rnn_output, batch_first=True 

463 ) 

464 

465 if self.use_dropout > 0.0: 

466 sentence_tensor = self.dropout(sentence_tensor) 

467 # word dropout only before LSTM - TODO: more experimentation needed 

468 # if self.use_word_dropout > 0.0: 

469 # sentence_tensor = self.word_dropout(sentence_tensor) 

470 if self.use_locked_dropout > 0.0: 

471 sentence_tensor = self.locked_dropout(sentence_tensor) 

472 

473 features = self.linear(sentence_tensor) 

474 

475 return features 

476 

477 def _score_sentence(self, feats, tags, lens_): 

478 

479 start = torch.tensor( 

480 [self.tag_dictionary.get_idx_for_item(START_TAG)], device=flair.device 

481 ) 

482 start = start[None, :].repeat(tags.shape[0], 1) 

483 

484 stop = torch.tensor( 

485 [self.tag_dictionary.get_idx_for_item(STOP_TAG)], device=flair.device 

486 ) 

487 stop = stop[None, :].repeat(tags.shape[0], 1) 

488 

489 pad_start_tags = torch.cat([start, tags], 1) 

490 pad_stop_tags = torch.cat([tags, stop], 1) 

491 

492 for i in range(len(lens_)): 

493 pad_stop_tags[i, lens_[i]:] = self.tag_dictionary.get_idx_for_item( 

494 STOP_TAG 

495 ) 

496 

497 score = torch.FloatTensor(feats.shape[0]).to(flair.device) 

498 

499 for i in range(feats.shape[0]): 

500 r = torch.LongTensor(range(lens_[i])).to(flair.device) 

501 

502 score[i] = torch.sum( 

503 self.transitions[ 

504 pad_stop_tags[i, : lens_[i] + 1], pad_start_tags[i, : lens_[i] + 1] 

505 ] 

506 ) + torch.sum(feats[i, r, tags[i, : lens_[i]]]) 

507 

508 return score 

509 

510 def _calculate_loss( 

511 self, features: torch.tensor, sentences: List[Sentence] 

512 ) -> Tuple[float, int]: 

513 

514 lengths: List[int] = [len(sentence.tokens) for sentence in sentences] 

515 

516 tag_list: List = [] 

517 token_count = 0 

518 for s_id, sentence in enumerate(sentences): 

519 # get the tags in this sentence 

520 tag_idx: List[int] = [ 

521 self.tag_dictionary.get_idx_for_item(token.get_tag(self.tag_type).value) 

522 for token in sentence 

523 ] 

524 token_count += len(tag_idx) 

525 # add tags as tensor 

526 tag = torch.tensor(tag_idx, device=flair.device) 

527 tag_list.append(tag) 

528 

529 if self.use_crf: 

530 # pad tags if using batch-CRF decoder 

531 tags, _ = pad_tensors(tag_list) 

532 

533 forward_score = self._forward_alg(features, lengths) 

534 gold_score = self._score_sentence(features, tags, lengths) 

535 

536 score = forward_score - gold_score 

537 

538 return score.sum(), token_count 

539 

540 else: 

541 score = 0 

542 for sentence_feats, sentence_tags, sentence_length in zip( 

543 features, tag_list, lengths 

544 ): 

545 sentence_feats = sentence_feats[:sentence_length] 

546 score += torch.nn.functional.cross_entropy( 

547 sentence_feats, sentence_tags, weight=self.loss_weights, reduction='sum', 

548 ) 

549 

550 return score, token_count 

551 

552 def _obtain_labels( 

553 self, 

554 feature: torch.Tensor, 

555 batch_sentences: List[Sentence], 

556 transitions: Optional[np.ndarray], 

557 get_all_tags: bool, 

558 ) -> (List[List[Label]], List[List[List[Label]]]): 

559 """ 

560 Returns a tuple of two lists: 

561 - The first list corresponds to the most likely `Label` per token in each sentence. 

562 - The second list contains a probability distribution over all `Labels` for each token 

563 in a sentence for all sentences. 

564 """ 

565 

566 lengths: List[int] = [len(sentence.tokens) for sentence in batch_sentences] 

567 

568 tags = [] 

569 all_tags = [] 

570 feature = feature.cpu() 

571 if self.use_crf: 

572 feature = feature.numpy() 

573 else: 

574 for index, length in enumerate(lengths): 

575 feature[index, length:] = 0 

576 softmax_batch = F.softmax(feature, dim=2).cpu() 

577 scores_batch, prediction_batch = torch.max(softmax_batch, dim=2) 

578 feature = zip(softmax_batch, scores_batch, prediction_batch) 

579 

580 for feats, length in zip(feature, lengths): 

581 if self.use_crf: 

582 confidences, tag_seq, scores = self._viterbi_decode( 

583 feats=feats[:length], 

584 transitions=transitions, 

585 all_scores=get_all_tags, 

586 ) 

587 else: 

588 softmax, score, prediction = feats 

589 confidences = score[:length].tolist() 

590 tag_seq = prediction[:length].tolist() 

591 scores = softmax[:length].tolist() 

592 

593 tags.append( 

594 [ 

595 Label(self.tag_dictionary.get_item_for_index(tag), conf) 

596 for conf, tag in zip(confidences, tag_seq) 

597 ] 

598 ) 

599 

600 if get_all_tags: 

601 all_tags.append( 

602 [ 

603 [ 

604 Label( 

605 self.tag_dictionary.get_item_for_index(score_id), score 

606 ) 

607 for score_id, score in enumerate(score_dist) 

608 ] 

609 for score_dist in scores 

610 ] 

611 ) 

612 

613 return tags, all_tags 

614 

615 @staticmethod 

616 def _softmax(x, axis): 

617 # reduce raw values to avoid NaN during exp 

618 x_norm = x - x.max(axis=axis, keepdims=True) 

619 y = np.exp(x_norm) 

620 return y / y.sum(axis=axis, keepdims=True) 

621 

622 def _viterbi_decode( 

623 self, feats: np.ndarray, transitions: np.ndarray, all_scores: bool 

624 ): 

625 id_start = self.tag_dictionary.get_idx_for_item(START_TAG) 

626 id_stop = self.tag_dictionary.get_idx_for_item(STOP_TAG) 

627 

628 backpointers = np.empty(shape=(feats.shape[0], self.tagset_size), dtype=np.int_) 

629 backscores = np.empty( 

630 shape=(feats.shape[0], self.tagset_size), dtype=np.float32 

631 ) 

632 

633 init_vvars = np.expand_dims( 

634 np.repeat(-10000.0, self.tagset_size), axis=0 

635 ).astype(np.float32) 

636 init_vvars[0][id_start] = 0 

637 

638 forward_var = init_vvars 

639 for index, feat in enumerate(feats): 

640 # broadcasting will do the job of reshaping and is more efficient than calling repeat 

641 next_tag_var = forward_var + transitions 

642 bptrs_t = next_tag_var.argmax(axis=1) 

643 viterbivars_t = next_tag_var[np.arange(bptrs_t.shape[0]), bptrs_t] 

644 forward_var = viterbivars_t + feat 

645 backscores[index] = forward_var 

646 forward_var = forward_var[np.newaxis, :] 

647 backpointers[index] = bptrs_t 

648 

649 terminal_var = forward_var.squeeze() + transitions[id_stop] 

650 terminal_var[id_stop] = -10000.0 

651 terminal_var[id_start] = -10000.0 

652 best_tag_id = terminal_var.argmax() 

653 

654 best_path = [best_tag_id] 

655 for bptrs_t in reversed(backpointers): 

656 best_tag_id = bptrs_t[best_tag_id] 

657 best_path.append(best_tag_id) 

658 

659 start = best_path.pop() 

660 assert start == id_start 

661 best_path.reverse() 

662 

663 best_scores_softmax = self._softmax(backscores, axis=1) 

664 best_scores_np = np.max(best_scores_softmax, axis=1) 

665 

666 # default value 

667 all_scores_np = np.zeros(0, dtype=np.float64) 

668 if all_scores: 

669 all_scores_np = best_scores_softmax 

670 for index, (tag_id, tag_scores) in enumerate(zip(best_path, all_scores_np)): 

671 if type(tag_id) != int and tag_id.item() != tag_scores.argmax(): 

672 swap_index_score = tag_scores.argmax() 

673 ( 

674 all_scores_np[index][tag_id.item()], 

675 all_scores_np[index][swap_index_score], 

676 ) = ( 

677 all_scores_np[index][swap_index_score], 

678 all_scores_np[index][tag_id.item()], 

679 ) 

680 elif type(tag_id) == int and tag_id != tag_scores.argmax(): 

681 swap_index_score = tag_scores.argmax() 

682 ( 

683 all_scores_np[index][tag_id], 

684 all_scores_np[index][swap_index_score], 

685 ) = ( 

686 all_scores_np[index][swap_index_score], 

687 all_scores_np[index][tag_id], 

688 ) 

689 

690 return best_scores_np.tolist(), best_path, all_scores_np.tolist() 

691 

692 def _forward_alg(self, feats, lens_): 

693 

694 init_alphas = torch.FloatTensor(self.tagset_size).fill_(-10000.0) 

695 init_alphas[self.tag_dictionary.get_idx_for_item(START_TAG)] = 0.0 

696 

697 forward_var = torch.zeros( 

698 feats.shape[0], 

699 feats.shape[1] + 1, 

700 feats.shape[2], 

701 dtype=torch.float, 

702 device=flair.device, 

703 ) 

704 

705 forward_var[:, 0, :] = init_alphas[None, :].repeat(feats.shape[0], 1) 

706 

707 transitions = self.transitions.view( 

708 1, self.transitions.shape[0], self.transitions.shape[1] 

709 ).repeat(feats.shape[0], 1, 1) 

710 

711 for i in range(feats.shape[1]): 

712 emit_score = feats[:, i, :] 

713 

714 tag_var = ( 

715 emit_score[:, :, None].repeat(1, 1, transitions.shape[2]) 

716 + transitions 

717 + forward_var[:, i, :][:, :, None] 

718 .repeat(1, 1, transitions.shape[2]) 

719 .transpose(2, 1) 

720 ) 

721 

722 max_tag_var, _ = torch.max(tag_var, dim=2) 

723 

724 tag_var = tag_var - max_tag_var[:, :, None].repeat( 

725 1, 1, transitions.shape[2] 

726 ) 

727 

728 agg_ = torch.log(torch.sum(torch.exp(tag_var), dim=2)) 

729 

730 cloned = forward_var.clone() 

731 cloned[:, i + 1, :] = max_tag_var + agg_ 

732 

733 forward_var = cloned 

734 

735 forward_var = forward_var[range(forward_var.shape[0]), lens_, :] 

736 

737 terminal_var = forward_var + self.transitions[ 

738 self.tag_dictionary.get_idx_for_item(STOP_TAG) 

739 ][None, :].repeat(forward_var.shape[0], 1) 

740 

741 alpha = log_sum_exp_batch(terminal_var) 

742 

743 return alpha 

744 

745 @staticmethod 

746 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: 

747 filtered_sentences = [sentence for sentence in sentences if sentence.tokens] 

748 if len(sentences) != len(filtered_sentences): 

749 log.warning( 

750 f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens." 

751 ) 

752 return filtered_sentences 

753 

754 @staticmethod 

755 def _filter_empty_string(texts: List[str]) -> List[str]: 

756 filtered_texts = [text for text in texts if text] 

757 if len(texts) != len(filtered_texts): 

758 log.warning( 

759 f"Ignore {len(texts) - len(filtered_texts)} string(s) with no tokens." 

760 ) 

761 return filtered_texts 

762 

763 @staticmethod 

764 def _fetch_model(model_name) -> str: 

765 

766 # core Flair models on Huggingface ModelHub 

767 huggingface_model_map = { 

768 "ner": "flair/ner-english", 

769 "ner-fast": "flair/ner-english-fast", 

770 "ner-ontonotes": "flair/ner-english-ontonotes", 

771 "ner-ontonotes-fast": "flair/ner-english-ontonotes-fast", 

772 # Large NER models, 

773 "ner-large": "flair/ner-english-large", 

774 "ner-ontonotes-large": "flair/ner-english-ontonotes-large", 

775 "de-ner-large": "flair/ner-german-large", 

776 "nl-ner-large": "flair/ner-dutch-large", 

777 "es-ner-large": "flair/ner-spanish-large", 

778 # Multilingual NER models 

779 "ner-multi": "flair/ner-multi", 

780 "multi-ner": "flair/ner-multi", 

781 "ner-multi-fast": "flair/ner-multi-fast", 

782 # English POS models 

783 "upos": "flair/upos-english", 

784 "upos-fast": "flair/upos-english-fast", 

785 "pos": "flair/pos-english", 

786 "pos-fast": "flair/pos-english-fast", 

787 # Multilingual POS models 

788 "pos-multi": "flair/upos-multi", 

789 "multi-pos": "flair/upos-multi", 

790 "pos-multi-fast": "flair/upos-multi-fast", 

791 "multi-pos-fast": "flair/upos-multi-fast", 

792 # English SRL models 

793 "frame": "flair/frame-english", 

794 "frame-fast": "flair/frame-english-fast", 

795 # English chunking models 

796 "chunk": "flair/chunk-english", 

797 "chunk-fast": "flair/chunk-english-fast", 

798 # Language-specific NER models 

799 "ar-ner": "megantosh/flair-arabic-multi-ner", 

800 "ar-pos": "megantosh/flair-arabic-dialects-codeswitch-egy-lev", 

801 "da-ner": "flair/ner-danish", 

802 "de-ner": "flair/ner-german", 

803 "de-ler": "flair/ner-german-legal", 

804 "de-ner-legal": "flair/ner-german-legal", 

805 "fr-ner": "flair/ner-french", 

806 "nl-ner": "flair/ner-dutch", 

807 } 

808 

809 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" 

810 

811 hu_model_map = { 

812 # English NER models 

813 "ner": "/".join([hu_path, "ner", "en-ner-conll03-v0.4.pt"]), 

814 "ner-pooled": "/".join([hu_path, "ner-pooled", "en-ner-conll03-pooled-v0.5.pt"]), 

815 "ner-fast": "/".join([hu_path, "ner-fast", "en-ner-fast-conll03-v0.4.pt"]), 

816 "ner-ontonotes": "/".join([hu_path, "ner-ontonotes", "en-ner-ontonotes-v0.4.pt"]), 

817 "ner-ontonotes-fast": "/".join([hu_path, "ner-ontonotes-fast", "en-ner-ontonotes-fast-v0.4.pt"]), 

818 # Multilingual NER models 

819 "ner-multi": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), 

820 "multi-ner": "/".join([hu_path, "multi-ner", "quadner-large.pt"]), 

821 "ner-multi-fast": "/".join([hu_path, "multi-ner-fast", "ner-multi-fast.pt"]), 

822 # English POS models 

823 "upos": "/".join([hu_path, "upos", "en-pos-ontonotes-v0.4.pt"]), 

824 "upos-fast": "/".join([hu_path, "upos-fast", "en-upos-ontonotes-fast-v0.4.pt"]), 

825 "pos": "/".join([hu_path, "pos", "en-pos-ontonotes-v0.5.pt"]), 

826 "pos-fast": "/".join([hu_path, "pos-fast", "en-pos-ontonotes-fast-v0.5.pt"]), 

827 # Multilingual POS models 

828 "pos-multi": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), 

829 "multi-pos": "/".join([hu_path, "multi-pos", "pos-multi-v0.1.pt"]), 

830 "pos-multi-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), 

831 "multi-pos-fast": "/".join([hu_path, "multi-pos-fast", "pos-multi-fast.pt"]), 

832 # English SRL models 

833 "frame": "/".join([hu_path, "frame", "en-frame-ontonotes-v0.4.pt"]), 

834 "frame-fast": "/".join([hu_path, "frame-fast", "en-frame-ontonotes-fast-v0.4.pt"]), 

835 # English chunking models 

836 "chunk": "/".join([hu_path, "chunk", "en-chunk-conll2000-v0.4.pt"]), 

837 "chunk-fast": "/".join([hu_path, "chunk-fast", "en-chunk-conll2000-fast-v0.4.pt"]), 

838 # Danish models 

839 "da-pos": "/".join([hu_path, "da-pos", "da-pos-v0.1.pt"]), 

840 "da-ner": "/".join([hu_path, "NER-danish", "da-ner-v0.1.pt"]), 

841 # German models 

842 "de-pos": "/".join([hu_path, "de-pos", "de-pos-ud-hdt-v0.5.pt"]), 

843 "de-pos-tweets": "/".join([hu_path, "de-pos-tweets", "de-pos-twitter-v0.1.pt"]), 

844 "de-ner": "/".join([hu_path, "de-ner", "de-ner-conll03-v0.4.pt"]), 

845 "de-ner-germeval": "/".join([hu_path, "de-ner-germeval", "de-ner-germeval-0.4.1.pt"]), 

846 "de-ler": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), 

847 "de-ner-legal": "/".join([hu_path, "de-ner-legal", "de-ner-legal.pt"]), 

848 # French models 

849 "fr-ner": "/".join([hu_path, "fr-ner", "fr-ner-wikiner-0.4.pt"]), 

850 # Dutch models 

851 "nl-ner": "/".join([hu_path, "nl-ner", "nl-ner-bert-conll02-v0.8.pt"]), 

852 "nl-ner-rnn": "/".join([hu_path, "nl-ner-rnn", "nl-ner-conll02-v0.5.pt"]), 

853 # Malayalam models 

854 "ml-pos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-xpos-model.pt", 

855 "ml-upos": "https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/malayalam-upos-model.pt", 

856 # Portuguese models 

857 "pt-pos-clinical": "/".join([hu_path, "pt-pos-clinical", "pucpr-flair-clinical-pos-tagging-best-model.pt"]), 

858 # Keyphase models 

859 "keyphrase": "/".join([hu_path, "keyphrase", "keyphrase-en-scibert.pt"]), 

860 "negation-speculation": "/".join( 

861 [hu_path, "negation-speculation", "negation-speculation-model.pt"]), 

862 # Biomedical models 

863 "hunflair-paper-cellline": "/".join( 

864 [hu_path, "hunflair_smallish_models", "cellline", "hunflair-celline-v1.0.pt"] 

865 ), 

866 "hunflair-paper-chemical": "/".join( 

867 [hu_path, "hunflair_smallish_models", "chemical", "hunflair-chemical-v1.0.pt"] 

868 ), 

869 "hunflair-paper-disease": "/".join( 

870 [hu_path, "hunflair_smallish_models", "disease", "hunflair-disease-v1.0.pt"] 

871 ), 

872 "hunflair-paper-gene": "/".join( 

873 [hu_path, "hunflair_smallish_models", "gene", "hunflair-gene-v1.0.pt"] 

874 ), 

875 "hunflair-paper-species": "/".join( 

876 [hu_path, "hunflair_smallish_models", "species", "hunflair-species-v1.0.pt"] 

877 ), 

878 "hunflair-cellline": "/".join( 

879 [hu_path, "hunflair_smallish_models", "cellline", "hunflair-celline-v1.0.pt"] 

880 ), 

881 "hunflair-chemical": "/".join( 

882 [hu_path, "hunflair_allcorpus_models", "huner-chemical", "hunflair-chemical-full-v1.0.pt"] 

883 ), 

884 "hunflair-disease": "/".join( 

885 [hu_path, "hunflair_allcorpus_models", "huner-disease", "hunflair-disease-full-v1.0.pt"] 

886 ), 

887 "hunflair-gene": "/".join( 

888 [hu_path, "hunflair_allcorpus_models", "huner-gene", "hunflair-gene-full-v1.0.pt"] 

889 ), 

890 "hunflair-species": "/".join( 

891 [hu_path, "hunflair_allcorpus_models", "huner-species", "hunflair-species-full-v1.1.pt"] 

892 )} 

893 

894 cache_dir = Path("models") 

895 

896 get_from_model_hub = False 

897 

898 # check if model name is a valid local file 

899 if Path(model_name).exists(): 

900 model_path = model_name 

901 

902 # check if model key is remapped to HF key - if so, print out information 

903 elif model_name in huggingface_model_map: 

904 

905 # get mapped name 

906 hf_model_name = huggingface_model_map[model_name] 

907 

908 # output information 

909 log.info("-" * 80) 

910 log.info( 

911 f"The model key '{model_name}' now maps to 'https://huggingface.co/{hf_model_name}' on the HuggingFace ModelHub") 

912 log.info(f" - The most current version of the model is automatically downloaded from there.") 

913 if model_name in hu_model_map: 

914 log.info( 

915 f" - (you can alternatively manually download the original model at {hu_model_map[model_name]})") 

916 log.info("-" * 80) 

917 

918 # use mapped name instead 

919 model_name = hf_model_name 

920 get_from_model_hub = True 

921 

922 # if not, check if model key is remapped to direct download location. If so, download model 

923 elif model_name in hu_model_map: 

924 model_path = cached_path(hu_model_map[model_name], cache_dir=cache_dir) 

925 

926 # special handling for the taggers by the @redewiegergabe project (TODO: move to model hub) 

927 elif model_name == "de-historic-indirect": 

928 model_file = flair.cache_root / cache_dir / 'indirect' / 'final-model.pt' 

929 if not model_file.exists(): 

930 cached_path('http://www.redewiedergabe.de/models/indirect.zip', cache_dir=cache_dir) 

931 unzip_file(flair.cache_root / cache_dir / 'indirect.zip', flair.cache_root / cache_dir) 

932 model_path = str(flair.cache_root / cache_dir / 'indirect' / 'final-model.pt') 

933 

934 elif model_name == "de-historic-direct": 

935 model_file = flair.cache_root / cache_dir / 'direct' / 'final-model.pt' 

936 if not model_file.exists(): 

937 cached_path('http://www.redewiedergabe.de/models/direct.zip', cache_dir=cache_dir) 

938 unzip_file(flair.cache_root / cache_dir / 'direct.zip', flair.cache_root / cache_dir) 

939 model_path = str(flair.cache_root / cache_dir / 'direct' / 'final-model.pt') 

940 

941 elif model_name == "de-historic-reported": 

942 model_file = flair.cache_root / cache_dir / 'reported' / 'final-model.pt' 

943 if not model_file.exists(): 

944 cached_path('http://www.redewiedergabe.de/models/reported.zip', cache_dir=cache_dir) 

945 unzip_file(flair.cache_root / cache_dir / 'reported.zip', flair.cache_root / cache_dir) 

946 model_path = str(flair.cache_root / cache_dir / 'reported' / 'final-model.pt') 

947 

948 elif model_name == "de-historic-free-indirect": 

949 model_file = flair.cache_root / cache_dir / 'freeIndirect' / 'final-model.pt' 

950 if not model_file.exists(): 

951 cached_path('http://www.redewiedergabe.de/models/freeIndirect.zip', cache_dir=cache_dir) 

952 unzip_file(flair.cache_root / cache_dir / 'freeIndirect.zip', flair.cache_root / cache_dir) 

953 model_path = str(flair.cache_root / cache_dir / 'freeIndirect' / 'final-model.pt') 

954 

955 # for all other cases (not local file or special download location), use HF model hub 

956 else: 

957 get_from_model_hub = True 

958 

959 # if not a local file, get from model hub 

960 if get_from_model_hub: 

961 hf_model_name = "pytorch_model.bin" 

962 revision = "main" 

963 

964 if "@" in model_name: 

965 model_name_split = model_name.split("@") 

966 revision = model_name_split[-1] 

967 model_name = model_name_split[0] 

968 

969 # use model name as subfolder 

970 if "/" in model_name: 

971 model_folder = model_name.split("/", maxsplit=1)[1] 

972 else: 

973 model_folder = model_name 

974 

975 # Lazy import 

976 from huggingface_hub import hf_hub_url, cached_download 

977 

978 url = hf_hub_url(model_name, revision=revision, filename=hf_model_name) 

979 

980 try: 

981 model_path = cached_download(url=url, library_name="flair", 

982 library_version=flair.__version__, 

983 cache_dir=flair.cache_root / 'models' / model_folder) 

984 except HTTPError as e: 

985 # output information 

986 log.error("-" * 80) 

987 log.error( 

988 f"ACHTUNG: The key '{model_name}' was neither found on the ModelHub nor is this a valid path to a file on your system!") 

989 # log.error(f" - Error message: {e}") 

990 log.error(f" -> Please check https://huggingface.co/models?filter=flair for all available models.") 

991 log.error(f" -> Alternatively, point to a model file on your local drive.") 

992 log.error("-" * 80) 

993 Path(flair.cache_root / 'models' / model_folder).rmdir() # remove folder again if not valid 

994 

995 return model_path 

996 

997 def get_transition_matrix(self): 

998 data = [] 

999 for to_idx, row in enumerate(self.transitions): 

1000 for from_idx, column in enumerate(row): 

1001 row = [ 

1002 self.tag_dictionary.get_item_for_index(from_idx), 

1003 self.tag_dictionary.get_item_for_index(to_idx), 

1004 column.item(), 

1005 ] 

1006 data.append(row) 

1007 data.append(["----"]) 

1008 print(tabulate(data, headers=["FROM", "TO", "SCORE"])) 

1009 

1010 def __str__(self): 

1011 return super(flair.nn.Model, self).__str__().rstrip(')') + \ 

1012 f' (beta): {self.beta}\n' + \ 

1013 f' (weights): {self.weight_dict}\n' + \ 

1014 f' (weight_tensor) {self.loss_weights}\n)' 

1015 

1016 @property 

1017 def label_type(self): 

1018 return self.tag_type 

1019 

1020 

1021class MultiTagger: 

1022 def __init__(self, name_to_tagger: Dict[str, SequenceTagger]): 

1023 super().__init__() 

1024 self.name_to_tagger = name_to_tagger 

1025 

1026 def predict( 

1027 self, 

1028 sentences: Union[List[Sentence], Sentence], 

1029 mini_batch_size=32, 

1030 all_tag_prob: bool = False, 

1031 verbose: bool = False, 

1032 return_loss: bool = False, 

1033 ): 

1034 """ 

1035 Predict sequence tags for Named Entity Recognition task 

1036 :param sentences: a Sentence or a List of Sentence 

1037 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

1038 up to a point when it has no more effect. 

1039 :param all_tag_prob: True to compute the score for each tag on each token, 

1040 otherwise only the score of the best tag is returned 

1041 :param verbose: set to True to display a progress bar 

1042 :param return_loss: set to True to return loss 

1043 """ 

1044 if any(["hunflair" in name for name in self.name_to_tagger.keys()]): 

1045 if "spacy" not in sys.modules: 

1046 warn( 

1047 "We recommend to use SciSpaCy for tokenization and sentence splitting " 

1048 "if HunFlair is applied to biomedical text, e.g.\n\n" 

1049 "from flair.tokenization import SciSpacySentenceSplitter\n" 

1050 "sentence = Sentence('Your biomed text', use_tokenizer=SciSpacySentenceSplitter())\n" 

1051 ) 

1052 

1053 if isinstance(sentences, Sentence): 

1054 sentences = [sentences] 

1055 for name, tagger in self.name_to_tagger.items(): 

1056 tagger.predict( 

1057 sentences=sentences, 

1058 mini_batch_size=mini_batch_size, 

1059 # all_tag_prob=all_tag_prob, 

1060 verbose=verbose, 

1061 label_name=name, 

1062 return_loss=return_loss, 

1063 embedding_storage_mode="cpu", 

1064 ) 

1065 

1066 # clear embeddings after predicting 

1067 for sentence in sentences: 

1068 sentence.clear_embeddings() 

1069 

1070 @classmethod 

1071 def load(cls, model_names: Union[List[str], str]): 

1072 if model_names == "hunflair-paper": 

1073 model_names = [ 

1074 "hunflair-paper-cellline", 

1075 "hunflair-paper-chemical", 

1076 "hunflair-paper-disease", 

1077 "hunflair-paper-gene", 

1078 "hunflair-paper-species", 

1079 ] 

1080 elif model_names == "hunflair" or model_names == "bioner": 

1081 model_names = [ 

1082 "hunflair-cellline", 

1083 "hunflair-chemical", 

1084 "hunflair-disease", 

1085 "hunflair-gene", 

1086 "hunflair-species", 

1087 ] 

1088 elif isinstance(model_names, str): 

1089 model_names = [model_names] 

1090 

1091 taggers = {} 

1092 models = [] 

1093 

1094 # load each model 

1095 for model_name in model_names: 

1096 

1097 model = SequenceTagger.load(model_name) 

1098 

1099 # check if the same embeddings were already loaded previously 

1100 # if the model uses StackedEmbedding, make a new stack with previous objects 

1101 if type(model.embeddings) == StackedEmbeddings: 

1102 

1103 # sort embeddings by key alphabetically 

1104 new_stack = [] 

1105 d = model.embeddings.get_named_embeddings_dict() 

1106 import collections 

1107 od = collections.OrderedDict(sorted(d.items())) 

1108 

1109 for k, embedding in od.items(): 

1110 

1111 # check previous embeddings and add if found 

1112 embedding_found = False 

1113 for previous_model in models: 

1114 

1115 # only re-use static embeddings 

1116 if not embedding.static_embeddings: continue 

1117 

1118 if embedding.name in previous_model.embeddings.get_named_embeddings_dict(): 

1119 previous_embedding = previous_model.embeddings.get_named_embeddings_dict()[embedding.name] 

1120 previous_embedding.name = previous_embedding.name[2:] 

1121 new_stack.append(previous_embedding) 

1122 embedding_found = True 

1123 break 

1124 

1125 # if not found, use existing embedding 

1126 if not embedding_found: 

1127 embedding.name = embedding.name[2:] 

1128 new_stack.append(embedding) 

1129 

1130 # initialize new stack 

1131 model.embeddings = None 

1132 model.embeddings = StackedEmbeddings(new_stack) 

1133 

1134 else: 

1135 # of the model uses regular embedding, re-load if previous version found 

1136 if not model.embeddings.static_embeddings: 

1137 

1138 for previous_model in models: 

1139 if model.embeddings.name in previous_model.embeddings.get_named_embeddings_dict(): 

1140 previous_embedding = previous_model.embeddings.get_named_embeddings_dict()[ 

1141 model.embeddings.name] 

1142 if not previous_embedding.static_embeddings: 

1143 model.embeddings = previous_embedding 

1144 break 

1145 

1146 taggers[model_name] = model 

1147 models.append(model) 

1148 

1149 return cls(taggers)