Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/models/tars_model.py: 59%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2from collections import OrderedDict
3from pathlib import Path
4from typing import Union, List, Set, Optional
6import numpy as np
7import torch
8from sklearn.metrics.pairwise import cosine_similarity
9from sklearn.preprocessing import minmax_scale
10from tqdm import tqdm
12import flair
13from flair.data import Dictionary, Sentence
14from flair.datasets import SentenceDataset, DataLoader
15from flair.embeddings import TokenEmbeddings
16from flair.file_utils import cached_path
17from flair.models import SequenceTagger, TextClassifier
18from flair.training_utils import store_embeddings
20log = logging.getLogger("flair")
23class FewshotClassifier(flair.nn.Classifier):
25 def __init__(self):
26 self._current_task = None
27 self._task_specific_attributes = {}
28 self.label_nearest_map = None
30 super(FewshotClassifier, self).__init__()
32 def forward_loss(
33 self, data_points: Union[List[Sentence], Sentence]
34 ) -> torch.tensor:
36 if type(data_points) == Sentence:
37 data_points = [data_points]
39 # Transform input data into TARS format
40 sentences = self._get_tars_formatted_sentences(data_points)
42 loss = self.tars_model.forward_loss(sentences)
43 return loss
45 @property
46 def tars_embeddings(self):
47 raise NotImplementedError
49 def _get_tars_formatted_sentence(self, label, sentence):
50 raise NotImplementedError
52 def _get_tars_formatted_sentences(self, sentences: List[Sentence]):
53 label_text_pairs = []
54 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
55 # print(all_labels)
56 for sentence in sentences:
57 label_text_pairs_for_sentence = []
58 if self.training and self.num_negative_labels_to_sample is not None:
60 positive_labels = list(OrderedDict.fromkeys(
61 [label.value for label in sentence.get_labels(self.label_type)]))
63 sampled_negative_labels = self._get_nearest_labels_for(positive_labels)
65 for label in positive_labels:
66 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
67 for label in sampled_negative_labels:
68 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
70 else:
71 for label in all_labels:
72 label_text_pairs_for_sentence.append(self._get_tars_formatted_sentence(label, sentence))
73 label_text_pairs.extend(label_text_pairs_for_sentence)
75 return label_text_pairs
77 def _get_nearest_labels_for(self, labels):
79 # if there are no labels, return a random sample as negatives
80 if len(labels) == 0:
81 tags = self.get_current_label_dictionary().get_items()
82 import random
83 sample = random.sample(tags, k=self.num_negative_labels_to_sample)
84 # print(sample)
85 return sample
87 already_sampled_negative_labels = set()
89 # otherwise, go through all labels
90 for label in labels:
92 plausible_labels = []
93 plausible_label_probabilities = []
94 for plausible_label in self.label_nearest_map[label]:
95 if plausible_label in already_sampled_negative_labels or plausible_label in labels:
96 continue
97 else:
98 plausible_labels.append(plausible_label)
99 plausible_label_probabilities.append(self.label_nearest_map[label][plausible_label])
101 # make sure the probabilities always sum up to 1
102 plausible_label_probabilities = np.array(plausible_label_probabilities, dtype='float64')
103 plausible_label_probabilities += 1e-08
104 plausible_label_probabilities /= np.sum(plausible_label_probabilities)
106 if len(plausible_labels) > 0:
107 num_samples = min(self.num_negative_labels_to_sample, len(plausible_labels))
108 sampled_negative_labels = np.random.choice(plausible_labels,
109 num_samples,
110 replace=False,
111 p=plausible_label_probabilities)
112 already_sampled_negative_labels.update(sampled_negative_labels)
114 return already_sampled_negative_labels
116 def train(self, mode=True):
117 """Populate label similarity map based on cosine similarity before running epoch
119 If the `num_negative_labels_to_sample` is set to an integer value then before starting
120 each epoch the model would create a similarity measure between the label names based
121 on cosine distances between their BERT encoded embeddings.
122 """
123 if mode and self.num_negative_labels_to_sample is not None:
124 self._compute_label_similarity_for_current_epoch()
125 super().train(mode)
127 super().train(mode)
129 def _compute_label_similarity_for_current_epoch(self):
130 """
131 Compute the similarity between all labels for better sampling of negatives
132 """
134 # get and embed all labels by making a Sentence object that contains only the label text
135 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
136 label_sentences = [Sentence(label) for label in all_labels]
138 self.tars_embeddings.eval() # TODO: check if this is necessary
139 self.tars_embeddings.embed(label_sentences)
140 self.tars_embeddings.train()
142 # get each label embedding and scale between 0 and 1
143 if isinstance(self.tars_embeddings, TokenEmbeddings):
144 encodings_np = [sentence[0].get_embedding().cpu().detach().numpy() for sentence in label_sentences]
145 else:
146 encodings_np = [sentence.get_embedding().cpu().detach().numpy() for sentence in label_sentences]
148 normalized_encoding = minmax_scale(encodings_np)
150 # compute similarity matrix
151 similarity_matrix = cosine_similarity(normalized_encoding)
153 # the higher the similarity, the greater the chance that a label is
154 # sampled as negative example
155 negative_label_probabilities = {}
156 for row_index, label in enumerate(all_labels):
157 negative_label_probabilities[label] = {}
158 for column_index, other_label in enumerate(all_labels):
159 if label != other_label:
160 negative_label_probabilities[label][other_label] = \
161 similarity_matrix[row_index][column_index]
162 self.label_nearest_map = negative_label_probabilities
164 def get_current_label_dictionary(self):
165 label_dictionary = self._task_specific_attributes[self._current_task]['label_dictionary']
166 return label_dictionary
168 def get_current_label_type(self):
169 return self._task_specific_attributes[self._current_task]['label_type']
171 def is_current_task_multi_label(self):
172 return self._task_specific_attributes[self._current_task]['multi_label']
174 def add_and_switch_to_new_task(self,
175 task_name,
176 label_dictionary: Union[List, Set, Dictionary, str],
177 label_type: str,
178 multi_label: bool = True,
179 force_switch: bool = False,
180 ):
181 """
182 Adds a new task to an existing TARS model. Sets necessary attributes and finally 'switches'
183 to the new task. Parameters are similar to the constructor except for model choice, batch
184 size and negative sampling. This method does not store the resultant model onto disk.
185 :param task_name: a string depicting the name of the task
186 :param label_dictionary: dictionary of the labels you want to predict
187 :param label_type: string to identify the label type ('ner', 'sentiment', etc.)
188 :param multi_label: whether this task is a multi-label prediction problem
189 :param force_switch: if True, will overwrite existing task with same name
190 """
191 if task_name in self._task_specific_attributes and not force_switch:
192 log.warning("Task `%s` already exists in TARS model. Switching to it.", task_name)
193 else:
194 # make label dictionary if no Dictionary object is passed
195 if isinstance(label_dictionary, Dictionary):
196 label_dictionary = label_dictionary.get_items()
197 if type(label_dictionary) == str:
198 label_dictionary = [label_dictionary]
200 # prepare dictionary of tags (without B- I- prefixes and without UNK)
201 tag_dictionary = Dictionary(add_unk=False)
202 for tag in label_dictionary:
203 if tag == '<unk>' or tag == 'O': continue
204 if tag[1] == "-":
205 tag = tag[2:]
206 tag_dictionary.add_item(tag)
207 else:
208 tag_dictionary.add_item(tag)
210 self._task_specific_attributes[task_name] = {'label_dictionary': tag_dictionary,
211 'label_type': label_type,
212 'multi_label': multi_label}
214 self.switch_to_task(task_name)
216 def list_existing_tasks(self) -> Set[str]:
217 """
218 Lists existing tasks in the loaded TARS model on the console.
219 """
220 return set(self._task_specific_attributes.keys())
222 def switch_to_task(self, task_name):
223 """
224 Switches to a task which was previously added.
225 """
226 if task_name not in self._task_specific_attributes:
227 log.error("Provided `%s` does not exist in the model. Consider calling "
228 "`add_and_switch_to_new_task` first.", task_name)
229 else:
230 self._current_task = task_name
232 def _drop_task(self, task_name):
233 if task_name in self._task_specific_attributes:
234 if self._current_task == task_name:
235 log.error("`%s` is the current task."
236 " Switch to some other task before dropping this.", task_name)
237 else:
238 self._task_specific_attributes.pop(task_name)
239 else:
240 log.warning("No task exists with the name `%s`.", task_name)
242 @staticmethod
243 def _filter_empty_sentences(sentences: List[Sentence]) -> List[Sentence]:
244 filtered_sentences = [sentence for sentence in sentences if sentence.tokens]
245 if len(sentences) != len(filtered_sentences):
246 log.warning(
247 f"Ignore {len(sentences) - len(filtered_sentences)} sentence(s) with no tokens."
248 )
249 return filtered_sentences
251 @property
252 def label_type(self):
253 return self.get_current_label_type()
255 def predict_zero_shot(self,
256 sentences: Union[List[Sentence], Sentence],
257 candidate_label_set: Union[List[str], Set[str], str],
258 multi_label: bool = True):
259 """
260 Method to make zero shot predictions from the TARS model
261 :param sentences: input sentence objects to classify
262 :param candidate_label_set: set of candidate labels
263 :param multi_label: indicates whether multi-label or single class prediction. Defaults to True.
264 """
266 # check if candidate_label_set is empty
267 if candidate_label_set is None or len(candidate_label_set) == 0:
268 log.warning("Provided candidate_label_set is empty")
269 return
271 # make list if only one candidate label is passed
272 if isinstance(candidate_label_set, str):
273 candidate_label_set = {candidate_label_set}
275 # create label dictionary
276 label_dictionary = Dictionary(add_unk=False)
277 for label in candidate_label_set:
278 label_dictionary.add_item(label)
280 # note current task
281 existing_current_task = self._current_task
283 # create a temporary task
284 self.add_and_switch_to_new_task(task_name="ZeroShot",
285 label_dictionary=label_dictionary,
286 label_type='-'.join(label_dictionary.get_items()),
287 multi_label=multi_label)
289 try:
290 # make zero shot predictions
291 self.predict(sentences)
292 finally:
293 # switch to the pre-existing task
294 self.switch_to_task(existing_current_task)
295 self._drop_task("ZeroShot")
297 return
300class TARSTagger(FewshotClassifier):
301 """
302 TARS model for sequence tagging. In the backend, the model uses a BERT based 5-class
303 sequence labeler which given a <label, text> pair predicts the probability for each word
304 to belong to one of the BIOES classes. The input data is a usual Sentence object which is inflated
305 by the model internally before pushing it through the transformer stack of BERT.
306 """
308 static_label_type = "tars_label"
310 def __init__(
311 self,
312 task_name: Optional[str] = None,
313 label_dictionary: Optional[Dictionary] = None,
314 label_type: Optional[str] = None,
315 embeddings: str = 'bert-base-uncased',
316 num_negative_labels_to_sample: int = 2,
317 prefix: bool = True,
318 **tagger_args,
319 ):
320 """
321 Initializes a TextClassifier
322 :param task_name: a string depicting the name of the task
323 :param label_dictionary: dictionary of labels you want to predict
324 :param embeddings: name of the pre-trained transformer model e.g.,
325 'bert-base-uncased' etc
326 :param num_negative_labels_to_sample: number of negative labels to sample for each
327 positive labels against a sentence during training. Defaults to 2 negative
328 labels for each positive label. The model would sample all the negative labels
329 if None is passed. That slows down the training considerably.
330 """
331 super(TARSTagger, self).__init__()
333 from flair.embeddings import TransformerWordEmbeddings
335 if not isinstance(embeddings, TransformerWordEmbeddings):
336 embeddings = TransformerWordEmbeddings(model=embeddings,
337 fine_tune=True,
338 layers='-1',
339 layer_mean=False,
340 )
342 # prepare TARS dictionary
343 tars_dictionary = Dictionary(add_unk=False)
344 tars_dictionary.add_item('O')
345 tars_dictionary.add_item('S-')
346 tars_dictionary.add_item('B-')
347 tars_dictionary.add_item('E-')
348 tars_dictionary.add_item('I-')
350 # initialize a bare-bones sequence tagger
351 self.tars_model = SequenceTagger(123,
352 embeddings,
353 tag_dictionary=tars_dictionary,
354 tag_type=self.static_label_type,
355 use_crf=False,
356 use_rnn=False,
357 reproject_embeddings=False,
358 **tagger_args,
359 )
361 # transformer separator
362 self.separator = str(self.tars_embeddings.tokenizer.sep_token)
363 if self.tars_embeddings.tokenizer._bos_token:
364 self.separator += str(self.tars_embeddings.tokenizer.bos_token)
366 self.prefix = prefix
367 self.num_negative_labels_to_sample = num_negative_labels_to_sample
369 if task_name and label_dictionary and label_type:
370 # Store task specific labels since TARS can handle multiple tasks
371 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type)
372 else:
373 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() "
374 "before training this model")
376 def _get_tars_formatted_sentence(self, label, sentence):
378 original_text = sentence.to_tokenized_string()
380 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \
381 else f"{original_text} {self.separator} {label}"
383 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" "))
385 # make a tars sentence where all labels are O by default
386 tars_sentence = Sentence(label_text_pair, use_tokenizer=False)
387 for token in tars_sentence:
388 token.add_tag(self.static_label_type, "O")
390 # overwrite O labels with tags
391 for token in sentence:
392 tag = token.get_tag(self.get_current_label_type()).value
394 if tag == "O" or tag == "":
395 tars_tag = "O"
396 elif tag == label:
397 tars_tag = "S-"
398 elif tag[1] == "-" and tag[2:] == label:
399 tars_tag = tag.split('-')[0] + '-'
400 else:
401 tars_tag = "O"
403 tars_sentence.get_token(token.idx + label_length).add_tag(self.static_label_type, tars_tag)
405 return tars_sentence
407 def _get_state_dict(self):
408 model_state = {
409 "state_dict": self.state_dict(),
411 "current_task": self._current_task,
412 "tag_type": self.get_current_label_type(),
413 "tag_dictionary": self.get_current_label_dictionary(),
414 "tars_model": self.tars_model,
415 "num_negative_labels_to_sample": self.num_negative_labels_to_sample,
416 "prefix": self.prefix,
418 "task_specific_attributes": self._task_specific_attributes,
419 }
420 return model_state
422 @staticmethod
423 def _fetch_model(model_name) -> str:
425 if model_name == "tars-ner":
426 cache_dir = Path("models")
427 model_name = cached_path("https://nlp.informatik.hu-berlin.de/resources/models/tars-ner/tars-ner.pt",
428 cache_dir=cache_dir)
430 return model_name
432 @staticmethod
433 def _init_model_with_state_dict(state):
435 # init new TARS classifier
436 model = TARSTagger(
437 task_name=state["current_task"],
438 label_dictionary=state["tag_dictionary"],
439 label_type=state["tag_type"],
440 embeddings=state["tars_model"].embeddings,
441 num_negative_labels_to_sample=state["num_negative_labels_to_sample"],
442 prefix=state["prefix"],
443 )
444 # set all task information
445 model._task_specific_attributes = state["task_specific_attributes"]
447 # linear layers of internal classifier
448 model.load_state_dict(state["state_dict"])
449 return model
451 @property
452 def tars_embeddings(self):
453 return self.tars_model.embeddings
455 def predict(
456 self,
457 sentences: Union[List[Sentence], Sentence],
458 mini_batch_size=32,
459 verbose: bool = False,
460 label_name: Optional[str] = None,
461 return_loss=False,
462 embedding_storage_mode="none",
463 most_probable_first: bool = True
464 ):
465 # return
466 """
467 Predict sequence tags for Named Entity Recognition task
468 :param sentences: a Sentence or a List of Sentence
469 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
470 up to a point when it has no more effect.
471 :param all_tag_prob: True to compute the score for each tag on each token,
472 otherwise only the score of the best tag is returned
473 :param verbose: set to True to display a progress bar
474 :param return_loss: set to True to return loss
475 :param label_name: set this to change the name of the label type that is predicted
476 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
477 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
478 'gpu' to store embeddings in GPU memory.
479 """
480 if label_name == None:
481 label_name = self.get_current_label_type()
483 # with torch.no_grad():
484 if not sentences:
485 return sentences
487 if isinstance(sentences, Sentence):
488 sentences = [sentences]
490 # reverse sort all sequences by their length
491 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True)
493 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index]
495 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size)
497 # progress bar for verbosity
498 if verbose:
499 dataloader = tqdm(dataloader)
501 overall_loss = 0
502 overall_count = 0
503 batch_no = 0
504 with torch.no_grad():
505 for batch in dataloader:
507 batch_no += 1
509 if verbose:
510 dataloader.set_description(f"Inferencing on batch {batch_no}")
512 batch = self._filter_empty_sentences(batch)
513 # stop if all sentences are empty
514 if not batch:
515 continue
517 # go through each sentence in the batch
518 for sentence in batch:
520 # always remove tags first
521 for token in sentence:
522 token.remove_labels(label_name)
524 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
526 all_detected = {}
527 for label in all_labels:
528 tars_sentence = self._get_tars_formatted_sentence(label, sentence)
530 label_length = 0 if not self.prefix else len(label.split(" ")) + len(self.separator.split(" "))
532 loss_and_count = self.tars_model.predict(tars_sentence,
533 label_name=label_name,
534 all_tag_prob=True,
535 return_loss=True)
536 overall_loss += loss_and_count[0].item()
537 overall_count += loss_and_count[1]
539 for span in tars_sentence.get_spans(label_name):
540 span.set_label('tars_temp_label', label)
541 all_detected[span] = span.score
543 if not most_probable_first:
544 for span in tars_sentence.get_spans(label_name):
545 for token in span:
546 corresponding_token = sentence.get_token(token.idx - label_length)
547 if corresponding_token is None: continue
548 if corresponding_token.get_tag(label_name).value != '' and \
549 corresponding_token.get_tag(label_name).score > token.get_tag(
550 label_name).score:
551 continue
552 corresponding_token.add_tag(
553 label_name,
554 token.get_tag(label_name).value + label,
555 token.get_tag(label_name).score,
556 )
558 if most_probable_first:
559 import operator
560 sorted_x = sorted(all_detected.items(), key=operator.itemgetter(1))
561 sorted_x.reverse()
562 for tuple in sorted_x:
563 # get the span and its label
564 span = tuple[0]
565 label = span.get_labels('tars_temp_label')[0].value
566 label_length = 0 if not self.prefix else len(label.split(" ")) + len(
567 self.separator.split(" "))
569 # determine whether tokens in this span already have a label
570 tag_this = True
571 for token in span:
572 corresponding_token = sentence.get_token(token.idx - label_length)
573 if corresponding_token is None:
574 tag_this = False
575 continue
576 if corresponding_token.get_tag(label_name).value != '' and \
577 corresponding_token.get_tag(label_name).score > token.get_tag(label_name).score:
578 tag_this = False
579 continue
581 # only add if all tokens have no label
582 if tag_this:
583 for token in span:
584 corresponding_token = sentence.get_token(token.idx - label_length)
585 corresponding_token.add_tag(
586 label_name,
587 token.get_tag(label_name).value + label,
588 token.get_tag(label_name).score,
589 )
591 # clearing token embeddings to save memory
592 store_embeddings(batch, storage_mode=embedding_storage_mode)
594 if return_loss:
595 return overall_loss, overall_count
598class TARSClassifier(FewshotClassifier):
599 """
600 TARS model for text classification. In the backend, the model uses a BERT based binary
601 text classifier which given a <label, text> pair predicts the probability of two classes
602 "True", and "False". The input data is a usual Sentence object which is inflated
603 by the model internally before pushing it through the transformer stack of BERT.
604 """
606 static_label_type = "tars_label"
607 LABEL_MATCH = "YES"
608 LABEL_NO_MATCH = "NO"
610 def __init__(
611 self,
612 task_name: Optional[str] = None,
613 label_dictionary: Optional[Dictionary] = None,
614 label_type: Optional[str] = None,
615 embeddings: str = 'bert-base-uncased',
616 num_negative_labels_to_sample: int = 2,
617 prefix: bool = True,
618 **tagger_args,
619 ):
620 """
621 Initializes a TextClassifier
622 :param task_name: a string depicting the name of the task
623 :param label_dictionary: dictionary of labels you want to predict
624 :param embeddings: name of the pre-trained transformer model e.g.,
625 'bert-base-uncased' etc
626 :param num_negative_labels_to_sample: number of negative labels to sample for each
627 positive labels against a sentence during training. Defaults to 2 negative
628 labels for each positive label. The model would sample all the negative labels
629 if None is passed. That slows down the training considerably.
630 :param multi_label: auto-detected by default, but you can set this to True
631 to force multi-label predictionor False to force single-label prediction
632 :param multi_label_threshold: If multi-label you can set the threshold to make predictions
633 :param beta: Parameter for F-beta score for evaluation and training annealing
634 """
635 super(TARSClassifier, self).__init__()
637 from flair.embeddings import TransformerDocumentEmbeddings
639 if not isinstance(embeddings, TransformerDocumentEmbeddings):
640 embeddings = TransformerDocumentEmbeddings(model=embeddings,
641 fine_tune=True,
642 layers='-1',
643 layer_mean=False,
644 )
646 # prepare TARS dictionary
647 tars_dictionary = Dictionary(add_unk=False)
648 tars_dictionary.add_item(self.LABEL_NO_MATCH)
649 tars_dictionary.add_item(self.LABEL_MATCH)
651 # initialize a bare-bones sequence tagger
652 self.tars_model = TextClassifier(document_embeddings=embeddings,
653 label_dictionary=tars_dictionary,
654 label_type=self.static_label_type,
655 **tagger_args,
656 )
658 # transformer separator
659 self.separator = str(self.tars_embeddings.tokenizer.sep_token)
660 if self.tars_embeddings.tokenizer._bos_token:
661 self.separator += str(self.tars_embeddings.tokenizer.bos_token)
663 self.prefix = prefix
664 self.num_negative_labels_to_sample = num_negative_labels_to_sample
666 if task_name and label_dictionary and label_type:
667 # Store task specific labels since TARS can handle multiple tasks
668 self.add_and_switch_to_new_task(task_name, label_dictionary, label_type)
669 else:
670 log.info("TARS initialized without a task. You need to call .add_and_switch_to_new_task() "
671 "before training this model")
673 self.clean_up_labels = True
675 def _clean(self, label_value: str) -> str:
676 if self.clean_up_labels:
677 return label_value.replace("_", " ")
678 else:
679 return label_value
681 def _get_tars_formatted_sentence(self, label, sentence):
683 label = self._clean(label)
685 original_text = sentence.to_tokenized_string()
687 label_text_pair = f"{label} {self.separator} {original_text}" if self.prefix \
688 else f"{original_text} {self.separator} {label}"
690 sentence_labels = [self._clean(label.value) for label in sentence.get_labels(self.get_current_label_type())]
692 tars_label = self.LABEL_MATCH if label in sentence_labels else self.LABEL_NO_MATCH
694 tars_sentence = Sentence(label_text_pair, use_tokenizer=False).add_label(self.static_label_type, tars_label)
696 return tars_sentence
698 def _get_state_dict(self):
699 model_state = {
700 "state_dict": self.state_dict(),
702 "current_task": self._current_task,
703 "label_type": self.get_current_label_type(),
704 "label_dictionary": self.get_current_label_dictionary(),
705 "tars_model": self.tars_model,
706 "num_negative_labels_to_sample": self.num_negative_labels_to_sample,
708 "task_specific_attributes": self._task_specific_attributes,
709 }
710 return model_state
712 @staticmethod
713 def _init_model_with_state_dict(state):
715 # init new TARS classifier
716 label_dictionary = state["label_dictionary"]
717 label_type = "default_label" if not state["label_type"] else state["label_type"]
719 model: TARSClassifier = TARSClassifier(
720 task_name=state["current_task"],
721 label_dictionary=label_dictionary,
722 label_type=label_type,
723 embeddings=state["tars_model"].document_embeddings,
724 num_negative_labels_to_sample=state["num_negative_labels_to_sample"],
725 )
727 # set all task information
728 model._task_specific_attributes = state["task_specific_attributes"]
730 # linear layers of internal classifier
731 model.load_state_dict(state["state_dict"])
732 return model
734 @staticmethod
735 def _fetch_model(model_name) -> str:
737 model_map = {}
738 hu_path: str = "https://nlp.informatik.hu-berlin.de/resources/models"
740 model_map["tars-base"] = "/".join([hu_path, "tars-base", "tars-base-v8.pt"])
742 cache_dir = Path("models")
743 if model_name in model_map:
744 model_name = cached_path(model_map[model_name], cache_dir=cache_dir)
746 return model_name
748 @property
749 def tars_embeddings(self):
750 return self.tars_model.document_embeddings
752 def predict(
753 self,
754 sentences: Union[List[Sentence], Sentence],
755 mini_batch_size=32,
756 verbose: bool = False,
757 label_name: Optional[str] = None,
758 return_loss=False,
759 embedding_storage_mode="none",
760 label_threshold: float = 0.5,
761 multi_label: Optional[bool] = None,
762 ):
763 """
764 Predict sequence tags for Named Entity Recognition task
765 :param sentences: a Sentence or a List of Sentence
766 :param mini_batch_size: size of the minibatch, usually bigger is more rapid but consume more memory,
767 up to a point when it has no more effect.
768 :param all_tag_prob: True to compute the score for each tag on each token,
769 otherwise only the score of the best tag is returned
770 :param verbose: set to True to display a progress bar
771 :param return_loss: set to True to return loss
772 :param label_name: set this to change the name of the label type that is predicted
773 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
774 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
775 'gpu' to store embeddings in GPU memory.
776 """
777 if not label_name:
778 label_name = self.get_current_label_type()
780 if multi_label is None:
781 multi_label = self.is_current_task_multi_label()
783 # with torch.no_grad():
784 if not sentences:
785 return sentences
787 if isinstance(sentences, Sentence):
788 sentences = [sentences]
790 # set context if not set already
791 previous_sentence = None
792 for sentence in sentences:
793 if sentence.is_context_set(): continue
794 sentence._previous_sentence = previous_sentence
795 sentence._next_sentence = None
796 if previous_sentence: previous_sentence._next_sentence = sentence
797 previous_sentence = sentence
799 # reverse sort all sequences by their length
800 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True)
802 reordered_sentences: List[Union[Sentence, str]] = [sentences[index] for index in rev_order_len_index]
804 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size)
806 # progress bar for verbosity
807 if verbose:
808 dataloader = tqdm(dataloader)
810 overall_loss = 0
811 overall_count = 0
812 batch_no = 0
813 with torch.no_grad():
814 for batch in dataloader:
816 batch_no += 1
818 if verbose:
819 dataloader.set_description(f"Inferencing on batch {batch_no}")
821 batch = self._filter_empty_sentences(batch)
822 # stop if all sentences are empty
823 if not batch:
824 continue
826 # go through each sentence in the batch
827 for sentence in batch:
829 # always remove tags first
830 sentence.remove_labels(label_name)
832 all_labels = [label.decode("utf-8") for label in self.get_current_label_dictionary().idx2item]
834 best_label = None
835 for label in all_labels:
836 tars_sentence = self._get_tars_formatted_sentence(label, sentence)
838 loss_and_count = self.tars_model.predict(tars_sentence,
839 label_name=label_name,
840 return_loss=True,
841 return_probabilities_for_all_classes=True
842 if label_threshold < 0.5 else False,
843 )
845 overall_loss += loss_and_count[0].item()
846 overall_count += loss_and_count[1]
848 # add all labels that according to TARS match the text and are above threshold
849 for predicted_tars_label in tars_sentence.get_labels(label_name):
850 if predicted_tars_label.value == self.LABEL_MATCH \
851 and predicted_tars_label.score > label_threshold:
852 # do not add labels below confidence threshold
853 sentence.add_label(label_name, label, predicted_tars_label.score)
855 # only use label with highest confidence if enforcing single-label predictions
856 if not multi_label:
857 if len(sentence.get_labels(label_name)) > 0:
858 # get all label scores and do an argmax to get the best label
859 label_scores = torch.tensor([label.score for label in sentence.get_labels(label_name)],
860 dtype=torch.float)
861 best_label = sentence.get_labels(label_name)[torch.argmax(label_scores)]
863 # remove previously added labels and only add the best label
864 sentence.remove_labels(label_name)
865 sentence.add_label(typename=label_name, value=best_label.value, score=best_label.score)
867 # clearing token embeddings to save memory
868 store_embeddings(batch, storage_mode=embedding_storage_mode)
870 if return_loss:
871 return overall_loss, overall_count