Coverage for flair/flair/nn/model.py: 15%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

330 statements  

1import itertools 

2import logging 

3import warnings 

4from abc import abstractmethod 

5from collections import Counter 

6from pathlib import Path 

7from typing import Union, List, Tuple, Dict, Optional 

8 

9import torch.nn 

10from torch.utils.data.dataset import Dataset 

11from tqdm import tqdm 

12 

13import flair 

14from flair import file_utils 

15from flair.data import DataPoint, Sentence, Dictionary, SpanLabel 

16from flair.datasets import DataLoader, SentenceDataset 

17from flair.training_utils import Result, store_embeddings 

18 

19log = logging.getLogger("flair") 

20 

21 

22class Model(torch.nn.Module): 

23 """Abstract base class for all downstream task models in Flair, such as SequenceTagger and TextClassifier. 

24 Every new type of model must implement these methods.""" 

25 

26 @property 

27 @abstractmethod 

28 def label_type(self): 

29 """Each model predicts labels of a certain type. TODO: can we find a better name for this?""" 

30 raise NotImplementedError 

31 

32 @abstractmethod 

33 def forward_loss(self, data_points: Union[List[DataPoint], DataPoint]) -> torch.tensor: 

34 """Performs a forward pass and returns a loss tensor for backpropagation. Implement this to enable training.""" 

35 raise NotImplementedError 

36 

37 @abstractmethod 

38 def evaluate( 

39 self, 

40 sentences: Union[List[Sentence], Dataset], 

41 gold_label_type: str, 

42 out_path: Union[str, Path] = None, 

43 embedding_storage_mode: str = "none", 

44 mini_batch_size: int = 32, 

45 num_workers: int = 8, 

46 main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), 

47 exclude_labels: List[str] = [], 

48 gold_label_dictionary: Optional[Dictionary] = None, 

49 ) -> Result: 

50 """Evaluates the model. Returns a Result object containing evaluation 

51 results and a loss value. Implement this to enable evaluation. 

52 :param data_loader: DataLoader that iterates over dataset to be evaluated 

53 :param out_path: Optional output path to store predictions 

54 :param embedding_storage_mode: One of 'none', 'cpu' or 'gpu'. 'none' means all embeddings are deleted and 

55 freshly recomputed, 'cpu' means all embeddings are stored on CPU, or 'gpu' means all embeddings are stored on GPU 

56 :return: Returns a Tuple consisting of a Result object and a loss float value 

57 """ 

58 raise NotImplementedError 

59 

60 @abstractmethod 

61 def _get_state_dict(self): 

62 """Returns the state dictionary for this model. Implementing this enables the save() and save_checkpoint() 

63 functionality.""" 

64 raise NotImplementedError 

65 

66 @staticmethod 

67 @abstractmethod 

68 def _init_model_with_state_dict(state): 

69 """Initialize the model from a state dictionary. Implementing this enables the load() and load_checkpoint() 

70 functionality.""" 

71 raise NotImplementedError 

72 

73 @staticmethod 

74 def _fetch_model(model_name) -> str: 

75 return model_name 

76 

77 def save(self, model_file: Union[str, Path], checkpoint: bool = False): 

78 """ 

79 Saves the current model to the provided file. 

80 :param model_file: the model file 

81 """ 

82 model_state = self._get_state_dict() 

83 

84 # in Flair <0.9.1, optimizer and scheduler used to train model are not saved 

85 optimizer = scheduler = None 

86 

87 # write out a "model card" if one is set 

88 if hasattr(self, 'model_card'): 

89 

90 # special handling for optimizer: remember optimizer class and state dictionary 

91 if 'training_parameters' in self.model_card: 

92 training_parameters = self.model_card['training_parameters'] 

93 

94 if 'optimizer' in training_parameters: 

95 optimizer = training_parameters['optimizer'] 

96 if checkpoint: 

97 training_parameters['optimizer_state_dict'] = optimizer.state_dict() 

98 training_parameters['optimizer'] = optimizer.__class__ 

99 

100 if 'scheduler' in training_parameters: 

101 scheduler = training_parameters['scheduler'] 

102 if checkpoint: 

103 with warnings.catch_warnings(): 

104 warnings.simplefilter("ignore") 

105 training_parameters['scheduler_state_dict'] = scheduler.state_dict() 

106 training_parameters['scheduler'] = scheduler.__class__ 

107 

108 model_state['model_card'] = self.model_card 

109 

110 # save model 

111 torch.save(model_state, str(model_file), pickle_protocol=4) 

112 

113 # restore optimizer and scheduler to model card if set 

114 if optimizer: 

115 self.model_card['training_parameters']['optimizer'] = optimizer 

116 if scheduler: 

117 self.model_card['training_parameters']['scheduler'] = scheduler 

118 

119 @classmethod 

120 def load(cls, model: Union[str, Path]): 

121 """ 

122 Loads the model from the given file. 

123 :param model: the model file 

124 :return: the loaded text classifier model 

125 """ 

126 model_file = cls._fetch_model(str(model)) 

127 

128 with warnings.catch_warnings(): 

129 warnings.filterwarnings("ignore") 

130 # load_big_file is a workaround by https://github.com/highway11git to load models on some Mac/Windows setups 

131 # see https://github.com/zalandoresearch/flair/issues/351 

132 f = file_utils.load_big_file(str(model_file)) 

133 state = torch.load(f, map_location='cpu') 

134 

135 model = cls._init_model_with_state_dict(state) 

136 

137 if 'model_card' in state: 

138 model.model_card = state['model_card'] 

139 

140 model.eval() 

141 model.to(flair.device) 

142 

143 return model 

144 

145 def print_model_card(self): 

146 if hasattr(self, 'model_card'): 

147 param_out = "\n------------------------------------\n" 

148 param_out += "--------- Flair Model Card ---------\n" 

149 param_out += "------------------------------------\n" 

150 param_out += "- this Flair model was trained with:\n" 

151 param_out += f"-- Flair version {self.model_card['flair_version']}\n" 

152 param_out += f"-- PyTorch version {self.model_card['pytorch_version']}\n" 

153 if 'transformers_version' in self.model_card: 

154 param_out += f"-- Transformers version {self.model_card['transformers_version']}\n" 

155 param_out += "------------------------------------\n" 

156 

157 param_out += "------- Training Parameters: -------\n" 

158 param_out += "------------------------------------\n" 

159 training_params = '\n'.join(f'-- {param} = {self.model_card["training_parameters"][param]}' 

160 for param in self.model_card['training_parameters']) 

161 param_out += training_params + "\n" 

162 param_out += "------------------------------------\n" 

163 

164 log.info(param_out) 

165 else: 

166 log.info( 

167 "This model has no model card (likely because it is not yet trained or was trained with Flair version < 0.9.1)") 

168 

169 

170class Classifier(Model): 

171 """Abstract base class for all Flair models that do classification, both single- and multi-label. 

172 It inherits from flair.nn.Model and adds a unified evaluate() function so that all classification models 

173 use the same evaluation routines and compute the same numbers. 

174 Currently, the SequenceTagger implements this class directly, while all other classifiers in Flair 

175 implement the DefaultClassifier base class which implements Classifier.""" 

176 

177 def evaluate( 

178 self, 

179 data_points: Union[List[DataPoint], Dataset], 

180 gold_label_type: str, 

181 out_path: Union[str, Path] = None, 

182 embedding_storage_mode: str = "none", 

183 mini_batch_size: int = 32, 

184 num_workers: int = 8, 

185 main_evaluation_metric: Tuple[str, str] = ("micro avg", "f1-score"), 

186 exclude_labels: List[str] = [], 

187 gold_label_dictionary: Optional[Dictionary] = None, 

188 ) -> Result: 

189 import numpy as np 

190 import sklearn 

191 

192 # read Dataset into data loader (if list of sentences passed, make Dataset first) 

193 if not isinstance(data_points, Dataset): 

194 data_points = SentenceDataset(data_points) 

195 data_loader = DataLoader(data_points, batch_size=mini_batch_size, num_workers=num_workers) 

196 

197 with torch.no_grad(): 

198 

199 # loss calculation 

200 eval_loss = 0 

201 average_over = 0 

202 

203 # variables for printing 

204 lines: List[str] = [] 

205 is_word_level = False 

206 

207 # variables for computing scores 

208 all_spans: List[str] = [] 

209 all_true_values = {} 

210 all_predicted_values = {} 

211 

212 sentence_id = 0 

213 for batch in data_loader: 

214 

215 # remove any previously predicted labels 

216 for datapoint in batch: 

217 datapoint.remove_labels('predicted') 

218 

219 # predict for batch 

220 loss_and_count = self.predict(batch, 

221 embedding_storage_mode=embedding_storage_mode, 

222 mini_batch_size=mini_batch_size, 

223 label_name='predicted', 

224 return_loss=True) 

225 

226 if isinstance(loss_and_count, Tuple): 

227 average_over += loss_and_count[1] 

228 eval_loss += loss_and_count[0] 

229 else: 

230 eval_loss += loss_and_count 

231 

232 # get the gold labels 

233 for datapoint in batch: 

234 

235 for gold_label in datapoint.get_labels(gold_label_type): 

236 representation = str(sentence_id) + ': ' + gold_label.identifier 

237 

238 value = gold_label.value 

239 if gold_label_dictionary and gold_label_dictionary.get_idx_for_item(value) == 0: 

240 value = '<unk>' 

241 

242 if representation not in all_true_values: 

243 all_true_values[representation] = [value] 

244 else: 

245 all_true_values[representation].append(value) 

246 

247 if representation not in all_spans: 

248 all_spans.append(representation) 

249 

250 if type(gold_label) == SpanLabel: is_word_level = True 

251 

252 for predicted_span in datapoint.get_labels("predicted"): 

253 representation = str(sentence_id) + ': ' + predicted_span.identifier 

254 

255 # add to all_predicted_values 

256 if representation not in all_predicted_values: 

257 all_predicted_values[representation] = [predicted_span.value] 

258 else: 

259 all_predicted_values[representation].append(predicted_span.value) 

260 

261 if representation not in all_spans: 

262 all_spans.append(representation) 

263 

264 sentence_id += 1 

265 

266 store_embeddings(batch, embedding_storage_mode) 

267 

268 # make printout lines 

269 if out_path: 

270 for datapoint in batch: 

271 

272 # if the model is span-level, transfer to word-level annotations for printout 

273 if is_word_level: 

274 

275 # all labels default to "O" 

276 for token in datapoint: 

277 token.set_label("gold_bio", "O") 

278 token.set_label("predicted_bio", "O") 

279 

280 # set gold token-level 

281 for gold_label in datapoint.get_labels(gold_label_type): 

282 gold_label: SpanLabel = gold_label 

283 prefix = "B-" 

284 for token in gold_label.span: 

285 token.set_label("gold_bio", prefix + gold_label.value) 

286 prefix = "I-" 

287 

288 # set predicted token-level 

289 for predicted_label in datapoint.get_labels("predicted"): 

290 predicted_label: SpanLabel = predicted_label 

291 prefix = "B-" 

292 for token in predicted_label.span: 

293 token.set_label("predicted_bio", prefix + predicted_label.value) 

294 prefix = "I-" 

295 

296 # now print labels in CoNLL format 

297 for token in datapoint: 

298 eval_line = f"{token.text} " \ 

299 f"{token.get_tag('gold_bio').value} " \ 

300 f"{token.get_tag('predicted_bio').value}\n" 

301 lines.append(eval_line) 

302 lines.append("\n") 

303 else: 

304 # check if there is a label mismatch 

305 g = [label.identifier + label.value for label in datapoint.get_labels(gold_label_type)] 

306 p = [label.identifier + label.value for label in datapoint.get_labels('predicted')] 

307 g.sort() 

308 p.sort() 

309 correct_string = " -> MISMATCH!\n" if g != p else "" 

310 # print info 

311 eval_line = f"{datapoint.to_original_text()}\n" \ 

312 f" - Gold: {datapoint.get_labels(gold_label_type)}\n" \ 

313 f" - Pred: {datapoint.get_labels('predicted')}\n{correct_string}\n" 

314 lines.append(eval_line) 

315 

316 # write all_predicted_values to out_file if set 

317 if out_path: 

318 with open(Path(out_path), "w", encoding="utf-8") as outfile: 

319 outfile.write("".join(lines)) 

320 

321 # make the evaluation dictionary 

322 evaluation_label_dictionary = Dictionary(add_unk=False) 

323 evaluation_label_dictionary.add_item("O") 

324 for true_values in all_true_values.values(): 

325 for label in true_values: 

326 evaluation_label_dictionary.add_item(label) 

327 for predicted_values in all_predicted_values.values(): 

328 for label in predicted_values: 

329 evaluation_label_dictionary.add_item(label) 

330 

331 # finally, compute numbers 

332 y_true = [] 

333 y_pred = [] 

334 

335 for span in all_spans: 

336 

337 true_values = all_true_values[span] if span in all_true_values else ['O'] 

338 predicted_values = all_predicted_values[span] if span in all_predicted_values else ['O'] 

339 

340 y_true_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) 

341 for true_value in true_values: 

342 y_true_instance[evaluation_label_dictionary.get_idx_for_item(true_value)] = 1 

343 y_true.append(y_true_instance.tolist()) 

344 

345 y_pred_instance = np.zeros(len(evaluation_label_dictionary), dtype=int) 

346 for predicted_value in predicted_values: 

347 y_pred_instance[evaluation_label_dictionary.get_idx_for_item(predicted_value)] = 1 

348 y_pred.append(y_pred_instance.tolist()) 

349 

350 # now, calculate evaluation numbers 

351 target_names = [] 

352 labels = [] 

353 

354 counter = Counter() 

355 counter.update(list(itertools.chain.from_iterable(all_true_values.values()))) 

356 counter.update(list(itertools.chain.from_iterable(all_predicted_values.values()))) 

357 

358 for label_name, count in counter.most_common(): 

359 if label_name == 'O': continue 

360 if label_name in exclude_labels: continue 

361 target_names.append(label_name) 

362 labels.append(evaluation_label_dictionary.get_idx_for_item(label_name)) 

363 

364 # there is at least one gold label or one prediction (default) 

365 if len(all_true_values) + len(all_predicted_values) > 1: 

366 classification_report = sklearn.metrics.classification_report( 

367 y_true, y_pred, digits=4, target_names=target_names, zero_division=0, labels=labels, 

368 ) 

369 

370 classification_report_dict = sklearn.metrics.classification_report( 

371 y_true, y_pred, target_names=target_names, zero_division=0, output_dict=True, labels=labels, 

372 ) 

373 

374 accuracy_score = round(sklearn.metrics.accuracy_score(y_true, y_pred), 4) 

375 

376 precision_score = round(classification_report_dict["micro avg"]["precision"], 4) 

377 recall_score = round(classification_report_dict["micro avg"]["recall"], 4) 

378 micro_f_score = round(classification_report_dict["micro avg"]["f1-score"], 4) 

379 macro_f_score = round(classification_report_dict["macro avg"]["f1-score"], 4) 

380 

381 main_score = classification_report_dict[main_evaluation_metric[0]][main_evaluation_metric[1]] 

382 

383 else: 

384 # issue error and default all evaluation numbers to 0. 

385 log.error( 

386 "ACHTUNG! No gold labels and no all_predicted_values found! Could be an error in your corpus or how you " 

387 "initialize the trainer!") 

388 accuracy_score = precision_score = recall_score = micro_f_score = macro_f_score = main_score = 0. 

389 classification_report = "" 

390 classification_report_dict = {} 

391 

392 detailed_result = ( 

393 "\nResults:" 

394 f"\n- F-score (micro) {micro_f_score}" 

395 f"\n- F-score (macro) {macro_f_score}" 

396 f"\n- Accuracy {accuracy_score}" 

397 "\n\nBy class:\n" + classification_report 

398 ) 

399 

400 # line for log file 

401 log_header = "PRECISION\tRECALL\tF1\tACCURACY" 

402 log_line = f"{precision_score}\t" f"{recall_score}\t" f"{micro_f_score}\t" f"{accuracy_score}" 

403 

404 if average_over > 0: 

405 eval_loss /= average_over 

406 

407 result = Result( 

408 main_score=main_score, 

409 log_line=log_line, 

410 log_header=log_header, 

411 detailed_results=detailed_result, 

412 classification_report=classification_report_dict, 

413 loss=eval_loss 

414 ) 

415 

416 return result 

417 

418 

419class DefaultClassifier(Classifier): 

420 """Default base class for all Flair models that do classification, both single- and multi-label. 

421 It inherits from flair.nn.Classifier and thus from flair.nn.Model. All features shared by all classifiers 

422 are implemented here, including the loss calculation and the predict() method. 

423 Currently, the TextClassifier, RelationExtractor, TextPairClassifier and SimpleSequenceTagger implement 

424 this class. You only need to implement the forward_pass() method to implement this base class. 

425 """ 

426 

427 def forward_pass(self, 

428 sentences: Union[List[DataPoint], DataPoint], 

429 return_label_candidates: bool = False, 

430 ): 

431 """This method does a forward pass through the model given a list of data points as input. 

432 Returns the tuple (scores, labels) if return_label_candidates = False, where scores are a tensor of logits 

433 produced by the decoder and labels are the string labels for each data point. 

434 Returns the tuple (scores, labels, data_points, candidate_labels) if return_label_candidates = True, 

435 where data_points are the data points to which labels are added (commonly either Sentence or Token objects) 

436 and candidate_labels are empty Label objects for each prediction (depending on the task Label, 

437 SpanLabel or RelationLabel).""" 

438 raise NotImplementedError 

439 

440 def __init__(self, 

441 label_dictionary: Dictionary, 

442 multi_label: bool = False, 

443 multi_label_threshold: float = 0.5, 

444 loss_weights: Dict[str, float] = None, 

445 ): 

446 

447 super().__init__() 

448 

449 # initialize the label dictionary 

450 self.label_dictionary: Dictionary = label_dictionary 

451 

452 # set up multi-label logic 

453 self.multi_label = multi_label 

454 self.multi_label_threshold = multi_label_threshold 

455 

456 # loss weights and loss function 

457 self.weight_dict = loss_weights 

458 # Initialize the weight tensor 

459 if loss_weights is not None: 

460 n_classes = len(self.label_dictionary) 

461 weight_list = [1.0 for i in range(n_classes)] 

462 for i, tag in enumerate(self.label_dictionary.get_items()): 

463 if tag in loss_weights.keys(): 

464 weight_list[i] = loss_weights[tag] 

465 self.loss_weights = torch.FloatTensor(weight_list).to(flair.device) 

466 else: 

467 self.loss_weights = None 

468 

469 if self.multi_label: 

470 self.loss_function = torch.nn.BCEWithLogitsLoss(weight=self.loss_weights) 

471 else: 

472 self.loss_function = torch.nn.CrossEntropyLoss(weight=self.loss_weights) 

473 

474 @property 

475 def multi_label_threshold(self): 

476 return self._multi_label_threshold 

477 

478 @multi_label_threshold.setter 

479 def multi_label_threshold(self, x): # setter method 

480 if type(x) is dict: 

481 if 'default' in x: 

482 self._multi_label_threshold = x 

483 else: 

484 raise Exception('multi_label_threshold dict should have a "default" key') 

485 else: 

486 self._multi_label_threshold = {'default': x} 

487 

488 def forward_loss(self, sentences: Union[List[DataPoint], DataPoint]) -> torch.tensor: 

489 scores, labels = self.forward_pass(sentences) 

490 return self._calculate_loss(scores, labels) 

491 

492 def _calculate_loss(self, scores, labels): 

493 

494 if not any(labels): return torch.tensor(0., requires_grad=True, device=flair.device), 1 

495 

496 if self.multi_label: 

497 labels = torch.tensor([[1 if l in all_labels_for_point else 0 for l in self.label_dictionary.get_items()] 

498 for all_labels_for_point in labels], dtype=torch.float, device=flair.device) 

499 

500 else: 

501 labels = torch.tensor([self.label_dictionary.get_idx_for_item(label[0]) if len(label) > 0 

502 else self.label_dictionary.get_idx_for_item('O') 

503 for label in labels], dtype=torch.long, device=flair.device) 

504 

505 return self.loss_function(scores, labels), len(labels) 

506 

507 def predict( 

508 self, 

509 sentences: Union[List[Sentence], Sentence], 

510 mini_batch_size: int = 32, 

511 return_probabilities_for_all_classes: bool = False, 

512 verbose: bool = False, 

513 label_name: Optional[str] = None, 

514 return_loss=False, 

515 embedding_storage_mode="none", 

516 ): 

517 """ 

518 Predicts the class labels for the given sentences. The labels are directly added to the sentences. 

519 :param sentences: list of sentences 

520 :param mini_batch_size: mini batch size to use 

521 :param return_probabilities_for_all_classes : return probabilities for all classes instead of only best predicted 

522 :param verbose: set to True to display a progress bar 

523 :param return_loss: set to True to return loss 

524 :param label_name: set this to change the name of the label type that is predicted 

525 :param embedding_storage_mode: default is 'none' which is always best. Only set to 'cpu' or 'gpu' if 

526 you wish to not only predict, but also keep the generated embeddings in CPU or GPU memory respectively. 

527 'gpu' to store embeddings in GPU memory. 

528 """ 

529 if label_name is None: 

530 label_name = self.label_type if self.label_type is not None else "label" 

531 

532 with torch.no_grad(): 

533 if not sentences: 

534 return sentences 

535 

536 if isinstance(sentences, DataPoint): 

537 sentences = [sentences] 

538 

539 # filter empty sentences 

540 if isinstance(sentences[0], DataPoint): 

541 sentences = [sentence for sentence in sentences if len(sentence) > 0] 

542 if len(sentences) == 0: 

543 return sentences 

544 

545 # reverse sort all sequences by their length 

546 rev_order_len_index = sorted(range(len(sentences)), key=lambda k: len(sentences[k]), reverse=True) 

547 

548 reordered_sentences: List[Union[DataPoint, str]] = [sentences[index] for index in rev_order_len_index] 

549 

550 dataloader = DataLoader(dataset=SentenceDataset(reordered_sentences), batch_size=mini_batch_size) 

551 # progress bar for verbosity 

552 if verbose: 

553 dataloader = tqdm(dataloader) 

554 

555 overall_loss = 0 

556 batch_no = 0 

557 label_count = 0 

558 for batch in dataloader: 

559 

560 batch_no += 1 

561 

562 if verbose: 

563 dataloader.set_description(f"Inferencing on batch {batch_no}") 

564 

565 # stop if all sentences are empty 

566 if not batch: 

567 continue 

568 

569 scores, gold_labels, data_points, label_candidates = self.forward_pass(batch, 

570 return_label_candidates=True) 

571 # remove previously predicted labels of this type 

572 for sentence in data_points: 

573 sentence.remove_labels(label_name) 

574 

575 if return_loss: 

576 overall_loss += self._calculate_loss(scores, gold_labels)[0] 

577 label_count += len(label_candidates) 

578 

579 # if anything could possibly be predicted 

580 if len(label_candidates) > 0: 

581 if self.multi_label: 

582 sigmoided = torch.sigmoid(scores) # size: (n_sentences, n_classes) 

583 n_labels = sigmoided.size(1) 

584 for s_idx, (data_point, label_candidate) in enumerate(zip(data_points, label_candidates)): 

585 for l_idx in range(n_labels): 

586 label_value = self.label_dictionary.get_item_for_index(l_idx) 

587 if label_value == 'O': continue 

588 label_threshold = self._get_label_threshold(label_value) 

589 label_score = sigmoided[s_idx, l_idx].item() 

590 if label_score > label_threshold or return_probabilities_for_all_classes: 

591 label = label_candidate.spawn(value=label_value, score=label_score) 

592 data_point.add_complex_label(label_name, label) 

593 else: 

594 softmax = torch.nn.functional.softmax(scores, dim=-1) 

595 

596 if return_probabilities_for_all_classes: 

597 n_labels = softmax.size(1) 

598 for s_idx, (data_point, label_candidate) in enumerate(zip(data_points, label_candidates)): 

599 for l_idx in range(n_labels): 

600 label_value = self.label_dictionary.get_item_for_index(l_idx) 

601 if label_value == 'O': continue 

602 label_score = softmax[s_idx, l_idx].item() 

603 label = label_candidate.spawn(value=label_value, score=label_score) 

604 data_point.add_complex_label(label_name, label) 

605 else: 

606 conf, idx = torch.max(softmax, dim=-1) 

607 for data_point, label_candidate, c, i in zip(data_points, label_candidates, conf, idx): 

608 label_value = self.label_dictionary.get_item_for_index(i.item()) 

609 if label_value == 'O': continue 

610 label = label_candidate.spawn(value=label_value, score=c.item()) 

611 data_point.add_complex_label(label_name, label) 

612 

613 store_embeddings(batch, storage_mode=embedding_storage_mode) 

614 

615 if return_loss: 

616 return overall_loss, label_count 

617 

618 def _get_label_threshold(self, label_value): 

619 label_threshold = self.multi_label_threshold['default'] 

620 if label_value in self.multi_label_threshold: 

621 label_threshold = self.multi_label_threshold[label_value] 

622 

623 return label_threshold 

624 

625 def __str__(self): 

626 return super(flair.nn.Model, self).__str__().rstrip(')') + \ 

627 f' (weights): {self.weight_dict}\n' + \ 

628 f' (weight_tensor) {self.loss_weights}\n)'