Coverage for /home/ubuntu/Documents/Research/mut_p1/flair/flair/nn/model.py: 71%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import itertools
2import logging
3import warnings
4from abc import abstractmethod
5from collections import Counter
6from pathlib import Path
7from typing import Union, List, Tuple, Dict, Optional
9import torch.nn
10from torch.utils.data.dataset import Dataset
11from tqdm import tqdm
13import flair
14from flair import file_utils
15from flair.data import DataPoint, Sentence, Dictionary, SpanLabel
16from flair.datasets import DataLoader, SentenceDataset
17from flair.training_utils import Result, store_embeddings
19log = logging.getLogger("flair")
22class Model(torch.nn.Module):
23 """Abstract base class for all downstream task models in Flair, such as SequenceTagger and TextClassifier.
24 Every new type of model must implement these methods."""
26 @property
27 @abstractmethod
28 def label_type(self):
29 """Each model predicts labels of a certain type. TODO: can we find a better name for this?"""
30 raise NotImplementedError
32 @abstractmethod
33 def forward_loss(self, data_points: Union[List[DataPoint], DataPoint]) -> torch.tensor:
34 """Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training."""
35 raise NotImplementedError
37 @abstractmethod
38 def evaluate(
39 self,
40 sentences: Union[List[Sentence], Dataset],
41 gold_label_type: str,
42 out_path: Union[str, Path] = None,
43 embedding_storage_mode: str = "none",
44 mini_batch_size: int = 32,
45 num_workers: int = 8,
46 main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
47 exclude_labels: List[str] = [],
48 gold_label_dictionary: Optional[Dictionary] = None,
49 ) -> Result:
50 """Evaluates the model. Returns a Result object containing evaluation
51 results and a loss value. Implement this to enable evaluation.
52 :param data_loader: DataLoader that iterates over dataset to be evaluated
53 :param out_path: Optional output path to store predictions
54 :param embedding_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and
55 freshly recomputed, 'cpu' means all embeddings are stored on CPU, or 'gpu' means all embeddings are stored on GPU
56 :return: Returns a Tuple consisting of a Result object and a loss float value
57 """
58 raise NotImplementedError
60 @abstractmethod
61 def _get_state_dict(self):
62 """Returns the state dictionary for this model. Implementing this enables the save() and save_checkpoint()
63 functionality."""
64 raise NotImplementedError
66 @staticmethod
67 @abstractmethod
68 def _init_model_with_state_dict(state):
69 """Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint()
70 functionality."""
71 raise NotImplementedError
73 @staticmethod
74 def _fetch_model(model_name) -> str:
75 return model_name
77 def save(self, model_file: Union[str, Path], checkpoint: bool = False):
78 """
79 Saves the current model to the provided file.
80 :param model_file: the model file
81 """
82 model_state = self._get_state_dict()
84 # in Flair <0.9.1, optimizer and scheduler used to train model are not saved
85 optimizer = scheduler = None
87 # write out a "model card" if one is set
88 if hasattr(self, 'model_card'):
90 # special handling for optimizer: remember optimizer class and state dictionary
91 if 'training_parameters' in self.model_card:
92 training_parameters = self.model_card['training_parameters']
94 if 'optimizer' in training_parameters:
95 optimizer = training_parameters['optimizer']
96 if checkpoint:
97 training_parameters['optimizer_state_dict'] = optimizer.state_dict()
98 training_parameters['optimizer'] = optimizer.__class__
100 if 'scheduler' in training_parameters:
101 scheduler = training_parameters['scheduler']
102 if checkpoint:
103 with warnings.catch_warnings():
104 warnings.simplefilter("ignore")
105 training_parameters['scheduler_state_dict'] = scheduler.state_dict()
106 training_parameters['scheduler'] = scheduler.__class__
108 model_state['model_card'] = self.model_card
110 # save model
111 torch.save(model_state, str(model_file), pickle_protocol=4)
113 # restore optimizer and scheduler to model card if set
114 if optimizer:
115 self.model_card['training_parameters']['optimizer'] = optimizer
116 if scheduler:
117 self.model_card['training_parameters']['scheduler'] = scheduler
119 @classmethod
120 def load(cls, model: Union[str, Path]):
121 """
122 Loads the model from the given file.
123 :param model: the model file
124 :return: the loaded text classifier model
125 """
126 model_file = cls._fetch_model(str(model))
128 with warnings.catch_warnings():
129 warnings.filterwarnings("ignore")
130 # load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups
131 # see https://github.com/zalandoresearch/flair/issues/351
132 f = file_utils.load_big_file(str(model_file))
133 state = torch.load(f, map_location='cpu')
135 model = cls._init_model_with_state_dict(state)
137 if 'model_card' in state:
138 model.model_card = state['model_card']
140 model.eval()
141 model.to(flair.device)
143 return model
145 def print_model_card(self):
146 if hasattr(self, 'model_card'):
147 param_out = "\n------------------------------------\n"
148 param_out += "--------- Flair Model Card ---------\n"
149 param_out += "------------------------------------\n"
150 param_out += "- this Flair model was trained with:\n"
151 param_out += f"-- Flair version {self.model_card['flair_version']}\n"
152 param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n"
153 if 'transformers_version' in self.model_card:
154 param_out += f"-- Transformers version {self.model_card['transformers_version']}\n"
155 param_out += "------------------------------------\n"
157 param_out += "------- Training Parameters: -------\n"
158 param_out += "------------------------------------\n"
159 training_params = '\n'.join(f'-- {param} = {self.model_card["training_parameters"][param]}'
160 for param in self.model_card['training_parameters'])
161 param_out += training_params + "\n"
162 param_out += "------------------------------------\n"
164 log.info(param_out)
165 else:
166 log.info(
167 "This model has no model card (likely because it is not yet trained or was trained with Flair version < 0.9.1)")
170class Classifier(Model):
171 """Abstract base class for all Flair models that do classification, both single- and multi-label.
172 It inherits from flair.nn.Model and adds a unified evaluate() function so that all classification models
173 use the same evaluation routines and compute the same numbers.
174 Currently, the SequenceTagger implements this class directly, while all other classifiers in Flair
175 implement the DefaultClassifier base class which implements Classifier."""
177 def evaluate(
178 self,
179 data_points: Union[List[DataPoint], Dataset],
180 gold_label_type: str,
181 out_path: Union[str, Path] = None,
182 embedding_storage_mode: str = "none",
183 mini_batch_size: int = 32,
184 num_workers: int = 8,
185 main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"),
186 exclude_labels: List[str] = [],
187 gold_label_dictionary: Optional[Dictionary] = None,
188 ) -> Result:
189 import numpy as np
190 import sklearn
192 # read Dataset into data loader (if list of sentences passed, make Dataset first)
193 if not isinstance(data_points, Dataset):
194 data_points = SentenceDataset(data_points)
195 data_loader = DataLoader(data_points, batch_size=mini_batch_size, num_workers=num_workers)
197 with torch.no_grad():
199 # loss calculation
200 eval_loss = 0
201 average_over = 0
203 # variables for printing
204 lines: List[str] = []
205 is_word_level = False
207 # variables for computing scores
208 all_spans: List[str] = []
209 all_true_values = {}
210 all_predicted_values = {}
212 sentence_id = 0
213 for batch in data_loader:
215 # remove any previously predicted labels
216 for datapoint in batch:
217 datapoint.remove_labels('predicted')
219 # predict for batch
220 loss_and_count = self.predict(batch,
221 embedding_storage_mode=embedding_storage_mode,
222 mini_batch_size=mini_batch_size,
223 label_name='predicted',
224 return_loss=True)
226 if isinstance(loss_and_count, Tuple):
227 average_over += loss_and_count[1]
228 eval_loss += loss_and_count[0]
229 else:
230 eval_loss += loss_and_count
232 # get the gold labels
233 for datapoint in batch:
235 for gold_label in datapoint.get_labels(gold_label_type):
236 representation = str(sentence_id) + ': ' + gold_label.identifier
238 value = gold_label.value
239 if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0:
240 value = '<unk>'
242 if representation not in all_true_values:
243 all_true_values[representation] = [value]
244 else:
245 all_true_values[representation].append(value)
247 if representation not in all_spans:
248 all_spans.append(representation)
250 if type(gold_label) == SpanLabel: is_word_level = True
252 for predicted_span in datapoint.get_labels("predicted"):
253 representation = str(sentence_id) + ': ' + predicted_span.identifier
255 # add to all_predicted_values
256 if representation not in all_predicted_values:
257 all_predicted_values[representation] = [predicted_span.value]
258 else:
259 all_predicted_values[representation].append(predicted_span.value)
261 if representation not in all_spans:
262 all_spans.append(representation)
264 sentence_id += 1
266 store_embeddings(batch, embedding_storage_mode)
268 # make printout lines
269 if out_path:
270 for datapoint in batch:
272 # if the model is span-level, transfer to word-level annotations for printout
273 if is_word_level:
275 # all labels default to "O"
276 for token in datapoint:
277 token.set_label("gold_bio", "O")
278 token.set_label("predicted_bio", "O")
280 # set gold token-level
281 for gold_label in datapoint.get_labels(gold_label_type):
282 gold_label: SpanLabel = gold_label
283 prefix = "B-"
284 for token in gold_label.span:
285 token.set_label("gold_bio", prefix + gold_label.value)
286 prefix = "I-"
288 # set predicted token-level
289 for predicted_label in datapoint.get_labels("predicted"):
290 predicted_label: SpanLabel = predicted_label
291 prefix = "B-"
292 for token in predicted_label.span:
293 token.set_label("predicted_bio", prefix + predicted_label.value)
294 prefix = "I-"
296 # now print labels in CoNLL format
297 for token in datapoint:
298 eval_line = f"{token.text} " \
299 f"{token.get_tag('gold_bio').value} " \
300 f"{token.get_tag('predicted_bio').value}\n"
301 lines.append(eval_line)
302 lines.append("\n")
303 else:
304 # check if there is a label mismatch
305 g = [label.identifier + label.value for label in datapoint.get_labels(gold_label_type)]
306 p = [label.identifier + label.value for label in datapoint.get_labels('predicted')]
307 g.sort()
308 p.sort()
309 correct_string = " -> MISMATCH!\n" if g != p else ""
310 # print info
311 eval_line = f"{datapoint.to_original_text()}\n" \
312 f" - Gold: {datapoint.get_labels(gold_label_type)}\n" \
313 f" - Pred: {datapoint.get_labels('predicted')}\n{correct_string}\n"
314 lines.append(eval_line)
316 # write all_predicted_values to out_file if set
317 if out_path:
318 with open(Path(out_path), "w", encoding="utf-8") as outfile:
319 outfile.write("".join(lines))
321 # make the evaluation dictionary
322 evaluation_label_dictionary = Dictionary(add_unk=False)
323 evaluation_label_dictionary.add_item("O")
324 for true_values in all_true_values.values():
325 for label in true_values:
326 evaluation_label_dictionary.add_item(label)
327 for predicted_values in all_predicted_values.values():
328 for label in predicted_values:
329 evaluation_label_dictionary.add_item(label)
331 # finally, compute numbers
332 y_true = []
333 y_pred = []
335 for span in all_spans:
337 true_values = all_true_values[span] if span in all_true_values else ['O']
338 predicted_values = all_predicted_values[span] if span in all_predicted_values else ['O']
340 y_true_instance = np.zeros(len(evaluation_label_dictionary), dtype=int)
341 for true_value in true_values:
342 y_true_instance[evaluation_label_dictionary.get_idx_for_item(true_value)] = 1
343 y_true.append(y_true_instance.tolist())
345 y_pred_instance = np.zeros(len(evaluation_label_dictionary), dtype=int)
346 for predicted_value in predicted_values:
347 y_pred_instance[evaluation_label_dictionary.get_idx_for_item(predicted_value)] = 1
348 y_pred.append(y_pred_instance.tolist())
350 # now, calculate evaluation numbers
351 target_names = []
352 labels = []
354 counter = Counter()
355 counter.update(list(itertools.chain.from_iterable(all_true_values.values())))
356 counter.update(list(itertools.chain.from_iterable(all_predicted_values.values())))
358 for label_name, count in counter.most_common():
359 if label_name == 'O': continue
360 if label_name in exclude_labels: continue
361 target_names.append(label_name)
362 labels.append(evaluation_label_dictionary.get_idx_for_item(label_name))
364 # there is at least one gold label or one prediction (default)
365 if len(all_true_values) + len(all_predicted_values) > 1:
366 classification_report = sklearn.metrics.classification_report(
367 y_true, y_pred, digits=4, target_names=target_names, zero_division=0, labels=labels,
368 )
370 classification_report_dict = sklearn.metrics.classification_report(
371 y_true, y_pred, target_names=target_names, zero_division=0, output_dict=True, labels=labels,
372 )
374 accuracy_score = round(sklearn.metrics.accuracy_score(y_true, y_pred), 4)
376 precision_score = round(classification_report_dict["micro avg"]["precision"], 4)
377 recall_score = round(classification_report_dict["micro avg"]["recall"], 4)
378 micro_f_score = round(classification_report_dict["micro avg"]["f1-score"], 4)
379 macro_f_score = round(classification_report_dict["macro avg"]["f1-score"], 4)
381 main_score = classification_report_dict[main_evaluation_metric[0]][main_evaluation_metric[1]]
383 else:
384 # issue error and default all evaluation numbers to 0.
385 log.error(
386 "ACHTUNG! No gold labels and no all_predicted_values found! Could be an error in your corpus or how you "
387 "initialize the trainer!")
388 accuracy_score = precision_score = recall_score = micro_f_score = macro_f_score = main_score = 0.
389 classification_report = ""
390 classification_report_dict = {}
392 detailed_result = (
393 "\nResults:"
394 f"\n- F-score (micro) {micro_f_score}"
395 f"\n- F-score (macro) {macro_f_score}"
396 f"\n- Accuracy {accuracy_score}"
397 "\n\nBy class:\n" + classification_report
398 )
400 # line for log file
401 log_header = "PRECISION\tRECALL\tF1\tACCURACY"
402 log_line = f"{precision_score}\t" f"{recall_score}\t" f"{micro_f_score}\t" f"{accuracy_score}"
404 if average_over > 0:
405 eval_loss /= average_over
407 result = Result(
408 main_score=main_score,
409 log_line=log_line,
410 log_header=log_header,
411 detailed_results=detailed_result,
412 classification_report=classification_report_dict,
413 loss=eval_loss
414 )
416 return result
419class DefaultClassifier(Classifier):
420 """Default base class for all Flair models that do classification, both single- and multi-label.
421 It inherits from flair.nn.Classifier and thus from flair.nn.Model. All features shared by all classifiers
422 are implemented here, including the loss calculation and the predict() method.
423 Currently, the TextClassifier, RelationExtractor, TextPairClassifier and SimpleSequenceTagger implement
424 this class. You only need to implement the forward_pass() method to implement this base class.
425 """
427 def forward_pass(self,
428 sentences: Union[List[DataPoint], DataPoint],
429 return_label_candidates: bool = False,
430 ):
431 """This method does a forward pass through the model given a list of data points as input.
432 Returns the tuple (scores, labels) if return_label_candidates = False, where scores are a tensor of logits
433 produced by the decoder and labels are the string labels for each data point.
434 Returns the tuple (scores, labels, data_points, candidate_labels) if return_label_candidates = True,
435 where data_points are the data points to which labels are added (commonly either Sentence or Token objects)
436 and candidate_labels are empty Label objects for each prediction (depending on the task Label,
437 SpanLabel or RelationLabel)."""
438 raise NotImplementedError
440 def __init__(self,
441 label_dictionary: Dictionary,
442 multi_label: bool = False,
443 multi_label_threshold: float = 0.5,
444 loss_weights: Dict[str, float] = None,
445 ):
447 super().__init__()
449 # initialize the label dictionary
450 self.label_dictionary: Dictionary = label_dictionary
452 # set up multi-label logic
453 self.multi_label = multi_label
454 self.multi_label_threshold = multi_label_threshold
456 # loss weights and loss function
457 self.weight_dict = loss_weights
458 # Initialize the weight tensor
459 if loss_weights is not None:
460 n_classes = len(self.label_dictionary)
461 weight_list = [1.0 for i in range(n_classes)]
462 for i, tag in enumerate(self.label_dictionary.get_items()):
463 if tag in loss_weights.keys():
464 weight_list[i] = loss_weights[tag]
465 self.loss_weights = torch.FloatTensor(weight_list).to(flair.device)
466 else:
467 self.loss_weights = None
469 if self.multi_label:
470 self.loss_function = torch.nn.BCEWithLogitsLoss(weight=self.loss_weights)
471 else:
472 self.loss_function = torch.nn.CrossEntropyLoss(weight=self.loss_weights)
474 @property
475 def multi_label_threshold(self):
476 return self._multi_label_threshold
478 @multi_label_threshold.setter
479 def multi_label_threshold(self, x): # setter method
480 if type(x) is dict:
481 if 'default' in x:
482 self._multi_label_threshold = x
483 else:
484 raise Exception('multi_label_threshold dict should have a "default" key')
485 else:
486 self._multi_label_threshold = {'default': x}
488 def forward_loss(self, sentences: Union[List[DataPoint], DataPoint]) -> torch.tensor:
489 scores, labels = self.forward_pass(sentences)
490 return self._calculate_loss(scores, labels)
492 def _calculate_loss(self, scores, labels):
494 if not any(labels): return torch.tensor(0., requires_grad=True, device=flair.device), 1
496 if self.multi_label:
497 labels = torch.tensor([[1 if l in all_labels_for_point else 0 for l in self.label_dictionary.get_items()]
498 for all_labels_for_point in labels], dtype=torch.float, device=flair.device)
500 else:
501 labels = torch.tensor([self.label_dictionary.get_idx_for_item(label[0]) if len(label) > 0
502 else self.label_dictionary.get_idx_for_item('O')
503 for label in labels], dtype=torch.long, device=flair.device)
505 return self.loss_function(scores, labels), len(labels)
507 def predict(
508 self,
509 sentences: Union[List[Sentence], Sentence],
510 mini_batch_size: int = 32,
511 return_probabilities_for_all_classes: bool = False,
512 verbose: bool = False,
513 label_name: Optional[str] = None,
514 return_loss=False,
515 embedding_storage_mode="none",
516 ):
517 """
518 Predicts the class labels for the given sentences. The labels are directly added to the sentences.
519 :param sentences: list of sentences
520 :param mini_batch_size: mini batch size to use
521 :param return_probabilities_for_all_classes : return probabilities for all classes instead of only best predicted
522 :param verbose: set to True to display a progress bar
523 :param return_loss: set to True to return loss
524 :param label_name: set this to change the name of the label type that is predicted
525 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if
526 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively.
527 'gpu' to store embeddings in GPU memory.
528 """
529 if label_name is None:
530 label_name = self.label_type if self.label_type is not None else "label"
532 with torch.no_grad():
533 if not sentences:
534 return sentences
536 if isinstance(sentences, DataPoint):
537 sentences = [sentences]
539 # filter empty sentences
540 if isinstance(sentences[0], DataPoint):
541 sentences = [sentence for sentence in sentences if len(sentence) > 0]
542 if len(sentences) == 0:
543 return sentences
545 # reverse sort all sequences by their length
546 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True)
548 reordered_sentences: List[Union[DataPoint, str]] = [sentences[index] for index in rev_order_len_index]
550 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size)
551 # progress bar for verbosity
552 if verbose:
553 dataloader = tqdm(dataloader)
555 overall_loss = 0
556 batch_no = 0
557 label_count = 0
558 for batch in dataloader:
560 batch_no += 1
562 if verbose:
563 dataloader.set_description(f"Inferencing on batch {batch_no}")
565 # stop if all sentences are empty
566 if not batch:
567 continue
569 scores, gold_labels, data_points, label_candidates = self.forward_pass(batch,
570 return_label_candidates=True)
571 # remove previously predicted labels of this type
572 for sentence in data_points:
573 sentence.remove_labels(label_name)
575 if return_loss:
576 overall_loss += self._calculate_loss(scores, gold_labels)[0]
577 label_count += len(label_candidates)
579 # if anything could possibly be predicted
580 if len(label_candidates) > 0:
581 if self.multi_label:
582 sigmoided = torch.sigmoid(scores) # size: (n_sentences, n_classes)
583 n_labels = sigmoided.size(1)
584 for s_idx, (data_point, label_candidate) in enumerate(zip(data_points, label_candidates)):
585 for l_idx in range(n_labels):
586 label_value = self.label_dictionary.get_item_for_index(l_idx)
587 if label_value == 'O': continue
588 label_threshold = self._get_label_threshold(label_value)
589 label_score = sigmoided[s_idx, l_idx].item()
590 if label_score > label_threshold or return_probabilities_for_all_classes:
591 label = label_candidate.spawn(value=label_value, score=label_score)
592 data_point.add_complex_label(label_name, label)
593 else:
594 softmax = torch.nn.functional.softmax(scores, dim=-1)
596 if return_probabilities_for_all_classes:
597 n_labels = softmax.size(1)
598 for s_idx, (data_point, label_candidate) in enumerate(zip(data_points, label_candidates)):
599 for l_idx in range(n_labels):
600 label_value = self.label_dictionary.get_item_for_index(l_idx)
601 if label_value == 'O': continue
602 label_score = softmax[s_idx, l_idx].item()
603 label = label_candidate.spawn(value=label_value, score=label_score)
604 data_point.add_complex_label(label_name, label)
605 else:
606 conf, idx = torch.max(softmax, dim=-1)
607 for data_point, label_candidate, c, i in zip(data_points, label_candidates, conf, idx):
608 label_value = self.label_dictionary.get_item_for_index(i.item())
609 if label_value == 'O': continue
610 label = label_candidate.spawn(value=label_value, score=c.item())
611 data_point.add_complex_label(label_name, label)
613 store_embeddings(batch, storage_mode=embedding_storage_mode)
615 if return_loss:
616 return overall_loss, label_count
618 def _get_label_threshold(self, label_value):
619 label_threshold = self.multi_label_threshold['default']
620 if label_value in self.multi_label_threshold:
621 label_threshold = self.multi_label_threshold[label_value]
623 return label_threshold
625 def __str__(self):
626 return super(flair.nn.Model, self).__str__().rstrip(')') + \
627 f' (weights): {self.weight_dict}\n' + \
628 f' (weight_tensor) {self.loss_weights}\n)'