Coverage for flair/flair/models/tars_model.py: 17%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

393 statements  

1import logging 

2from collections import OrderedDict 

3from pathlib import Path 

4from typing import Union, List, Set, Optional 

5 

6import numpy as np 

7import torch 

8from sklearn.metrics.pairwise import cosine_similarity 

9from sklearn.preprocessing import minmax_scale 

10from tqdm import tqdm 

11 

12import flair 

13from flair.data import Dictionary, Sentence 

14from flair.datasets import SentenceDataset, DataLoader 

15from flair.embeddings import TokenEmbeddings 

16from flair.file_utils import cached_path 

17from flair.models import SequenceTagger, TextClassifier 

18from flair.training_utils import store_embeddings 

19 

20log = logging.getLogger("flair") 

21 

22 

23class FewshotClassifier(flair.nn.Classifier): 

24 

25 def __init__(self): 

26 self._current_task = None 

27 self._task_specific_attributes = {} 

28 self.label_nearest_map = None 

29 

30 super(FewshotClassifier, self).__init__() 

31 

32 def forward_loss( 

33 self, data_points: Union[List[Sentence], Sentence] 

34 ) -> torch.tensor: 

35 

36 if type(data_points) == Sentence: 

37 data_points = [data_points] 

38 

39 # Transform input data into TARS format 

40 sentences = self._get_tars_formatted_sentences(data_points) 

41 

42 loss = self.tars_model.forward_loss(sentences) 

43 return loss 

44 

45 @property 

46 def tars_embeddings(self): 

47 raise NotImplementedError 

48 

49 def _get_tars_formatted_sentence(self, label, sentence): 

50 raise NotImplementedError 

51 

52 def _get_tars_formatted_sentences(self, sentences: List[Sentence]): 

53 label_text_pairs = [] 

54 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

55 # print(all_labels) 

56 for sentence in sentences: 

57 label_text_pairs_for_sentence = [] 

58 if self.training and self.num_negative_labels_to_sample is not None: 

59 

60 positive_labels = list(OrderedDict.fromkeys( 

61 [label.value for label in sentence.get_labels(self.label_type)])) 

62 

63 sampled_negative_labels = self._get_nearest_labels_for(positive_labels) 

64 

65 for label in positive_labels: 

66 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

67 for label in sampled_negative_labels: 

68 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

69 

70 else: 

71 for label in all_labels: 

72 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence)) 

73 label_text_pairs.extend(label_text_pairs_for_sentence) 

74 

75 return label_text_pairs 

76 

77 def _get_nearest_labels_for(self, labels): 

78 

79 # if there are no labels, return a random sample as negatives 

80 if len(labels) == 0: 

81 tags = self.get_current_label_dictionary().get_items() 

82 import random 

83 sample = random.sample(tags, k=self.num_negative_labels_to_sample) 

84 # print(sample) 

85 return sample 

86 

87 already_sampled_negative_labels = set() 

88 

89 # otherwise, go through all labels 

90 for label in labels: 

91 

92 plausible_labels = [] 

93 plausible_label_probabilities = [] 

94 for plausible_label in self.label_nearest_map[label]: 

95 if plausible_label in already_sampled_negative_labels or plausible_label in labels: 

96 continue 

97 else: 

98 plausible_labels.append(plausible_label) 

99 plausible_label_probabilities.append(self.label_nearest_map[label][plausible_label]) 

100 

101 # make sure the probabilities always sum up to 1 

102 plausible_label_probabilities = np.array(plausible_label_probabilities, dtype='float64') 

103 plausible_label_probabilities += 1e-08 

104 plausible_label_probabilities /= np.sum(plausible_label_probabilities) 

105 

106 if len(plausible_labels) > 0: 

107 num_samples = min(self.num_negative_labels_to_sample, len(plausible_labels)) 

108 sampled_negative_labels = np.random.choice(plausible_labels, 

109 num_samples, 

110 replace=False, 

111 p=plausible_label_probabilities) 

112 already_sampled_negative_labels.update(sampled_negative_labels) 

113 

114 return already_sampled_negative_labels 

115 

116 def train(self, mode=True): 

117 """Populate label similarity map based on cosine similarity before running epoch 

118 

119 If the `num_negative_labels_to_sample` is set to an integer value then before starting 

120 each epoch the model would create a similarity measure between the label names based 

121 on cosine distances between their BERT encoded embeddings. 

122 """ 

123 if mode and self.num_negative_labels_to_sample is not None: 

124 self._compute_label_similarity_for_current_epoch() 

125 super().train(mode) 

126 

127 super().train(mode) 

128 

129 def _compute_label_similarity_for_current_epoch(self): 

130 """ 

131 Compute the similarity between all labels for better sampling of negatives 

132 """ 

133 

134 # get and embed all labels by making a Sentence object that contains only the label text 

135 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

136 label_sentences = [Sentence(label) for label in all_labels] 

137 

138 self.tars_embeddings.eval() # TODO: check if this is necessary 

139 self.tars_embeddings.embed(label_sentences) 

140 self.tars_embeddings.train() 

141 

142 # get each label embedding and scale between 0 and 1 

143 if isinstance(self.tars_embeddings, TokenEmbeddings): 

144 encodings_np = [sentence[0].get_embedding().cpu().detach().numpy() for sentence in label_sentences] 

145 else: 

146 encodings_np = [sentence.get_embedding().cpu().detach().numpy() for sentence in label_sentences] 

147 

148 normalized_encoding = minmax_scale(encodings_np) 

149 

150 # compute similarity matrix 

151 similarity_matrix = cosine_similarity(normalized_encoding) 

152 

153 # the higher the similarity, the greater the chance that a label is 

154 # sampled as negative example 

155 negative_label_probabilities = {} 

156 for row_index, label in enumerate(all_labels): 

157 negative_label_probabilities[label] = {} 

158 for column_index, other_label in enumerate(all_labels): 

159 if label != other_label: 

160 negative_label_probabilities[label][other_label] = \ 

161 similarity_matrix[row_index][column_index] 

162 self.label_nearest_map = negative_label_probabilities 

163 

164 def get_current_label_dictionary(self): 

165 label_dictionary = self._task_specific_attributes[self._current_task]['label_dictionary'] 

166 return label_dictionary 

167 

168 def get_current_label_type(self): 

169 return self._task_specific_attributes[self._current_task]['label_type'] 

170 

171 def is_current_task_multi_label(self): 

172 return self._task_specific_attributes[self._current_task]['multi_label'] 

173 

174 def add_and_switch_to_new_task(self, 

175 task_name, 

176 label_dictionary: Union[List, Set, Dictionary, str], 

177 label_type: str, 

178 multi_label: bool = True, 

179 force_switch: bool = False, 

180 ): 

181 """ 

182 Adds a new task to an existing TARS model. Sets necessary attributes and finally 'switches' 

183 to the new task. Parameters are similar to the constructor except for model choice, batch 

184 size and negative sampling. This method does not store the resultant model onto disk. 

185 :param task_name: a string depicting the name of the task 

186 :param label_dictionary: dictionary of the labels you want to predict 

187 :param label_type: string to identify the label type ('ner', 'sentiment', etc.) 

188 :param multi_label: whether this task is a multi-label prediction problem 

189 :param force_switch: if True, will overwrite existing task with same name 

190 """ 

191 if task_name in self._task_specific_attributes and not force_switch: 

192 log.warning("Task `%s` already exists in TARS model. Switching to it.", task_name) 

193 else: 

194 # make label dictionary if no Dictionary object is passed 

195 if isinstance(label_dictionary, Dictionary): 

196 label_dictionary = label_dictionary.get_items() 

197 if type(label_dictionary) == str: 

198 label_dictionary = [label_dictionary] 

199 

200 # prepare dictionary of tags (without B- I- prefixes and without UNK) 

201 tag_dictionary = Dictionary(add_unk=False) 

202 for tag in label_dictionary: 

203 if tag == '<unk>' or tag == 'O': continue 

204 if tag[1] == "-": 

205 tag = tag[2:] 

206 tag_dictionary.add_item(tag) 

207 else: 

208 tag_dictionary.add_item(tag) 

209 

210 self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary, 

211 'label_type': label_type, 

212 'multi_label': multi_label} 

213 

214 self.switch_to_task(task_name) 

215 

216 def list_existing_tasks(self) -> Set[str]: 

217 """ 

218 Lists existing tasks in the loaded TARS model on the console. 

219 """ 

220 return set(self._task_specific_attributes.keys()) 

221 

222 def switch_to_task(self, task_name): 

223 """ 

224 Switches to a task which was previously added. 

225 """ 

226 if task_name not in self._task_specific_attributes: 

227 log.error("Provided `%s` does not exist in the model. Consider calling " 

228 "`add_and_switch_to_new_task` first.", task_name) 

229 else: 

230 self._current_task = task_name 

231 

232 def _drop_task(self, task_name): 

233 if task_name in self._task_specific_attributes: 

234 if self._current_task == task_name: 

235 log.error("`%s` is the current task." 

236 " Switch to some other task before dropping this.", task_name) 

237 else: 

238 self._task_specific_attributes.pop(task_name) 

239 else: 

240 log.warning("No task exists with the name `%s`.", task_name) 

241 

242 @staticmethod 

243 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]: 

244 filtered_sentences = [sentence for sentence in sentences if sentence.tokens] 

245 if len(sentences) != len(filtered_sentences): 

246 log.warning( 

247 f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens." 

248 ) 

249 return filtered_sentences 

250 

251 @property 

252 def label_type(self): 

253 return self.get_current_label_type() 

254 

255 def predict_zero_shot(self, 

256 sentences: Union[List[Sentence], Sentence], 

257 candidate_label_set: Union[List[str], Set[str], str], 

258 multi_label: bool = True): 

259 """ 

260 Method to make zero shot predictions from the TARS model 

261 :param sentences: input sentence objects to classify 

262 :param candidate_label_set: set of candidate labels 

263 :param multi_label: indicates whether multi-label or single class prediction. Defaults to True. 

264 """ 

265 

266 # check if candidate_label_set is empty 

267 if candidate_label_set is None or len(candidate_label_set) == 0: 

268 log.warning("Provided candidate_label_set is empty") 

269 return 

270 

271 # make list if only one candidate label is passed 

272 if isinstance(candidate_label_set, str): 

273 candidate_label_set = {candidate_label_set} 

274 

275 # create label dictionary 

276 label_dictionary = Dictionary(add_unk=False) 

277 for label in candidate_label_set: 

278 label_dictionary.add_item(label) 

279 

280 # note current task 

281 existing_current_task = self._current_task 

282 

283 # create a temporary task 

284 self.add_and_switch_to_new_task(task_name="ZeroShot", 

285 label_dictionary=label_dictionary, 

286 label_type='-'.join(label_dictionary.get_items()), 

287 multi_label=multi_label) 

288 

289 try: 

290 # make zero shot predictions 

291 self.predict(sentences) 

292 finally: 

293 # switch to the pre-existing task 

294 self.switch_to_task(existing_current_task) 

295 self._drop_task("ZeroShot") 

296 

297 return 

298 

299 

300class TARSTagger(FewshotClassifier): 

301 """ 

302 TARS model for sequence tagging. In the backend, the model uses a BERT based 5-class 

303 sequence labeler which given a <label, text> pair predicts the probability for each word 

304 to belong to one of the BIOES classes. The input data is a usual Sentence object which is inflated 

305 by the model internally before pushing it through the transformer stack of BERT. 

306 """ 

307 

308 static_label_type = "tars_label" 

309 

310 def __init__( 

311 self, 

312 task_name: Optional[str] = None, 

313 label_dictionary: Optional[Dictionary] = None, 

314 label_type: Optional[str] = None, 

315 embeddings: str = 'bert-base-uncased', 

316 num_negative_labels_to_sample: int = 2, 

317 prefix: bool = True, 

318 **tagger_args, 

319 ): 

320 """ 

321 Initializes a TextClassifier 

322 :param task_name: a string depicting the name of the task 

323 :param label_dictionary: dictionary of labels you want to predict 

324 :param embeddings: name of the pre-trained transformer model e.g., 

325 'bert-base-uncased' etc 

326 :param num_negative_labels_to_sample: number of negative labels to sample for each 

327 positive labels against a sentence during training. Defaults to 2 negative 

328 labels for each positive label. The model would sample all the negative labels 

329 if None is passed. That slows down the training considerably. 

330 """ 

331 super(TARSTagger, self).__init__() 

332 

333 from flair.embeddings import TransformerWordEmbeddings 

334 

335 if not isinstance(embeddings, TransformerWordEmbeddings): 

336 embeddings = TransformerWordEmbeddings(model=embeddings, 

337 fine_tune=True, 

338 layers='-1', 

339 layer_mean=False, 

340 ) 

341 

342 # prepare TARS dictionary 

343 tars_dictionary = Dictionary(add_unk=False) 

344 tars_dictionary.add_item('O') 

345 tars_dictionary.add_item('S-') 

346 tars_dictionary.add_item('B-') 

347 tars_dictionary.add_item('E-') 

348 tars_dictionary.add_item('I-') 

349 

350 # initialize a bare-bones sequence tagger 

351 self.tars_model = SequenceTagger(123, 

352 embeddings, 

353 tag_dictionary=tars_dictionary, 

354 tag_type=self.static_label_type, 

355 use_crf=False, 

356 use_rnn=False, 

357 reproject_embeddings=False, 

358 **tagger_args, 

359 ) 

360 

361 # transformer separator 

362 self.separator = str(self.tars_embeddings.tokenizer.sep_token) 

363 if self.tars_embeddings.tokenizer._bos_token: 

364 self.separator += str(self.tars_embeddings.tokenizer.bos_token) 

365 

366 self.prefix = prefix 

367 self.num_negative_labels_to_sample = num_negative_labels_to_sample 

368 

369 if task_name and label_dictionary and label_type: 

370 # Store task specific labels since TARS can handle multiple tasks 

371 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type) 

372 else: 

373 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() " 

374 "before training this model") 

375 

376 def _get_tars_formatted_sentence(self, label, sentence): 

377 

378 original_text = sentence.to_tokenized_string() 

379 

380 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \ 

381 else f"{original_text} {self.separator} {label}" 

382 

383 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" ")) 

384 

385 # make a tars sentence where all labels are O by default 

386 tars_sentence = Sentence(label_text_pair, use_tokenizer=False) 

387 for token in tars_sentence: 

388 token.add_tag(self.static_label_type, "O") 

389 

390 # overwrite O labels with tags 

391 for token in sentence: 

392 tag = token.get_tag(self.get_current_label_type()).value 

393 

394 if tag == "O" or tag == "": 

395 tars_tag = "O" 

396 elif tag == label: 

397 tars_tag = "S-" 

398 elif tag[1] == "-" and tag[2:] == label: 

399 tars_tag = tag.split('-')[0] + '-' 

400 else: 

401 tars_tag = "O" 

402 

403 tars_sentence.get_token(token.idx + label_length).add_tag(self.static_label_type, tars_tag) 

404 

405 return tars_sentence 

406 

407 def _get_state_dict(self): 

408 model_state = { 

409 "state_dict": self.state_dict(), 

410 

411 "current_task": self._current_task, 

412 "tag_type": self.get_current_label_type(), 

413 "tag_dictionary": self.get_current_label_dictionary(), 

414 "tars_model": self.tars_model, 

415 "num_negative_labels_to_sample": self.num_negative_labels_to_sample, 

416 "prefix": self.prefix, 

417 

418 "task_specific_attributes": self._task_specific_attributes, 

419 } 

420 return model_state 

421 

422 @staticmethod 

423 def _fetch_model(model_name) -> str: 

424 

425 if model_name == "tars-ner": 

426 cache_dir = Path("models") 

427 model_name = cached_path("https://nlp.informatik.hu-berlin.de/resources/models/tars-ner/tars-ner.pt", 

428 cache_dir=cache_dir) 

429 

430 return model_name 

431 

432 @staticmethod 

433 def _init_model_with_state_dict(state): 

434 

435 # init new TARS classifier 

436 model = TARSTagger( 

437 task_name=state["current_task"], 

438 label_dictionary=state["tag_dictionary"], 

439 label_type=state["tag_type"], 

440 embeddings=state["tars_model"].embeddings, 

441 num_negative_labels_to_sample=state["num_negative_labels_to_sample"], 

442 prefix=state["prefix"], 

443 ) 

444 # set all task information 

445 model._task_specific_attributes = state["task_specific_attributes"] 

446 

447 # linear layers of internal classifier 

448 model.load_state_dict(state["state_dict"]) 

449 return model 

450 

451 @property 

452 def tars_embeddings(self): 

453 return self.tars_model.embeddings 

454 

455 def predict( 

456 self, 

457 sentences: Union[List[Sentence], Sentence], 

458 mini_batch_size=32, 

459 verbose: bool = False, 

460 label_name: Optional[str] = None, 

461 return_loss=False, 

462 embedding_storage_mode="none", 

463 most_probable_first: bool = True 

464 ): 

465 # return 

466 """ 

467 Predict sequence tags for Named Entity Recognition task 

468 :param sentences: a Sentence or a List of Sentence 

469 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

470 up to a point when it has no more effect. 

471 :param all_tag_prob: True to compute the score for each tag on each token, 

472 otherwise only the score of the best tag is returned 

473 :param verbose: set to True to display a progress bar 

474 :param return_loss: set to True to return loss 

475 :param label_name: set this to change the name of the label type that is predicted 

476 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

477 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

478 'gpu' to store embeddings in GPU memory. 

479 """ 

480 if label_name == None: 

481 label_name = self.get_current_label_type() 

482 

483 # with torch.no_grad(): 

484 if not sentences: 

485 return sentences 

486 

487 if isinstance(sentences, Sentence): 

488 sentences = [sentences] 

489 

490 # reverse sort all sequences by their length 

491 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) 

492 

493 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index] 

494 

495 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size) 

496 

497 # progress bar for verbosity 

498 if verbose: 

499 dataloader = tqdm(dataloader) 

500 

501 overall_loss = 0 

502 overall_count = 0 

503 batch_no = 0 

504 with torch.no_grad(): 

505 for batch in dataloader: 

506 

507 batch_no += 1 

508 

509 if verbose: 

510 dataloader.set_description(f"Inferencing on batch {batch_no}") 

511 

512 batch = self._filter_empty_sentences(batch) 

513 # stop if all sentences are empty 

514 if not batch: 

515 continue 

516 

517 # go through each sentence in the batch 

518 for sentence in batch: 

519 

520 # always remove tags first 

521 for token in sentence: 

522 token.remove_labels(label_name) 

523 

524 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

525 

526 all_detected = {} 

527 for label in all_labels: 

528 tars_sentence = self._get_tars_formatted_sentence(label, sentence) 

529 

530 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" ")) 

531 

532 loss_and_count = self.tars_model.predict(tars_sentence, 

533 label_name=label_name, 

534 all_tag_prob=True, 

535 return_loss=True) 

536 overall_loss += loss_and_count[0].item() 

537 overall_count += loss_and_count[1] 

538 

539 for span in tars_sentence.get_spans(label_name): 

540 span.set_label('tars_temp_label', label) 

541 all_detected[span] = span.score 

542 

543 if not most_probable_first: 

544 for span in tars_sentence.get_spans(label_name): 

545 for token in span: 

546 corresponding_token = sentence.get_token(token.idx - label_length) 

547 if corresponding_token is None: continue 

548 if corresponding_token.get_tag(label_name).value != '' and \ 

549 corresponding_token.get_tag(label_name).score > token.get_tag( 

550 label_name).score: 

551 continue 

552 corresponding_token.add_tag( 

553 label_name, 

554 token.get_tag(label_name).value + label, 

555 token.get_tag(label_name).score, 

556 ) 

557 

558 if most_probable_first: 

559 import operator 

560 sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1)) 

561 sorted_x.reverse() 

562 for tuple in sorted_x: 

563 # get the span and its label 

564 span = tuple[0] 

565 label = span.get_labels('tars_temp_label')[0].value 

566 label_length = 0 if not self.prefix else len(label.split(" ")) + len( 

567 self.separator.split(" ")) 

568 

569 # determine whether tokens in this span already have a label 

570 tag_this = True 

571 for token in span: 

572 corresponding_token = sentence.get_token(token.idx - label_length) 

573 if corresponding_token is None: 

574 tag_this = False 

575 continue 

576 if corresponding_token.get_tag(label_name).value != '' and \ 

577 corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score: 

578 tag_this = False 

579 continue 

580 

581 # only add if all tokens have no label 

582 if tag_this: 

583 for token in span: 

584 corresponding_token = sentence.get_token(token.idx - label_length) 

585 corresponding_token.add_tag( 

586 label_name, 

587 token.get_tag(label_name).value + label, 

588 token.get_tag(label_name).score, 

589 ) 

590 

591 # clearing token embeddings to save memory 

592 store_embeddings(batch, storage_mode=embedding_storage_mode) 

593 

594 if return_loss: 

595 return overall_loss, overall_count 

596 

597 

598class TARSClassifier(FewshotClassifier): 

599 """ 

600 TARS model for text classification. In the backend, the model uses a BERT based binary 

601 text classifier which given a <label, text> pair predicts the probability of two classes 

602 "True", and "False". The input data is a usual Sentence object which is inflated 

603 by the model internally before pushing it through the transformer stack of BERT. 

604 """ 

605 

606 static_label_type = "tars_label" 

607 LABEL_MATCH = "YES" 

608 LABEL_NO_MATCH = "NO" 

609 

610 def __init__( 

611 self, 

612 task_name: Optional[str] = None, 

613 label_dictionary: Optional[Dictionary] = None, 

614 label_type: Optional[str] = None, 

615 embeddings: str = 'bert-base-uncased', 

616 num_negative_labels_to_sample: int = 2, 

617 prefix: bool = True, 

618 **tagger_args, 

619 ): 

620 """ 

621 Initializes a TextClassifier 

622 :param task_name: a string depicting the name of the task 

623 :param label_dictionary: dictionary of labels you want to predict 

624 :param embeddings: name of the pre-trained transformer model e.g., 

625 'bert-base-uncased' etc 

626 :param num_negative_labels_to_sample: number of negative labels to sample for each 

627 positive labels against a sentence during training. Defaults to 2 negative 

628 labels for each positive label. The model would sample all the negative labels 

629 if None is passed. That slows down the training considerably. 

630 :param multi_label: auto-detected by default, but you can set this to True 

631 to force multi-label predictionor False to force single-label prediction 

632 :param multi_label_threshold: If multi-label you can set the threshold to make predictions 

633 :param beta: Parameter for F-beta score for evaluation and training annealing 

634 """ 

635 super(TARSClassifier, self).__init__() 

636 

637 from flair.embeddings import TransformerDocumentEmbeddings 

638 

639 if not isinstance(embeddings, TransformerDocumentEmbeddings): 

640 embeddings = TransformerDocumentEmbeddings(model=embeddings, 

641 fine_tune=True, 

642 layers='-1', 

643 layer_mean=False, 

644 ) 

645 

646 # prepare TARS dictionary 

647 tars_dictionary = Dictionary(add_unk=False) 

648 tars_dictionary.add_item(self.LABEL_NO_MATCH) 

649 tars_dictionary.add_item(self.LABEL_MATCH) 

650 

651 # initialize a bare-bones sequence tagger 

652 self.tars_model = TextClassifier(document_embeddings=embeddings, 

653 label_dictionary=tars_dictionary, 

654 label_type=self.static_label_type, 

655 **tagger_args, 

656 ) 

657 

658 # transformer separator 

659 self.separator = str(self.tars_embeddings.tokenizer.sep_token) 

660 if self.tars_embeddings.tokenizer._bos_token: 

661 self.separator += str(self.tars_embeddings.tokenizer.bos_token) 

662 

663 self.prefix = prefix 

664 self.num_negative_labels_to_sample = num_negative_labels_to_sample 

665 

666 if task_name and label_dictionary and label_type: 

667 # Store task specific labels since TARS can handle multiple tasks 

668 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type) 

669 else: 

670 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() " 

671 "before training this model") 

672 

673 self.clean_up_labels = True 

674 

675 def _clean(self, label_value: str) -> str: 

676 if self.clean_up_labels: 

677 return label_value.replace("_", " ") 

678 else: 

679 return label_value 

680 

681 def _get_tars_formatted_sentence(self, label, sentence): 

682 

683 label = self._clean(label) 

684 

685 original_text = sentence.to_tokenized_string() 

686 

687 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \ 

688 else f"{original_text} {self.separator} {label}" 

689 

690 sentence_labels = [self._clean(label.value) for label in sentence.get_labels(self.get_current_label_type())] 

691 

692 tars_label = self.LABEL_MATCH if label in sentence_labels else self.LABEL_NO_MATCH 

693 

694 tars_sentence = Sentence(label_text_pair, use_tokenizer=False).add_label(self.static_label_type, tars_label) 

695 

696 return tars_sentence 

697 

698 def _get_state_dict(self): 

699 model_state = { 

700 "state_dict": self.state_dict(), 

701 

702 "current_task": self._current_task, 

703 "label_type": self.get_current_label_type(), 

704 "label_dictionary": self.get_current_label_dictionary(), 

705 "tars_model": self.tars_model, 

706 "num_negative_labels_to_sample": self.num_negative_labels_to_sample, 

707 

708 "task_specific_attributes": self._task_specific_attributes, 

709 } 

710 return model_state 

711 

712 @staticmethod 

713 def _init_model_with_state_dict(state): 

714 

715 # init new TARS classifier 

716 label_dictionary = state["label_dictionary"] 

717 label_type = "default_label" if not state["label_type"] else state["label_type"] 

718 

719 model: TARSClassifier = TARSClassifier( 

720 task_name=state["current_task"], 

721 label_dictionary=label_dictionary, 

722 label_type=label_type, 

723 embeddings=state["tars_model"].document_embeddings, 

724 num_negative_labels_to_sample=state["num_negative_labels_to_sample"], 

725 ) 

726 

727 # set all task information 

728 model._task_specific_attributes = state["task_specific_attributes"] 

729 

730 # linear layers of internal classifier 

731 model.load_state_dict(state["state_dict"]) 

732 return model 

733 

734 @staticmethod 

735 def _fetch_model(model_name) -> str: 

736 

737 model_map = {} 

738 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models" 

739 

740 model_map["tars-base"] = "/".join([hu_path, "tars-base", "tars-base-v8.pt"]) 

741 

742 cache_dir = Path("models") 

743 if model_name in model_map: 

744 model_name = cached_path(model_map[model_name], cache_dir=cache_dir) 

745 

746 return model_name 

747 

748 @property 

749 def tars_embeddings(self): 

750 return self.tars_model.document_embeddings 

751 

752 def predict( 

753 self, 

754 sentences: Union[List[Sentence], Sentence], 

755 mini_batch_size=32, 

756 verbose: bool = False, 

757 label_name: Optional[str] = None, 

758 return_loss=False, 

759 embedding_storage_mode="none", 

760 label_threshold: float = 0.5, 

761 multi_label: Optional[bool] = None, 

762 ): 

763 """ 

764 Predict sequence tags for Named Entity Recognition task 

765 :param sentences: a Sentence or a List of Sentence 

766 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory, 

767 up to a point when it has no more effect. 

768 :param all_tag_prob: True to compute the score for each tag on each token, 

769 otherwise only the score of the best tag is returned 

770 :param verbose: set to True to display a progress bar 

771 :param return_loss: set to True to return loss 

772 :param label_name: set this to change the name of the label type that is predicted 

773 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

774 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

775 'gpu' to store embeddings in GPU memory. 

776 """ 

777 if not label_name: 

778 label_name = self.get_current_label_type() 

779 

780 if multi_label is None: 

781 multi_label = self.is_current_task_multi_label() 

782 

783 # with torch.no_grad(): 

784 if not sentences: 

785 return sentences 

786 

787 if isinstance(sentences, Sentence): 

788 sentences = [sentences] 

789 

790 # set context if not set already 

791 previous_sentence = None 

792 for sentence in sentences: 

793 if sentence.is_context_set(): continue 

794 sentence._previous_sentence = previous_sentence 

795 sentence._next_sentence = None 

796 if previous_sentence: previous_sentence._next_sentence = sentence 

797 previous_sentence = sentence 

798 

799 # reverse sort all sequences by their length 

800 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) 

801 

802 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index] 

803 

804 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size) 

805 

806 # progress bar for verbosity 

807 if verbose: 

808 dataloader = tqdm(dataloader) 

809 

810 overall_loss = 0 

811 overall_count = 0 

812 batch_no = 0 

813 with torch.no_grad(): 

814 for batch in dataloader: 

815 

816 batch_no += 1 

817 

818 if verbose: 

819 dataloader.set_description(f"Inferencing on batch {batch_no}") 

820 

821 batch = self._filter_empty_sentences(batch) 

822 # stop if all sentences are empty 

823 if not batch: 

824 continue 

825 

826 # go through each sentence in the batch 

827 for sentence in batch: 

828 

829 # always remove tags first 

830 sentence.remove_labels(label_name) 

831 

832 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item] 

833 

834 best_label = None 

835 for label in all_labels: 

836 tars_sentence = self._get_tars_formatted_sentence(label, sentence) 

837 

838 loss_and_count = self.tars_model.predict(tars_sentence, 

839 label_name=label_name, 

840 return_loss=True, 

841 return_probabilities_for_all_classes=True 

842 if label_threshold < 0.5 else False, 

843 ) 

844 

845 overall_loss += loss_and_count[0].item() 

846 overall_count += loss_and_count[1] 

847 

848 # add all labels that according to TARS match the text and are above threshold 

849 for predicted_tars_label in tars_sentence.get_labels(label_name): 

850 if predicted_tars_label.value == self.LABEL_MATCH \ 

851 and predicted_tars_label.score > label_threshold: 

852 # do not add labels below confidence threshold 

853 sentence.add_label(label_name, label, predicted_tars_label.score) 

854 

855 # only use label with highest confidence if enforcing single-label predictions 

856 if not multi_label: 

857 if len(sentence.get_labels(label_name)) > 0: 

858 # get all label scores and do an argmax to get the best label 

859 label_scores = torch.tensor([label.score for label in sentence.get_labels(label_name)], 

860 dtype=torch.float) 

861 best_label = sentence.get_labels(label_name)[torch.argmax(label_scores)] 

862 

863 # remove previously added labels and only add the best label 

864 sentence.remove_labels(label_name) 

865 sentence.add_label(typename=label_name, value=best_label.value, score=best_label.score) 

866 

867 # clearing token embeddings to save memory 

868 store_embeddings(batch, storage_mode=embedding_storage_mode) 

869 

870 if return_loss: 

871 return overall_loss, overall_count