Coverage for flair/flair/models/tars_model.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

393 statements  

1import logging 

2from collections import OrderedDict 

3from pathlib import Path 

4from typing import Union, List, Set, Optional 

5 

6import numpy as np 

7import torch 

8from sklearn.metrics.pairwise import cosine_similarity 

9from sklearn.preprocessing import minmax_scale 

10from tqdm import tqdm 

11 

12import flair 

13from flair.data import Dictionary, Sentence 

14from flair.datasets import SentenceDataset, DataLoader 

15from flair.embeddings import TokenEmbeddings 

16from flair.file_utils import cached_path 

17from flair.models import SequenceTagger, TextClassifier 

18from flair.training_utils import store_embeddings 

19 

20log = logging.getLogger("flair") 

21 

22 

23class FewshotClassifier(flair.nn.Classifier): 

24 

25 def __init__(self): 

26 self._current_task = None 

27 self._task_specific_attributes = {} 

28 self.label_nearest_map = None 

29 self.clean_up_labels: bool = True 

30 

31 super(FewshotClassifier, self).__init__() 

32 

33 def forward_loss( 

34 self, data_points: Union[List[Sentence], Sentence] 

35 ) -> torch.tensor: 

36 

37 if type(data_points) == Sentence: 

38 data_points = [data_points] 

39 

40 # Transform input data into TARS format 

41 sentences = self._get_tars_formatted_sentences(data_points) 

42 

43 loss = self.tars_model.forward_loss(sentences) 

44 return loss 

45 

46 @property 

47 def tars_embeddings(self): 

48 raise NotImplementedError 

49 

50 def _get_tars_formatted_sentence(self, label, sentence): 

51 raise NotImplementedError 

52 

53 def _get_tars_formatted_sentences(self, sentences: List[Sentence]): 

54 label_text_pairs = [] 

55 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

56 # print(all_labels) 

57 for sentence in sentences: 

58 label_text_pairs_for_sentence = [] 

59 if self.training and self.num_negative_labels_to_sample is not None: 

60 

61 positive_labels = list(OrderedDict.fromkeys( 

62 [label.value for label in sentence.get_labels(self.label_type)])) 

63 

64 sampled_negative_labels = self._get_nearest_labels_for(positive_labels) 

65 

66 for label in positive_labels: 

67 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

68 for label in sampled_negative_labels: 

69 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

70 

71 else: 

72 for label in all_labels: 

73 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

74 label_text_pairs.extend(label_text_pairs_for_sentence) 

75 

76 return label_text_pairs 

77 

78 def _get_nearest_labels_for(self, labels): 

79 

80 # if there are no labels, return a random sample as negatives 

81 if len(labels) == 0: 

82 tags = self.get_current_label_dictionary().get_items() 

83 import random 

84 sample = random.sample(tags, k=self.num_negative_labels_to_sample) 

85 # print(sample) 

86 return sample 

87 

88 already_sampled_negative_labels = set() 

89 

90 # otherwise, go through all labels 

91 for label in labels: 

92 

93 plausible_labels = [] 

94 plausible_label_probabilities = [] 

95 for plausible_label in self.label_nearest_map[label]: 

96 if plausible_label in already_sampled_negative_labels or plausible_label in labels: 

97 continue 

98 else: 

99 plausible_labels.append(plausible_label) 

100 plausible_label_probabilities.append(self.label_nearest_map[label][plausible_label]) 

101 

102 # make sure the probabilities always sum up to 1 

103 plausible_label_probabilities = np.array(plausible_label_probabilities, dtype='float64') 

104 plausible_label_probabilities += 1e-08 

105 plausible_label_probabilities /= np.sum(plausible_label_probabilities) 

106 

107 if len(plausible_labels) > 0: 

108 num_samples = min(self.num_negative_labels_to_sample, len(plausible_labels)) 

109 sampled_negative_labels = np.random.choice(plausible_labels, 

110 num_samples, 

111 replace=False, 

112 p=plausible_label_probabilities) 

113 already_sampled_negative_labels.update(sampled_negative_labels) 

114 

115 return already_sampled_negative_labels 

116 

117 def train(self, mode=True): 

118 """Populate label similarity map based on cosine similarity before running epoch 

119 

120 If the `num_negative_labels_to_sample` is set to an integer value then before starting 

121 each epoch the model would create a similarity measure between the label names based 

122 on cosine distances between their BERT encoded embeddings. 

123 """ 

124 if mode and self.num_negative_labels_to_sample is not None: 

125 self._compute_label_similarity_for_current_epoch() 

126 super().train(mode) 

127 

128 super().train(mode) 

129 

130 def _compute_label_similarity_for_current_epoch(self): 

131 """ 

132 Compute the similarity between all labels for better sampling of negatives 

133 """ 

134 

135 # get and embed all labels by making a Sentence object that contains only the label text 

136 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

137 label_sentences = [Sentence(label) for label in all_labels] 

138 

139 self.tars_embeddings.eval() # TODO: check if this is necessary 

140 self.tars_embeddings.embed(label_sentences) 

141 self.tars_embeddings.train() 

142 

143 # get each label embedding and scale between 0 and 1 

144 if isinstance(self.tars_embeddings, TokenEmbeddings): 

145 encodings_np = [sentence[0].get_embedding().cpu().detach().numpy() for sentence in label_sentences] 

146 else: 

147 encodings_np = [sentence.get_embedding().cpu().detach().numpy() for sentence in label_sentences] 

148 

149 normalized_encoding = minmax_scale(encodings_np) 

150 

151 # compute similarity matrix 

152 similarity_matrix = cosine_similarity(normalized_encoding) 

153 

154 # the higher the similarity, the greater the chance that a label is 

155 # sampled as negative example 

156 negative_label_probabilities = {} 

157 for row_index, label in enumerate(all_labels): 

158 negative_label_probabilities[label] = {} 

159 for column_index, other_label in enumerate(all_labels): 

160 if label != other_label: 

161 negative_label_probabilities[label][other_label] = \ 

162 similarity_matrix[row_index][column_index] 

163 self.label_nearest_map = negative_label_probabilities 

164 

165 def get_current_label_dictionary(self): 

166 label_dictionary = self._task_specific_attributes[self._current_task]['label_dictionary'] 

167 if self.clean_up_labels: 

168 # default: make new dictionary with modified labels (no underscores) 

169 dictionary = Dictionary(add_unk=False) 

170 for label in label_dictionary.get_items(): 

171 dictionary.add_item(label.replace("_", " ")) 

172 return dictionary 

173 else: 

174 return label_dictionary 

175 

176 def get_current_label_type(self): 

177 return self._task_specific_attributes[self._current_task]['label_type'] 

178 

179 def is_current_task_multi_label(self): 

180 return self._task_specific_attributes[self._current_task]['multi_label'] 

181 

182 def add_and_switch_to_new_task(self, 

183 task_name, 

184 label_dictionary: Union[List, Set, Dictionary, str], 

185 label_type: str, 

186 multi_label: bool = True, 

187 force_switch: bool = False, 

188 ): 

189 """ 

190 Adds a new task to an existing TARS model. Sets necessary attributes and finally 'switches' 

191 to the new task. Parameters are similar to the constructor except for model choice, batch 

192 size and negative sampling. This method does not store the resultant model onto disk. 

193 :param task_name: a string depicting the name of the task 

194 :param label_dictionary: dictionary of the labels you want to predict 

195 :param label_type: string to identify the label type ('ner', 'sentiment', etc.) 

196 :param multi_label: whether this task is a multi-label prediction problem 

197 :param force_switch: if True, will overwrite existing task with same name 

198 """ 

199 if task_name in self._task_specific_attributes and not force_switch: 

200 log.warning("Task `%s` already exists in TARS model. Switching to it.", task_name) 

201 else: 

202 # make label dictionary if no Dictionary object is passed 

203 if isinstance(label_dictionary, Dictionary): 

204 label_dictionary = label_dictionary.get_items() 

205 if type(label_dictionary) == str: 

206 label_dictionary = [label_dictionary] 

207 

208 # prepare dictionary of tags (without B- I- prefixes and without UNK) 

209 tag_dictionary = Dictionary(add_unk=False) 

210 for tag in label_dictionary: 

211 if tag == '<unk>' or tag == 'O': continue 

212 if tag[1] == "-": 

213 tag = tag[2:] 

214 tag_dictionary.add_item(tag) 

215 else: 

216 tag_dictionary.add_item(tag) 

217 

218 self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary, 

219 'label_type': label_type, 

220 'multi_label': multi_label} 

221 

222 self.switch_to_task(task_name) 

223 

224 def list_existing_tasks(self) -> Set[str]: 

225 """ 

226 Lists existing tasks in the loaded TARS model on the console. 

227 """ 

228 return set(self._task_specific_attributes.keys()) 

229 

230 def switch_to_task(self, task_name): 

231 """ 

232 Switches to a task which was previously added. 

233 """ 

234 if task_name not in self._task_specific_attributes: 

235 log.error("Provided `%s` does not exist in the model. Consider calling " 

236 "`add_and_switch_to_new_task` first.", task_name) 

237 else: 

238 self._current_task = task_name 

239 

240 def _drop_task(self, task_name): 

241 if task_name in self._task_specific_attributes: 

242 if self._current_task == task_name: 

243 log.error("`%s` is the current task." 

244 " Switch to some other task before dropping this.", task_name) 

245 else: 

246 self._task_specific_attributes.pop(task_name) 

247 else: 

248 log.warning("No task exists with the name `%s`.", task_name) 

249 

250 @staticmethod 

251 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: 

252 filtered_sentences = [sentence for sentence in sentences if sentence.tokens] 

253 if len(sentences) != len(filtered_sentences): 

254 log.warning( 

255 f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens." 

256 ) 

257 return filtered_sentences 

258 

259 @property 

260 def label_type(self): 

261 return self.get_current_label_type() 

262 

263 def predict_zero_shot(self, 

264 sentences: Union[List[Sentence], Sentence], 

265 candidate_label_set: Union[List[str], Set[str], str], 

266 multi_label: bool = True): 

267 """ 

268 Method to make zero shot predictions from the TARS model 

269 :param sentences: input sentence objects to classify 

270 :param candidate_label_set: set of candidate labels 

271 :param multi_label: indicates whether multi-label or single class prediction. Defaults to True. 

272 """ 

273 

274 # check if candidate_label_set is empty 

275 if candidate_label_set is None or len(candidate_label_set) == 0: 

276 log.warning("Provided candidate_label_set is empty") 

277 return 

278 

279 # make list if only one candidate label is passed 

280 if isinstance(candidate_label_set, str): 

281 candidate_label_set = {candidate_label_set} 

282 

283 # create label dictionary 

284 label_dictionary = Dictionary(add_unk=False) 

285 for label in candidate_label_set: 

286 label_dictionary.add_item(label) 

287 

288 # note current task 

289 existing_current_task = self._current_task 

290 

291 # create a temporary task 

292 self.add_and_switch_to_new_task(task_name="ZeroShot", 

293 label_dictionary=label_dictionary, 

294 label_type='-'.join(label_dictionary.get_items()), 

295 multi_label=multi_label) 

296 

297 try: 

298 # make zero shot predictions 

299 self.predict(sentences) 

300 finally: 

301 # switch to the pre-existing task 

302 self.switch_to_task(existing_current_task) 

303 self._drop_task("ZeroShot") 

304 

305 return 

306 

307 

308class TARSTagger(FewshotClassifier): 

309 """ 

310 TARS model for sequence tagging. In the backend, the model uses a BERT based 5-class 

311 sequence labeler which given a <label, text> pair predicts the probability for each word 

312 to belong to one of the BIOES classes. The input data is a usual Sentence object which is inflated 

313 by the model internally before pushing it through the transformer stack of BERT. 

314 """ 

315 

316 static_label_type = "tars_label" 

317 

318 def __init__( 

319 self, 

320 task_name: Optional[str] = None, 

321 label_dictionary: Optional[Dictionary] = None, 

322 label_type: Optional[str] = None, 

323 embeddings: str = 'bert-base-uncased', 

324 num_negative_labels_to_sample: int = 2, 

325 prefix: bool = True, 

326 **tagger_args, 

327 ): 

328 """ 

329 Initializes a TextClassifier 

330 :param task_name: a string depicting the name of the task 

331 :param label_dictionary: dictionary of labels you want to predict 

332 :param embeddings: name of the pre-trained transformer model e.g., 

333 'bert-base-uncased' etc 

334 :param num_negative_labels_to_sample: number of negative labels to sample for each 

335 positive labels against a sentence during training. Defaults to 2 negative 

336 labels for each positive label. The model would sample all the negative labels 

337 if None is passed. That slows down the training considerably. 

338 """ 

339 super(TARSTagger, self).__init__() 

340 

341 from flair.embeddings import TransformerWordEmbeddings 

342 

343 if not isinstance(embeddings, TransformerWordEmbeddings): 

344 embeddings = TransformerWordEmbeddings(model=embeddings, 

345 fine_tune=True, 

346 layers='-1', 

347 layer_mean=False, 

348 ) 

349 

350 # prepare TARS dictionary 

351 tars_dictionary = Dictionary(add_unk=False) 

352 tars_dictionary.add_item('O') 

353 tars_dictionary.add_item('S-') 

354 tars_dictionary.add_item('B-') 

355 tars_dictionary.add_item('E-') 

356 tars_dictionary.add_item('I-') 

357 

358 # initialize a bare-bones sequence tagger 

359 self.tars_model = SequenceTagger(123, 

360 embeddings, 

361 tag_dictionary=tars_dictionary, 

362 tag_type=self.static_label_type, 

363 use_crf=False, 

364 use_rnn=False, 

365 reproject_embeddings=False, 

366 **tagger_args, 

367 ) 

368 

369 # transformer separator 

370 self.separator = str(self.tars_embeddings.tokenizer.sep_token) 

371 if self.tars_embeddings.tokenizer._bos_token: 

372 self.separator += str(self.tars_embeddings.tokenizer.bos_token) 

373 

374 self.prefix = prefix 

375 self.num_negative_labels_to_sample = num_negative_labels_to_sample 

376 

377 if task_name and label_dictionary and label_type: 

378 # Store task specific labels since TARS can handle multiple tasks 

379 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type) 

380 else: 

381 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() " 

382 "before training this model") 

383 

384 def _get_tars_formatted_sentence(self, label, sentence): 

385 

386 original_text = sentence.to_tokenized_string() 

387 

388 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \ 

389 else f"{original_text} {self.separator} {label}" 

390 

391 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" ")) 

392 

393 # make a tars sentence where all labels are O by default 

394 tars_sentence = Sentence(label_text_pair, use_tokenizer=False) 

395 for token in tars_sentence: 

396 token.add_tag(self.static_label_type, "O") 

397 

398 # overwrite O labels with tags 

399 for token in sentence: 

400 tag = token.get_tag(self.get_current_label_type()).value 

401 

402 if tag == "O" or tag == "": 

403 tars_tag = "O" 

404 elif tag == label: 

405 tars_tag = "S-" 

406 elif tag[1] == "-" and tag[2:] == label: 

407 tars_tag = tag.split('-')[0] + '-' 

408 else: 

409 tars_tag = "O" 

410 

411 tars_sentence.get_token(token.idx + label_length).add_tag(self.static_label_type, tars_tag) 

412 

413 return tars_sentence 

414 

415 def _get_state_dict(self): 

416 model_state = { 

417 "state_dict": self.state_dict(), 

418 

419 "current_task": self._current_task, 

420 "tag_type": self.get_current_label_type(), 

421 "tag_dictionary": self.get_current_label_dictionary(), 

422 "tars_model": self.tars_model, 

423 "num_negative_labels_to_sample": self.num_negative_labels_to_sample, 

424 "prefix": self.prefix, 

425 

426 "task_specific_attributes": self._task_specific_attributes, 

427 } 

428 return model_state 

429 

430 @staticmethod 

431 def _fetch_model(model_name) -> str: 

432 

433 if model_name == "tars-ner": 

434 cache_dir = Path("models") 

435 model_name = cached_path("https://nlp.informatik.hu-berlin.de/resources/models/tars-ner/tars-ner.pt", 

436 cache_dir=cache_dir) 

437 

438 return model_name 

439 

440 @staticmethod 

441 def _init_model_with_state_dict(state): 

442 

443 # init new TARS classifier 

444 model = TARSTagger( 

445 task_name=state["current_task"], 

446 label_dictionary=state["tag_dictionary"], 

447 label_type=state["tag_type"], 

448 embeddings=state["tars_model"].embeddings, 

449 num_negative_labels_to_sample=state["num_negative_labels_to_sample"], 

450 prefix=state["prefix"], 

451 ) 

452 # set all task information 

453 model._task_specific_attributes = state["task_specific_attributes"] 

454 

455 # linear layers of internal classifier 

456 model.load_state_dict(state["state_dict"]) 

457 return model 

458 

459 @property 

460 def tars_embeddings(self): 

461 return self.tars_model.embeddings 

462 

463 def predict( 

464 self, 

465 sentences: Union[List[Sentence], Sentence], 

466 mini_batch_size=32, 

467 verbose: bool = False, 

468 label_name: Optional[str] = None, 

469 return_loss=False, 

470 embedding_storage_mode="none", 

471 most_probable_first: bool = True 

472 ): 

473 # return 

474 """ 

475 Predict sequence tags for Named Entity Recognition task 

476 :param sentences: a Sentence or a List of Sentence 

477 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

478 up to a point when it has no more effect. 

479 :param all_tag_prob: True to compute the score for each tag on each token, 

480 otherwise only the score of the best tag is returned 

481 :param verbose: set to True to display a progress bar 

482 :param return_loss: set to True to return loss 

483 :param label_name: set this to change the name of the label type that is predicted 

484 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

485 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

486 'gpu' to store embeddings in GPU memory. 

487 """ 

488 if label_name == None: 

489 label_name = self.get_current_label_type() 

490 

491 # with torch.no_grad(): 

492 if not sentences: 

493 return sentences 

494 

495 if isinstance(sentences, Sentence): 

496 sentences = [sentences] 

497 

498 # reverse sort all sequences by their length 

499 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) 

500 

501 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index] 

502 

503 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size) 

504 

505 # progress bar for verbosity 

506 if verbose: 

507 dataloader = tqdm(dataloader) 

508 

509 overall_loss = 0 

510 overall_count = 0 

511 batch_no = 0 

512 with torch.no_grad(): 

513 for batch in dataloader: 

514 

515 batch_no += 1 

516 

517 if verbose: 

518 dataloader.set_description(f"Inferencing on batch {batch_no}") 

519 

520 batch = self._filter_empty_sentences(batch) 

521 # stop if all sentences are empty 

522 if not batch: 

523 continue 

524 

525 # go through each sentence in the batch 

526 for sentence in batch: 

527 

528 # always remove tags first 

529 for token in sentence: 

530 token.remove_labels(label_name) 

531 

532 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

533 

534 all_detected = {} 

535 for label in all_labels: 

536 tars_sentence = self._get_tars_formatted_sentence(label, sentence) 

537 

538 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" ")) 

539 

540 loss_and_count = self.tars_model.predict(tars_sentence, 

541 label_name=label_name, 

542 all_tag_prob=True, 

543 return_loss=True) 

544 overall_loss += loss_and_count[0].item() 

545 overall_count += loss_and_count[1] 

546 

547 for span in tars_sentence.get_spans(label_name): 

548 span.set_label('tars_temp_label', label) 

549 all_detected[span] = span.score 

550 

551 if not most_probable_first: 

552 for span in tars_sentence.get_spans(label_name): 

553 for token in span: 

554 corresponding_token = sentence.get_token(token.idx - label_length) 

555 if corresponding_token is None: continue 

556 if corresponding_token.get_tag(label_name).value != '' and \ 

557 corresponding_token.get_tag(label_name).score > token.get_tag( 

558 label_name).score: 

559 continue 

560 corresponding_token.add_tag( 

561 label_name, 

562 token.get_tag(label_name).value + label, 

563 token.get_tag(label_name).score, 

564 ) 

565 

566 if most_probable_first: 

567 import operator 

568 sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) 

569 sorted_x.reverse() 

570 for tuple in sorted_x: 

571 # get the span and its label 

572 span = tuple[0] 

573 label = span.get_labels('tars_temp_label')[0].value 

574 label_length = 0 if not self.prefix else len(label.split(" ")) + len( 

575 self.separator.split(" ")) 

576 

577 # determine whether tokens in this span already have a label 

578 tag_this = True 

579 for token in span: 

580 corresponding_token = sentence.get_token(token.idx - label_length) 

581 if corresponding_token is None: 

582 tag_this = False 

583 continue 

584 if corresponding_token.get_tag(label_name).value != '' and \ 

585 corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score: 

586 tag_this = False 

587 continue 

588 

589 # only add if all tokens have no label 

590 if tag_this: 

591 for token in span: 

592 corresponding_token = sentence.get_token(token.idx - label_length) 

593 corresponding_token.add_tag( 

594 label_name, 

595 token.get_tag(label_name).value + label, 

596 token.get_tag(label_name).score, 

597 ) 

598 

599 # clearing token embeddings to save memory 

600 store_embeddings(batch, storage_mode=embedding_storage_mode) 

601 

602 if return_loss: 

603 return overall_loss, overall_count 

604 

605 

606class TARSClassifier(FewshotClassifier): 

607 """ 

608 TARS model for text classification. In the backend, the model uses a BERT based binary 

609 text classifier which given a <label, text> pair predicts the probability of two classes 

610 "True", and "False". The input data is a usual Sentence object which is inflated 

611 by the model internally before pushing it through the transformer stack of BERT. 

612 """ 

613 

614 static_label_type = "tars_label" 

615 LABEL_MATCH = "YES" 

616 LABEL_NO_MATCH = "NO" 

617 

618 def __init__( 

619 self, 

620 task_name: Optional[str] = None, 

621 label_dictionary: Optional[Dictionary] = None, 

622 label_type: Optional[str] = None, 

623 embeddings: str = 'bert-base-uncased', 

624 num_negative_labels_to_sample: int = 2, 

625 prefix: bool = True, 

626 **tagger_args, 

627 ): 

628 """ 

629 Initializes a TextClassifier 

630 :param task_name: a string depicting the name of the task 

631 :param label_dictionary: dictionary of labels you want to predict 

632 :param embeddings: name of the pre-trained transformer model e.g., 

633 'bert-base-uncased' etc 

634 :param num_negative_labels_to_sample: number of negative labels to sample for each 

635 positive labels against a sentence during training. Defaults to 2 negative 

636 labels for each positive label. The model would sample all the negative labels 

637 if None is passed. That slows down the training considerably. 

638 :param multi_label: auto-detected by default, but you can set this to True 

639 to force multi-label predictionor False to force single-label prediction 

640 :param multi_label_threshold: If multi-label you can set the threshold to make predictions 

641 :param beta: Parameter for F-beta score for evaluation and training annealing 

642 """ 

643 super(TARSClassifier, self).__init__() 

644 

645 from flair.embeddings import TransformerDocumentEmbeddings 

646 

647 if not isinstance(embeddings, TransformerDocumentEmbeddings): 

648 embeddings = TransformerDocumentEmbeddings(model=embeddings, 

649 fine_tune=True, 

650 layers='-1', 

651 layer_mean=False, 

652 ) 

653 

654 # prepare TARS dictionary 

655 tars_dictionary = Dictionary(add_unk=False) 

656 tars_dictionary.add_item(self.LABEL_NO_MATCH) 

657 tars_dictionary.add_item(self.LABEL_MATCH) 

658 

659 # initialize a bare-bones sequence tagger 

660 self.tars_model = TextClassifier(document_embeddings=embeddings, 

661 label_dictionary=tars_dictionary, 

662 label_type=self.static_label_type, 

663 **tagger_args, 

664 ) 

665 

666 # transformer separator 

667 self.separator = str(self.tars_embeddings.tokenizer.sep_token) 

668 if self.tars_embeddings.tokenizer._bos_token: 

669 self.separator += str(self.tars_embeddings.tokenizer.bos_token) 

670 

671 self.prefix = prefix 

672 self.num_negative_labels_to_sample = num_negative_labels_to_sample 

673 

674 if task_name and label_dictionary and label_type: 

675 # Store task specific labels since TARS can handle multiple tasks 

676 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type) 

677 else: 

678 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() " 

679 "before training this model") 

680 

681 def _get_tars_formatted_sentence(self, label, sentence): 

682 

683 original_text = sentence.to_tokenized_string() 

684 

685 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \ 

686 else f"{original_text} {self.separator} {label}" 

687 

688 sentence_labels = [label.value for label in sentence.get_labels(self.get_current_label_type())] 

689 

690 tars_label = self.LABEL_MATCH if label in sentence_labels else self.LABEL_NO_MATCH 

691 

692 tars_sentence = Sentence(label_text_pair, use_tokenizer=False).add_label(self.static_label_type, tars_label) 

693 

694 return tars_sentence 

695 

696 def _get_state_dict(self): 

697 model_state = { 

698 "state_dict": self.state_dict(), 

699 

700 "current_task": self._current_task, 

701 "label_type": self.get_current_label_type(), 

702 "label_dictionary": self.get_current_label_dictionary(), 

703 "tars_model": self.tars_model, 

704 "num_negative_labels_to_sample": self.num_negative_labels_to_sample, 

705 

706 "task_specific_attributes": self._task_specific_attributes, 

707 } 

708 return model_state 

709 

710 @staticmethod 

711 def _init_model_with_state_dict(state): 

712 

713 # init new TARS classifier 

714 label_dictionary = state["label_dictionary"] 

715 label_type = "default_label" if not state["label_type"] else state["label_type"] 

716 

717 model: TARSClassifier = TARSClassifier( 

718 task_name=state["current_task"], 

719 label_dictionary=label_dictionary, 

720 label_type=label_type, 

721 embeddings=state["tars_model"].document_embeddings, 

722 num_negative_labels_to_sample=state["num_negative_labels_to_sample"], 

723 ) 

724 

725 # set all task information 

726 model._task_specific_attributes = state["task_specific_attributes"] 

727 

728 # linear layers of internal classifier 

729 model.load_state_dict(state["state_dict"]) 

730 return model 

731 

732 @staticmethod 

733 def _fetch_model(model_name) -> str: 

734 

735 model_map = {} 

736 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" 

737 

738 model_map["tars-base"] = "/".join([hu_path, "tars-base", "tars-base-v8.pt"]) 

739 

740 cache_dir = Path("models") 

741 if model_name in model_map: 

742 model_name = cached_path(model_map[model_name], cache_dir=cache_dir) 

743 

744 return model_name 

745 

746 @property 

747 def tars_embeddings(self): 

748 return self.tars_model.document_embeddings 

749 

750 def predict( 

751 self, 

752 sentences: Union[List[Sentence], Sentence], 

753 mini_batch_size=32, 

754 verbose: bool = False, 

755 label_name: Optional[str] = None, 

756 return_loss=False, 

757 embedding_storage_mode="none", 

758 label_threshold: float = 0.5, 

759 multi_label: Optional[bool] = None, 

760 ): 

761 """ 

762 Predict sequence tags for Named Entity Recognition task 

763 :param sentences: a Sentence or a List of Sentence 

764 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

765 up to a point when it has no more effect. 

766 :param all_tag_prob: True to compute the score for each tag on each token, 

767 otherwise only the score of the best tag is returned 

768 :param verbose: set to True to display a progress bar 

769 :param return_loss: set to True to return loss 

770 :param label_name: set this to change the name of the label type that is predicted 

771 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

772 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

773 'gpu' to store embeddings in GPU memory. 

774 """ 

775 if not label_name: 

776 label_name = self.get_current_label_type() 

777 

778 if multi_label is None: 

779 multi_label = self.is_current_task_multi_label() 

780 

781 # with torch.no_grad(): 

782 if not sentences: 

783 return sentences 

784 

785 if isinstance(sentences, Sentence): 

786 sentences = [sentences] 

787 

788 # set context if not set already 

789 previous_sentence = None 

790 for sentence in sentences: 

791 if sentence.is_context_set(): continue 

792 sentence._previous_sentence = previous_sentence 

793 sentence._next_sentence = None 

794 if previous_sentence: previous_sentence._next_sentence = sentence 

795 previous_sentence = sentence 

796 

797 # reverse sort all sequences by their length 

798 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) 

799 

800 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index] 

801 

802 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size) 

803 

804 # progress bar for verbosity 

805 if verbose: 

806 dataloader = tqdm(dataloader) 

807 

808 overall_loss = 0 

809 overall_count = 0 

810 batch_no = 0 

811 with torch.no_grad(): 

812 for batch in dataloader: 

813 

814 batch_no += 1 

815 

816 if verbose: 

817 dataloader.set_description(f"Inferencing on batch {batch_no}") 

818 

819 batch = self._filter_empty_sentences(batch) 

820 # stop if all sentences are empty 

821 if not batch: 

822 continue 

823 

824 # go through each sentence in the batch 

825 for sentence in batch: 

826 

827 # always remove tags first 

828 sentence.remove_labels(label_name) 

829 

830 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

831 

832 best_label = None 

833 for label in all_labels: 

834 tars_sentence = self._get_tars_formatted_sentence(label, sentence) 

835 

836 loss_and_count = self.tars_model.predict(tars_sentence, 

837 label_name=label_name, 

838 return_loss=True, 

839 return_probabilities_for_all_classes=True 

840 if label_threshold < 0.5 else False, 

841 ) 

842 

843 overall_loss += loss_and_count[0].item() 

844 overall_count += loss_and_count[1] 

845 

846 # add all labels that according to TARS match the text and are above threshold 

847 for predicted_tars_label in tars_sentence.get_labels(label_name): 

848 if predicted_tars_label.value == self.LABEL_MATCH \ 

849 and predicted_tars_label.score > label_threshold: 

850 # do not add labels below confidence threshold 

851 sentence.add_label(label_name, label, predicted_tars_label.score) 

852 

853 # only use label with highest confidence if enforcing single-label predictions 

854 if not multi_label: 

855 if len(sentence.get_labels(label_name)) > 0: 

856 

857 # get all label scores and do an argmax to get the best label 

858 label_scores = torch.tensor([label.score for label in sentence.get_labels(label_name)], 

859 dtype=torch.float) 

860 best_label = sentence.get_labels(label_name)[torch.argmax(label_scores)] 

861 

862 # remove previously added labels and only add the best label 

863 sentence.remove_labels(label_name) 

864 sentence.add_label(typename=label_name, value=best_label.value, score=best_label.score) 

865 

866 # clearing token embeddings to save memory 

867 store_embeddings(batch, storage_mode=embedding_storage_mode) 

868 

869 if return_loss: 

870 return overall_loss, overall_count