Coverage for flair/flair/models/tars_model.py: 17%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from collections import OrderedDict
3from pathlib import Path
4from typing import Union, List, Set, Optional
6import numpy as np
7import torch
8from sklearn.metrics.pairwise import cosine_similarity
9from sklearn.preprocessing import minmax_scale
10from tqdm import tqdm
12import flair
13from flair.data import Dictionary, Sentence
14from flair.datasets import SentenceDataset, DataLoader
15from flair.embeddings import TokenEmbeddings
16from flair.file_utils import cached_path
17from flair.models import SequenceTagger, TextClassifier
18from flair.training_utils import store_embeddings
20log = logging.getLogger("flair")
23class FewshotClassifier(flair.nn.Classifier):
25 def __init__(self):
26 self._current_task = None
27 self._task_specific_attributes = {}
28 self.label_nearest_map = None
29 self.clean_up_labels: bool = True
31 super(FewshotClassifier, self).__init__()
33 def forward_loss(
34 self, data_points: Union[List[Sentence], Sentence]
35 ) -> torch.tensor:
37 if type(data_points) == Sentence:
38 data_points = [data_points]
40 # Transform input data into TARS format
41 sentences = self._get_tars_formatted_sentences(data_points)
43 loss = self.tars_model.forward_loss(sentences)
44 return loss
46 @property
47 def tars_embeddings(self):
48 raise NotImplementedError
50 def _get_tars_formatted_sentence(self, label, sentence):
51 raise NotImplementedError
53 def _get_tars_formatted_sentences(self, sentences: List[Sentence]):
54 label_text_pairs = []
55 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
56 # print(all_labels)
57 for sentence in sentences:
58 label_text_pairs_for_sentence = []
59 if self.training and self.num_negative_labels_to_sample is not None:
61 positive_labels = list(OrderedDict.fromkeys(
62 [label.value for label in sentence.get_labels(self.label_type)]))
64 sampled_negative_labels = self._get_nearest_labels_for(positive_labels)
66 for label in positive_labels:
67 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
68 for label in sampled_negative_labels:
69 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
71 else:
72 for label in all_labels:
73 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
74 label_text_pairs.extend(label_text_pairs_for_sentence)
76 return label_text_pairs
78 def _get_nearest_labels_for(self, labels):
80 # if there are no labels, return a random sample as negatives
81 if len(labels) == 0:
82 tags = self.get_current_label_dictionary().get_items()
83 import random
84 sample = random.sample(tags, k=self.num_negative_labels_to_sample)
85 # print(sample)
86 return sample
88 already_sampled_negative_labels = set()
90 # otherwise, go through all labels
91 for label in labels:
93 plausible_labels = []
94 plausible_label_probabilities = []
95 for plausible_label in self.label_nearest_map[label]:
96 if plausible_label in already_sampled_negative_labels or plausible_label in labels:
97 continue
98 else:
99 plausible_labels.append(plausible_label)
100 plausible_label_probabilities.append(self.label_nearest_map[label][plausible_label])
102 # make sure the probabilities always sum up to 1
103 plausible_label_probabilities = np.array(plausible_label_probabilities, dtype='float64')
104 plausible_label_probabilities += 1e-08
105 plausible_label_probabilities /= np.sum(plausible_label_probabilities)
107 if len(plausible_labels) > 0:
108 num_samples = min(self.num_negative_labels_to_sample, len(plausible_labels))
109 sampled_negative_labels = np.random.choice(plausible_labels,
110 num_samples,
111 replace=False,
112 p=plausible_label_probabilities)
113 already_sampled_negative_labels.update(sampled_negative_labels)
115 return already_sampled_negative_labels
117 def train(self, mode=True):
118 """Populate label similarity map based on cosine similarity before running epoch
120 If the `num_negative_labels_to_sample` is set to an integer value then before starting
121 each epoch the model would create a similarity measure between the label names based
122 on cosine distances between their BERT encoded embeddings.
123 """
124 if mode and self.num_negative_labels_to_sample is not None:
125 self._compute_label_similarity_for_current_epoch()
126 super().train(mode)
128 super().train(mode)
130 def _compute_label_similarity_for_current_epoch(self):
131 """
132 Compute the similarity between all labels for better sampling of negatives
133 """
135 # get and embed all labels by making a Sentence object that contains only the label text
136 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
137 label_sentences = [Sentence(label) for label in all_labels]
139 self.tars_embeddings.eval() # TODO: check if this is necessary
140 self.tars_embeddings.embed(label_sentences)
141 self.tars_embeddings.train()
143 # get each label embedding and scale between 0 and 1
144 if isinstance(self.tars_embeddings, TokenEmbeddings):
145 encodings_np = [sentence[0].get_embedding().cpu().detach().numpy() for sentence in label_sentences]
146 else:
147 encodings_np = [sentence.get_embedding().cpu().detach().numpy() for sentence in label_sentences]
149 normalized_encoding = minmax_scale(encodings_np)
151 # compute similarity matrix
152 similarity_matrix = cosine_similarity(normalized_encoding)
154 # the higher the similarity, the greater the chance that a label is
155 # sampled as negative example
156 negative_label_probabilities = {}
157 for row_index, label in enumerate(all_labels):
158 negative_label_probabilities[label] = {}
159 for column_index, other_label in enumerate(all_labels):
160 if label != other_label:
161 negative_label_probabilities[label][other_label] = \
162 similarity_matrix[row_index][column_index]
163 self.label_nearest_map = negative_label_probabilities
165 def get_current_label_dictionary(self):
166 label_dictionary = self._task_specific_attributes[self._current_task]['label_dictionary']
167 if self.clean_up_labels:
168 # default: make new dictionary with modified labels (no underscores)
169 dictionary = Dictionary(add_unk=False)
170 for label in label_dictionary.get_items():
171 dictionary.add_item(label.replace("_", " "))
172 return dictionary
173 else:
174 return label_dictionary
176 def get_current_label_type(self):
177 return self._task_specific_attributes[self._current_task]['label_type']
179 def is_current_task_multi_label(self):
180 return self._task_specific_attributes[self._current_task]['multi_label']
182 def add_and_switch_to_new_task(self,
183 task_name,
184 label_dictionary: Union[List, Set, Dictionary, str],
185 label_type: str,
186 multi_label: bool = True,
187 force_switch: bool = False,
188 ):
189 """
190 Adds a new task to an existing TARS model. Sets necessary attributes and finally 'switches'
191 to the new task. Parameters are similar to the constructor except for model choice, batch
192 size and negative sampling. This method does not store the resultant model onto disk.
193 :param task_name: a string depicting the name of the task
194 :param label_dictionary: dictionary of the labels you want to predict
195 :param label_type: string to identify the label type ('ner', 'sentiment', etc.)
196 :param multi_label: whether this task is a multi-label prediction problem
197 :param force_switch: if True, will overwrite existing task with same name
198 """
199 if task_name in self._task_specific_attributes and not force_switch:
200 log.warning("Task `%s` already exists in TARS model. Switching to it.", task_name)
201 else:
202 # make label dictionary if no Dictionary object is passed
203 if isinstance(label_dictionary, Dictionary):
204 label_dictionary = label_dictionary.get_items()
205 if type(label_dictionary) == str:
206 label_dictionary = [label_dictionary]
208 # prepare dictionary of tags (without B- I- prefixes and without UNK)
209 tag_dictionary = Dictionary(add_unk=False)
210 for tag in label_dictionary:
211 if tag == '<unk>' or tag == 'O': continue
212 if tag[1] == "-":
213 tag = tag[2:]
214 tag_dictionary.add_item(tag)
215 else:
216 tag_dictionary.add_item(tag)
218 self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary,
219 'label_type': label_type,
220 'multi_label': multi_label}
222 self.switch_to_task(task_name)
224 def list_existing_tasks(self) -> Set[str]:
225 """
226 Lists existing tasks in the loaded TARS model on the console.
227 """
228 return set(self._task_specific_attributes.keys())
230 def switch_to_task(self, task_name):
231 """
232 Switches to a task which was previously added.
233 """
234 if task_name not in self._task_specific_attributes:
235 log.error("Provided `%s` does not exist in the model. Consider calling "
236 "`add_and_switch_to_new_task` first.", task_name)
237 else:
238 self._current_task = task_name
240 def _drop_task(self, task_name):
241 if task_name in self._task_specific_attributes:
242 if self._current_task == task_name:
243 log.error("`%s` is the current task."
244 " Switch to some other task before dropping this.", task_name)
245 else:
246 self._task_specific_attributes.pop(task_name)
247 else:
248 log.warning("No task exists with the name `%s`.", task_name)
250 @staticmethod
251 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
252 filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
253 if len(sentences) != len(filtered_sentences):
254 log.warning(
255 f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens."
256 )
257 return filtered_sentences
259 @property
260 def label_type(self):
261 return self.get_current_label_type()
263 def predict_zero_shot(self,
264 sentences: Union[List[Sentence], Sentence],
265 candidate_label_set: Union[List[str], Set[str], str],
266 multi_label: bool = True):
267 """
268 Method to make zero shot predictions from the TARS model
269 :param sentences: input sentence objects to classify
270 :param candidate_label_set: set of candidate labels
271 :param multi_label: indicates whether multi-label or single class prediction. Defaults to True.
272 """
274 # check if candidate_label_set is empty
275 if candidate_label_set is None or len(candidate_label_set) == 0:
276 log.warning("Provided candidate_label_set is empty")
277 return
279 # make list if only one candidate label is passed
280 if isinstance(candidate_label_set, str):
281 candidate_label_set = {candidate_label_set}
283 # create label dictionary
284 label_dictionary = Dictionary(add_unk=False)
285 for label in candidate_label_set:
286 label_dictionary.add_item(label)
288 # note current task
289 existing_current_task = self._current_task
291 # create a temporary task
292 self.add_and_switch_to_new_task(task_name="ZeroShot",
293 label_dictionary=label_dictionary,
294 label_type='-'.join(label_dictionary.get_items()),
295 multi_label=multi_label)
297 try:
298 # make zero shot predictions
299 self.predict(sentences)
300 finally:
301 # switch to the pre-existing task
302 self.switch_to_task(existing_current_task)
303 self._drop_task("ZeroShot")
305 return
308class TARSTagger(FewshotClassifier):
309 """
310 TARS model for sequence tagging. In the backend, the model uses a BERT based 5-class
311 sequence labeler which given a <label, text> pair predicts the probability for each word
312 to belong to one of the BIOES classes. The input data is a usual Sentence object which is inflated
313 by the model internally before pushing it through the transformer stack of BERT.
314 """
316 static_label_type = "tars_label"
318 def __init__(
319 self,
320 task_name: Optional[str] = None,
321 label_dictionary: Optional[Dictionary] = None,
322 label_type: Optional[str] = None,
323 embeddings: str = 'bert-base-uncased',
324 num_negative_labels_to_sample: int = 2,
325 prefix: bool = True,
326 **tagger_args,
327 ):
328 """
329 Initializes a TextClassifier
330 :param task_name: a string depicting the name of the task
331 :param label_dictionary: dictionary of labels you want to predict
332 :param embeddings: name of the pre-trained transformer model e.g.,
333 'bert-base-uncased' etc
334 :param num_negative_labels_to_sample: number of negative labels to sample for each
335 positive labels against a sentence during training. Defaults to 2 negative
336 labels for each positive label. The model would sample all the negative labels
337 if None is passed. That slows down the training considerably.
338 """
339 super(TARSTagger, self).__init__()
341 from flair.embeddings import TransformerWordEmbeddings
343 if not isinstance(embeddings, TransformerWordEmbeddings):
344 embeddings = TransformerWordEmbeddings(model=embeddings,
345 fine_tune=True,
346 layers='-1',
347 layer_mean=False,
348 )
350 # prepare TARS dictionary
351 tars_dictionary = Dictionary(add_unk=False)
352 tars_dictionary.add_item('O')
353 tars_dictionary.add_item('S-')
354 tars_dictionary.add_item('B-')
355 tars_dictionary.add_item('E-')
356 tars_dictionary.add_item('I-')
358 # initialize a bare-bones sequence tagger
359 self.tars_model = SequenceTagger(123,
360 embeddings,
361 tag_dictionary=tars_dictionary,
362 tag_type=self.static_label_type,
363 use_crf=False,
364 use_rnn=False,
365 reproject_embeddings=False,
366 **tagger_args,
367 )
369 # transformer separator
370 self.separator = str(self.tars_embeddings.tokenizer.sep_token)
371 if self.tars_embeddings.tokenizer._bos_token:
372 self.separator += str(self.tars_embeddings.tokenizer.bos_token)
374 self.prefix = prefix
375 self.num_negative_labels_to_sample = num_negative_labels_to_sample
377 if task_name and label_dictionary and label_type:
378 # Store task specific labels since TARS can handle multiple tasks
379 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type)
380 else:
381 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() "
382 "before training this model")
384 def _get_tars_formatted_sentence(self, label, sentence):
386 original_text = sentence.to_tokenized_string()
388 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \
389 else f"{original_text} {self.separator} {label}"
391 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" "))
393 # make a tars sentence where all labels are O by default
394 tars_sentence = Sentence(label_text_pair, use_tokenizer=False)
395 for token in tars_sentence:
396 token.add_tag(self.static_label_type, "O")
398 # overwrite O labels with tags
399 for token in sentence:
400 tag = token.get_tag(self.get_current_label_type()).value
402 if tag == "O" or tag == "":
403 tars_tag = "O"
404 elif tag == label:
405 tars_tag = "S-"
406 elif tag[1] == "-" and tag[2:] == label:
407 tars_tag = tag.split('-')[0] + '-'
408 else:
409 tars_tag = "O"
411 tars_sentence.get_token(token.idx + label_length).add_tag(self.static_label_type, tars_tag)
413 return tars_sentence
415 def _get_state_dict(self):
416 model_state = {
417 "state_dict": self.state_dict(),
419 "current_task": self._current_task,
420 "tag_type": self.get_current_label_type(),
421 "tag_dictionary": self.get_current_label_dictionary(),
422 "tars_model": self.tars_model,
423 "num_negative_labels_to_sample": self.num_negative_labels_to_sample,
424 "prefix": self.prefix,
426 "task_specific_attributes": self._task_specific_attributes,
427 }
428 return model_state
430 @staticmethod
431 def _fetch_model(model_name) -> str:
433 if model_name == "tars-ner":
434 cache_dir = Path("models")
435 model_name = cached_path("https://nlp.informatik.hu-berlin.de/resources/models/tars-ner/tars-ner.pt",
436 cache_dir=cache_dir)
438 return model_name
440 @staticmethod
441 def _init_model_with_state_dict(state):
443 # init new TARS classifier
444 model = TARSTagger(
445 task_name=state["current_task"],
446 label_dictionary=state["tag_dictionary"],
447 label_type=state["tag_type"],
448 embeddings=state["tars_model"].embeddings,
449 num_negative_labels_to_sample=state["num_negative_labels_to_sample"],
450 prefix=state["prefix"],
451 )
452 # set all task information
453 model._task_specific_attributes = state["task_specific_attributes"]
455 # linear layers of internal classifier
456 model.load_state_dict(state["state_dict"])
457 return model
459 @property
460 def tars_embeddings(self):
461 return self.tars_model.embeddings
463 def predict(
464 self,
465 sentences: Union[List[Sentence], Sentence],
466 mini_batch_size=32,
467 verbose: bool = False,
468 label_name: Optional[str] = None,
469 return_loss=False,
470 embedding_storage_mode="none",
471 most_probable_first: bool = True
472 ):
473 # return
474 """
475 Predict sequence tags for Named Entity Recognition task
476 :param sentences: a Sentence or a List of Sentence
477 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
478 up to a point when it has no more effect.
479 :param all_tag_prob: True to compute the score for each tag on each token,
480 otherwise only the score of the best tag is returned
481 :param verbose: set to True to display a progress bar
482 :param return_loss: set to True to return loss
483 :param label_name: set this to change the name of the label type that is predicted
484 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
485 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
486 'gpu' to store embeddings in GPU memory.
487 """
488 if label_name == None:
489 label_name = self.get_current_label_type()
491 # with torch.no_grad():
492 if not sentences:
493 return sentences
495 if isinstance(sentences, Sentence):
496 sentences = [sentences]
498 # reverse sort all sequences by their length
499 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True)
501 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index]
503 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size)
505 # progress bar for verbosity
506 if verbose:
507 dataloader = tqdm(dataloader)
509 overall_loss = 0
510 overall_count = 0
511 batch_no = 0
512 with torch.no_grad():
513 for batch in dataloader:
515 batch_no += 1
517 if verbose:
518 dataloader.set_description(f"Inferencing on batch {batch_no}")
520 batch = self._filter_empty_sentences(batch)
521 # stop if all sentences are empty
522 if not batch:
523 continue
525 # go through each sentence in the batch
526 for sentence in batch:
528 # always remove tags first
529 for token in sentence:
530 token.remove_labels(label_name)
532 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
534 all_detected = {}
535 for label in all_labels:
536 tars_sentence = self._get_tars_formatted_sentence(label, sentence)
538 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" "))
540 loss_and_count = self.tars_model.predict(tars_sentence,
541 label_name=label_name,
542 all_tag_prob=True,
543 return_loss=True)
544 overall_loss += loss_and_count[0].item()
545 overall_count += loss_and_count[1]
547 for span in tars_sentence.get_spans(label_name):
548 span.set_label('tars_temp_label', label)
549 all_detected[span] = span.score
551 if not most_probable_first:
552 for span in tars_sentence.get_spans(label_name):
553 for token in span:
554 corresponding_token = sentence.get_token(token.idx - label_length)
555 if corresponding_token is None: continue
556 if corresponding_token.get_tag(label_name).value != '' and \
557 corresponding_token.get_tag(label_name).score > token.get_tag(
558 label_name).score:
559 continue
560 corresponding_token.add_tag(
561 label_name,
562 token.get_tag(label_name).value + label,
563 token.get_tag(label_name).score,
564 )
566 if most_probable_first:
567 import operator
568 sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
569 sorted_x.reverse()
570 for tuple in sorted_x:
571 # get the span and its label
572 span = tuple[0]
573 label = span.get_labels('tars_temp_label')[0].value
574 label_length = 0 if not self.prefix else len(label.split(" ")) + len(
575 self.separator.split(" "))
577 # determine whether tokens in this span already have a label
578 tag_this = True
579 for token in span:
580 corresponding_token = sentence.get_token(token.idx - label_length)
581 if corresponding_token is None:
582 tag_this = False
583 continue
584 if corresponding_token.get_tag(label_name).value != '' and \
585 corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score:
586 tag_this = False
587 continue
589 # only add if all tokens have no label
590 if tag_this:
591 for token in span:
592 corresponding_token = sentence.get_token(token.idx - label_length)
593 corresponding_token.add_tag(
594 label_name,
595 token.get_tag(label_name).value + label,
596 token.get_tag(label_name).score,
597 )
599 # clearing token embeddings to save memory
600 store_embeddings(batch, storage_mode=embedding_storage_mode)
602 if return_loss:
603 return overall_loss, overall_count
606class TARSClassifier(FewshotClassifier):
607 """
608 TARS model for text classification. In the backend, the model uses a BERT based binary
609 text classifier which given a <label, text> pair predicts the probability of two classes
610 "True", and "False". The input data is a usual Sentence object which is inflated
611 by the model internally before pushing it through the transformer stack of BERT.
612 """
614 static_label_type = "tars_label"
615 LABEL_MATCH = "YES"
616 LABEL_NO_MATCH = "NO"
618 def __init__(
619 self,
620 task_name: Optional[str] = None,
621 label_dictionary: Optional[Dictionary] = None,
622 label_type: Optional[str] = None,
623 embeddings: str = 'bert-base-uncased',
624 num_negative_labels_to_sample: int = 2,
625 prefix: bool = True,
626 **tagger_args,
627 ):
628 """
629 Initializes a TextClassifier
630 :param task_name: a string depicting the name of the task
631 :param label_dictionary: dictionary of labels you want to predict
632 :param embeddings: name of the pre-trained transformer model e.g.,
633 'bert-base-uncased' etc
634 :param num_negative_labels_to_sample: number of negative labels to sample for each
635 positive labels against a sentence during training. Defaults to 2 negative
636 labels for each positive label. The model would sample all the negative labels
637 if None is passed. That slows down the training considerably.
638 :param multi_label: auto-detected by default, but you can set this to True
639 to force multi-label predictionor False to force single-label prediction
640 :param multi_label_threshold: If multi-label you can set the threshold to make predictions
641 :param beta: Parameter for F-beta score for evaluation and training annealing
642 """
643 super(TARSClassifier, self).__init__()
645 from flair.embeddings import TransformerDocumentEmbeddings
647 if not isinstance(embeddings, TransformerDocumentEmbeddings):
648 embeddings = TransformerDocumentEmbeddings(model=embeddings,
649 fine_tune=True,
650 layers='-1',
651 layer_mean=False,
652 )
654 # prepare TARS dictionary
655 tars_dictionary = Dictionary(add_unk=False)
656 tars_dictionary.add_item(self.LABEL_NO_MATCH)
657 tars_dictionary.add_item(self.LABEL_MATCH)
659 # initialize a bare-bones sequence tagger
660 self.tars_model = TextClassifier(document_embeddings=embeddings,
661 label_dictionary=tars_dictionary,
662 label_type=self.static_label_type,
663 **tagger_args,
664 )
666 # transformer separator
667 self.separator = str(self.tars_embeddings.tokenizer.sep_token)
668 if self.tars_embeddings.tokenizer._bos_token:
669 self.separator += str(self.tars_embeddings.tokenizer.bos_token)
671 self.prefix = prefix
672 self.num_negative_labels_to_sample = num_negative_labels_to_sample
674 if task_name and label_dictionary and label_type:
675 # Store task specific labels since TARS can handle multiple tasks
676 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type)
677 else:
678 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() "
679 "before training this model")
681 def _get_tars_formatted_sentence(self, label, sentence):
683 original_text = sentence.to_tokenized_string()
685 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \
686 else f"{original_text} {self.separator} {label}"
688 sentence_labels = [label.value for label in sentence.get_labels(self.get_current_label_type())]
690 tars_label = self.LABEL_MATCH if label in sentence_labels else self.LABEL_NO_MATCH
692 tars_sentence = Sentence(label_text_pair, use_tokenizer=False).add_label(self.static_label_type, tars_label)
694 return tars_sentence
696 def _get_state_dict(self):
697 model_state = {
698 "state_dict": self.state_dict(),
700 "current_task": self._current_task,
701 "label_type": self.get_current_label_type(),
702 "label_dictionary": self.get_current_label_dictionary(),
703 "tars_model": self.tars_model,
704 "num_negative_labels_to_sample": self.num_negative_labels_to_sample,
706 "task_specific_attributes": self._task_specific_attributes,
707 }
708 return model_state
710 @staticmethod
711 def _init_model_with_state_dict(state):
713 # init new TARS classifier
714 label_dictionary = state["label_dictionary"]
715 label_type = "default_label" if not state["label_type"] else state["label_type"]
717 model: TARSClassifier = TARSClassifier(
718 task_name=state["current_task"],
719 label_dictionary=label_dictionary,
720 label_type=label_type,
721 embeddings=state["tars_model"].document_embeddings,
722 num_negative_labels_to_sample=state["num_negative_labels_to_sample"],
723 )
725 # set all task information
726 model._task_specific_attributes = state["task_specific_attributes"]
728 # linear layers of internal classifier
729 model.load_state_dict(state["state_dict"])
730 return model
732 @staticmethod
733 def _fetch_model(model_name) -> str:
735 model_map = {}
736 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
738 model_map["tars-base"] = "/".join([hu_path, "tars-base", "tars-base-v8.pt"])
740 cache_dir = Path("models")
741 if model_name in model_map:
742 model_name = cached_path(model_map[model_name], cache_dir=cache_dir)
744 return model_name
746 @property
747 def tars_embeddings(self):
748 return self.tars_model.document_embeddings
750 def predict(
751 self,
752 sentences: Union[List[Sentence], Sentence],
753 mini_batch_size=32,
754 verbose: bool = False,
755 label_name: Optional[str] = None,
756 return_loss=False,
757 embedding_storage_mode="none",
758 label_threshold: float = 0.5,
759 multi_label: Optional[bool] = None,
760 ):
761 """
762 Predict sequence tags for Named Entity Recognition task
763 :param sentences: a Sentence or a List of Sentence
764 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
765 up to a point when it has no more effect.
766 :param all_tag_prob: True to compute the score for each tag on each token,
767 otherwise only the score of the best tag is returned
768 :param verbose: set to True to display a progress bar
769 :param return_loss: set to True to return loss
770 :param label_name: set this to change the name of the label type that is predicted
771 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
772 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
773 'gpu' to store embeddings in GPU memory.
774 """
775 if not label_name:
776 label_name = self.get_current_label_type()
778 if multi_label is None:
779 multi_label = self.is_current_task_multi_label()
781 # with torch.no_grad():
782 if not sentences:
783 return sentences
785 if isinstance(sentences, Sentence):
786 sentences = [sentences]
788 # set context if not set already
789 previous_sentence = None
790 for sentence in sentences:
791 if sentence.is_context_set(): continue
792 sentence._previous_sentence = previous_sentence
793 sentence._next_sentence = None
794 if previous_sentence: previous_sentence._next_sentence = sentence
795 previous_sentence = sentence
797 # reverse sort all sequences by their length
798 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True)
800 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index]
802 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size)
804 # progress bar for verbosity
805 if verbose:
806 dataloader = tqdm(dataloader)
808 overall_loss = 0
809 overall_count = 0
810 batch_no = 0
811 with torch.no_grad():
812 for batch in dataloader:
814 batch_no += 1
816 if verbose:
817 dataloader.set_description(f"Inferencing on batch {batch_no}")
819 batch = self._filter_empty_sentences(batch)
820 # stop if all sentences are empty
821 if not batch:
822 continue
824 # go through each sentence in the batch
825 for sentence in batch:
827 # always remove tags first
828 sentence.remove_labels(label_name)
830 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
832 best_label = None
833 for label in all_labels:
834 tars_sentence = self._get_tars_formatted_sentence(label, sentence)
836 loss_and_count = self.tars_model.predict(tars_sentence,
837 label_name=label_name,
838 return_loss=True,
839 return_probabilities_for_all_classes=True
840 if label_threshold < 0.5 else False,
841 )
843 overall_loss += loss_and_count[0].item()
844 overall_count += loss_and_count[1]
846 # add all labels that according to TARS match the text and are above threshold
847 for predicted_tars_label in tars_sentence.get_labels(label_name):
848 if predicted_tars_label.value == self.LABEL_MATCH \
849 and predicted_tars_label.score > label_threshold:
850 # do not add labels below confidence threshold
851 sentence.add_label(label_name, label, predicted_tars_label.score)
853 # only use label with highest confidence if enforcing single-label predictions
854 if not multi_label:
855 if len(sentence.get_labels(label_name)) > 0:
857 # get all label scores and do an argmax to get the best label
858 label_scores = torch.tensor([label.score for label in sentence.get_labels(label_name)],
859 dtype=torch.float)
860 best_label = sentence.get_labels(label_name)[torch.argmax(label_scores)]
862 # remove previously added labels and only add the best label
863 sentence.remove_labels(label_name)
864 sentence.add_label(typename=label_name, value=best_label.value, score=best_label.score)
866 # clearing token embeddings to save memory
867 store_embeddings(batch, storage_mode=embedding_storage_mode)
869 if return_loss:
870 return overall_loss, overall_count