Coverage for flair/flair/trainers/trainer.py: 8%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import copy
2import datetime
3import inspect
4import logging
5import os
6import sys
7import time
8import warnings
9from inspect import signature
10from pathlib import Path
11from typing import Union, Tuple, Optional
13import torch
14from torch.optim.sgd import SGD
15from torch.utils.data.dataset import ConcatDataset
17from flair.nn import Model
19try:
20 from apex import amp
21except ImportError:
22 amp = None
24import flair
25import flair.nn
26from flair.data import MultiCorpus, Corpus, Dictionary
27from flair.datasets import DataLoader
28from flair.optim import ExpAnnealLR, LinearSchedulerWithWarmup
29from flair.training_utils import (
30 init_output_file,
31 WeightExtractor,
32 log_line,
33 add_file_handler,
34 Result,
35 store_embeddings,
36 AnnealOnPlateau,
37)
38from torch.optim.lr_scheduler import OneCycleLR
39from flair.models import SequenceTagger
40import random
42log = logging.getLogger("flair")
45class ModelTrainer:
46 def __init__(
47 self,
48 model: flair.nn.Model,
49 corpus: Corpus,
50 ):
51 """
52 Initialize a model trainer
53 :param model: The model that you want to train. The model should inherit from flair.nn.Model
54 :param corpus: The dataset used to train the model, should be of type Corpus
55 """
56 self.model: flair.nn.Model = model
57 self.corpus: Corpus = corpus
59 @staticmethod
60 def check_for_and_delete_previous_best_models(base_path):
61 all_best_model_names = [filename for filename in os.listdir(base_path) if
62 filename.startswith("best-model")]
63 if len(all_best_model_names) != 0:
64 warnings.warn(
65 "There should be no best model saved at epoch 1 except there is a model from previous trainings"
66 " in your training folder. All previous best models will be deleted.")
67 for single_model in all_best_model_names:
68 previous_best_path = os.path.join(base_path, single_model)
69 if os.path.exists(previous_best_path):
70 os.remove(previous_best_path)
72 def train(
73 self,
74 base_path: Union[Path, str],
75 learning_rate: float = 0.1,
76 mini_batch_size: int = 32,
77 mini_batch_chunk_size: Optional[int] = None,
78 max_epochs: int = 100,
79 train_with_dev: bool = False,
80 train_with_test: bool = False,
81 monitor_train: bool = False,
82 monitor_test: bool = False,
83 main_evaluation_metric: Tuple[str, str] = ("micro avg", 'f1-score'),
84 scheduler=AnnealOnPlateau,
85 anneal_factor: float = 0.5,
86 patience: int = 3,
87 min_learning_rate: float = 0.0001,
88 initial_extra_patience: int = 0,
89 optimizer: torch.optim.Optimizer = SGD,
90 cycle_momentum: bool = False,
91 warmup_fraction: float = 0.1,
92 embeddings_storage_mode: str = "cpu",
93 checkpoint: bool = False,
94 save_final_model: bool = True,
95 anneal_with_restarts: bool = False,
96 anneal_with_prestarts: bool = False,
97 anneal_against_dev_loss: bool = False,
98 batch_growth_annealing: bool = False,
99 shuffle: bool = True,
100 param_selection_mode: bool = False,
101 write_weights: bool = False,
102 num_workers: int = 6,
103 sampler=None,
104 use_amp: bool = False,
105 amp_opt_level: str = "O1",
106 eval_on_train_fraction: float = 0.0,
107 eval_on_train_shuffle: bool = False,
108 save_model_each_k_epochs: int = 0,
109 tensorboard_comment: str = '',
110 use_swa: bool = False,
111 use_final_model_for_eval: bool = False,
112 gold_label_dictionary_for_eval: Optional[Dictionary] = None,
113 create_file_logs: bool = True,
114 create_loss_file: bool = True,
115 epoch: int = 0,
116 use_tensorboard: bool = False,
117 tensorboard_log_dir=None,
118 metrics_for_tensorboard=[],
119 optimizer_state_dict: Optional = None,
120 scheduler_state_dict: Optional = None,
121 save_optimizer_state: bool = False,
122 **kwargs,
123 ) -> dict:
124 """
125 Trains any class that implements the flair.nn.Model interface.
126 :param base_path: Main path to which all output during training is logged and models are saved
127 :param learning_rate: Initial learning rate (or max, if scheduler is OneCycleLR)
128 :param mini_batch_size: Size of mini-batches during training
129 :param mini_batch_chunk_size: If mini-batches are larger than this number, they get broken down into chunks of this size for processing purposes
130 :param max_epochs: Maximum number of epochs to train. Terminates training if this number is surpassed.
131 :param scheduler: The learning rate scheduler to use
132 :param checkpoint: If True, a full checkpoint is saved at end of each epoch
133 :param cycle_momentum: If scheduler is OneCycleLR, whether the scheduler should cycle also the momentum
134 :param anneal_factor: The factor by which the learning rate is annealed
135 :param patience: Patience is the number of epochs with no improvement the Trainer waits
136 until annealing the learning rate
137 :param min_learning_rate: If the learning rate falls below this threshold, training terminates
138 :param warmup_fraction: Fraction of warmup steps if the scheduler is LinearSchedulerWithWarmup
139 :param train_with_dev: If True, the data from dev split is added to the training data
140 :param train_with_test: If True, the data from test split is added to the training data
141 :param monitor_train: If True, training data is evaluated at end of each epoch
142 :param monitor_test: If True, test data is evaluated at end of each epoch
143 :param embeddings_storage_mode: One of 'none' (all embeddings are deleted and freshly recomputed),
144 'cpu' (embeddings are stored on CPU) or 'gpu' (embeddings are stored on GPU)
145 :param save_final_model: If True, final model is saved
146 :param anneal_with_restarts: If True, the last best model is restored when annealing the learning rate
147 :param shuffle: If True, data is shuffled during training
148 :param param_selection_mode: If True, testing is performed against dev data. Use this mode when doing
149 parameter selection.
150 :param num_workers: Number of workers in your data loader.
151 :param sampler: You can pass a data sampler here for special sampling of data.
152 :param eval_on_train_fraction: the fraction of train data to do the evaluation on,
153 if 0. the evaluation is not performed on fraction of training data,
154 if 'dev' the size is determined from dev set size
155 :param eval_on_train_shuffle: if True the train data fraction is determined on the start of training
156 and kept fixed during training, otherwise it's sampled at beginning of each epoch
157 :param save_model_each_k_epochs: Each k epochs, a model state will be written out. If set to '5', a model will
158 be saved each 5 epochs. Default is 0 which means no model saving.
159 :param main_evaluation_metric: Type of metric to use for best model tracking and learning rate scheduling (if dev data is available, otherwise loss will be used), currently only applicable for text_classification_model
160 :param tensorboard_comment: Comment to use for tensorboard logging
161 :param create_file_logs: If True, the logs will also be stored in a file 'training.log' in the model folder
162 :param create_loss_file: If True, the loss will be writen to a file 'loss.tsv' in the model folder
163 :param optimizer: The optimizer to use (typically SGD or Adam)
164 :param epoch: The starting epoch (normally 0 but could be higher if you continue training model)
165 :param use_tensorboard: If True, writes out tensorboard information
166 :param tensorboard_log_dir: Directory into which tensorboard log files will be written
167 :param metrics_for_tensorboard: List of tuples that specify which metrics (in addition to the main_score) shall be plotted in tensorboard, could be [("macro avg", 'f1-score'), ("macro avg", 'precision')] for example
168 :param kwargs: Other arguments for the Optimizer
169 :return:
170 """
172 # create a model card for this model with Flair and PyTorch version
173 model_card = {'flair_version': flair.__version__, 'pytorch_version': torch.__version__}
175 # also record Transformers version if library is loaded
176 try:
177 import transformers
178 model_card['transformers_version'] = transformers.__version__
179 except:
180 pass
182 # remember all parameters used in train() call
183 local_variables = locals()
184 training_parameters = {}
185 for parameter in signature(self.train).parameters:
186 training_parameters[parameter] = local_variables[parameter]
187 model_card['training_parameters'] = training_parameters
189 # add model card to model
190 self.model.model_card = model_card
192 if use_tensorboard:
193 try:
194 from torch.utils.tensorboard import SummaryWriter
196 if tensorboard_log_dir is not None and not os.path.exists(tensorboard_log_dir):
197 os.mkdir(tensorboard_log_dir)
198 writer = SummaryWriter(log_dir=tensorboard_log_dir, comment=tensorboard_comment)
199 log.info(f"tensorboard logging path is {tensorboard_log_dir}")
201 except:
202 log_line(log)
203 log.warning("ATTENTION! PyTorch >= 1.1.0 and pillow are required for TensorBoard support!")
204 log_line(log)
205 use_tensorboard = False
206 pass
208 if use_amp:
209 if sys.version_info < (3, 0):
210 raise RuntimeError("Apex currently only supports Python 3. Aborting.")
211 if amp is None:
212 raise RuntimeError(
213 "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
214 "to enable mixed-precision training."
215 )
217 if mini_batch_chunk_size is None:
218 mini_batch_chunk_size = mini_batch_size
219 if learning_rate < min_learning_rate:
220 min_learning_rate = learning_rate / 10
222 initial_learning_rate = learning_rate
224 # cast string to Path
225 if type(base_path) is str:
226 base_path = Path(base_path)
227 base_path.mkdir(exist_ok=True, parents=True)
229 if create_file_logs:
230 log_handler = add_file_handler(log, base_path / "training.log")
231 else:
232 log_handler = None
234 log_line(log)
235 log.info(f'Model: "{self.model}"')
236 log_line(log)
237 log.info(f'Corpus: "{self.corpus}"')
238 log_line(log)
239 log.info("Parameters:")
240 log.info(f' - learning_rate: "{learning_rate}"')
241 log.info(f' - mini_batch_size: "{mini_batch_size}"')
242 log.info(f' - patience: "{patience}"')
243 log.info(f' - anneal_factor: "{anneal_factor}"')
244 log.info(f' - max_epochs: "{max_epochs}"')
245 log.info(f' - shuffle: "{shuffle}"')
246 log.info(f' - train_with_dev: "{train_with_dev}"')
247 log.info(f' - batch_growth_annealing: "{batch_growth_annealing}"')
248 log_line(log)
249 log.info(f'Model training base path: "{base_path}"')
250 log_line(log)
251 log.info(f"Device: {flair.device}")
252 log_line(log)
253 log.info(f"Embeddings storage mode: {embeddings_storage_mode}")
254 if isinstance(self.model, SequenceTagger) and self.model.weight_dict and self.model.use_crf:
255 log_line(log)
256 log.warning(f'WARNING: Specified class weights will not take effect when using CRF')
258 # check for previously saved best models in the current training folder and delete them
259 self.check_for_and_delete_previous_best_models(base_path)
261 # determine what splits (train, dev, test) to evaluate and log
262 log_train = True if monitor_train else False
263 log_test = True if (not param_selection_mode and self.corpus.test and monitor_test) else False
264 log_dev = False if train_with_dev or not self.corpus.dev else True
265 log_train_part = True if (eval_on_train_fraction == "dev" or eval_on_train_fraction > 0.0) else False
267 if log_train_part:
268 train_part_size = len(self.corpus.dev) if eval_on_train_fraction == "dev" \
269 else int(len(self.corpus.train) * eval_on_train_fraction)
271 assert train_part_size > 0
272 if not eval_on_train_shuffle:
273 train_part_indices = list(range(train_part_size))
274 train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)
276 # prepare loss logging file and set up header
277 loss_txt = init_output_file(base_path, "loss.tsv") if create_loss_file else None
279 weight_extractor = WeightExtractor(base_path)
281 # if optimizer class is passed, instantiate:
282 if inspect.isclass(optimizer):
283 optimizer: torch.optim.Optimizer = optimizer(self.model.parameters(), lr=learning_rate, **kwargs)
285 if use_swa:
286 import torchcontrib
287 optimizer = torchcontrib.optim.SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=learning_rate)
289 if use_amp:
290 self.model, optimizer = amp.initialize(
291 self.model, optimizer, opt_level=amp_opt_level
292 )
294 # load existing optimizer state dictionary if it exists
295 if optimizer_state_dict:
296 optimizer.load_state_dict(optimizer_state_dict)
298 # minimize training loss if training with dev data, else maximize dev score
299 anneal_mode = "min" if train_with_dev or anneal_against_dev_loss else "max"
300 best_validation_score = 100000000000 if train_with_dev or anneal_against_dev_loss else 0.
302 dataset_size = len(self.corpus.train)
303 if train_with_dev:
304 dataset_size += len(self.corpus.dev)
306 # if scheduler is passed as a class, instantiate
307 if inspect.isclass(scheduler):
308 if scheduler == OneCycleLR:
309 scheduler = OneCycleLR(optimizer,
310 max_lr=learning_rate,
311 steps_per_epoch=dataset_size // mini_batch_size + 1,
312 epochs=max_epochs - epoch,
313 # if we load a checkpoint, we have already trained for epoch
314 pct_start=0.0,
315 cycle_momentum=cycle_momentum)
316 elif scheduler == LinearSchedulerWithWarmup:
317 steps_per_epoch = (dataset_size + mini_batch_size - 1) / mini_batch_size
318 num_train_steps = int(steps_per_epoch * max_epochs)
319 num_warmup_steps = int(num_train_steps * warmup_fraction)
321 scheduler = LinearSchedulerWithWarmup(optimizer,
322 num_train_steps=num_train_steps,
323 num_warmup_steps=num_warmup_steps)
324 else:
325 scheduler = scheduler(
326 optimizer,
327 factor=anneal_factor,
328 patience=patience,
329 initial_extra_patience=initial_extra_patience,
330 mode=anneal_mode,
331 verbose=True,
332 )
334 # load existing scheduler state dictionary if it exists
335 if scheduler_state_dict:
336 scheduler.load_state_dict(scheduler_state_dict)
338 # update optimizer and scheduler in model card
339 model_card['training_parameters']['optimizer'] = optimizer
340 model_card['training_parameters']['scheduler'] = scheduler
342 if isinstance(scheduler, OneCycleLR) and batch_growth_annealing:
343 raise ValueError("Batch growth with OneCycle policy is not implemented.")
345 train_data = self.corpus.train
347 # if training also uses dev/train data, include in training set
348 if train_with_dev or train_with_test:
350 parts = [self.corpus.train]
351 if train_with_dev: parts.append(self.corpus.dev)
352 if train_with_test: parts.append(self.corpus.test)
354 train_data = ConcatDataset(parts)
356 # initialize sampler if provided
357 if sampler is not None:
358 # init with default values if only class is provided
359 if inspect.isclass(sampler):
360 sampler = sampler()
361 # set dataset to sample from
362 sampler.set_dataset(train_data)
363 shuffle = False
365 dev_score_history = []
366 dev_loss_history = []
367 train_loss_history = []
369 micro_batch_size = mini_batch_chunk_size
371 # At any point you can hit Ctrl + C to break out of training early.
372 try:
373 previous_learning_rate = learning_rate
374 momentum = 0
375 for group in optimizer.param_groups:
376 if "momentum" in group:
377 momentum = group["momentum"]
379 for epoch in range(epoch + 1, max_epochs + 1):
380 log_line(log)
382 # update epoch in model card
383 self.model.model_card['training_parameters']['epoch'] = epoch
385 if anneal_with_prestarts:
386 last_epoch_model_state_dict = copy.deepcopy(self.model.state_dict())
388 if eval_on_train_shuffle:
389 train_part_indices = list(range(self.corpus.train))
390 random.shuffle(train_part_indices)
391 train_part_indices = train_part_indices[:train_part_size]
392 train_part = torch.utils.data.dataset.Subset(self.corpus.train, train_part_indices)
394 # get new learning rate
395 for group in optimizer.param_groups:
396 learning_rate = group["lr"]
398 if learning_rate != previous_learning_rate and batch_growth_annealing:
399 mini_batch_size *= 2
401 # reload last best model if annealing with restarts is enabled
402 if (
403 (anneal_with_restarts or anneal_with_prestarts)
404 and learning_rate != previous_learning_rate
405 and os.path.exists(base_path / "best-model.pt")
406 ):
407 if anneal_with_restarts:
408 log.info("resetting to best model")
409 self.model.load_state_dict(
410 self.model.load(base_path / "best-model.pt").state_dict()
411 )
412 if anneal_with_prestarts:
413 log.info("resetting to pre-best model")
414 self.model.load_state_dict(
415 self.model.load(base_path / "pre-best-model.pt").state_dict()
416 )
418 previous_learning_rate = learning_rate
419 if use_tensorboard:
420 writer.add_scalar("learning_rate", learning_rate, epoch)
422 # stop training if learning rate becomes too small
423 if ((not isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)) and
424 learning_rate < min_learning_rate)):
425 log_line(log)
426 log.info("learning rate too small - quitting training!")
427 log_line(log)
428 break
430 batch_loader = DataLoader(
431 train_data,
432 batch_size=mini_batch_size,
433 shuffle=shuffle if epoch > 1 else False, # never shuffle the first epoch
434 num_workers=num_workers,
435 sampler=sampler,
436 )
438 self.model.train()
440 train_loss: float = 0
442 seen_batches = 0
443 total_number_of_batches = len(batch_loader)
445 modulo = max(1, int(total_number_of_batches / 10))
447 # process mini-batches
448 batch_time = 0
449 average_over = 0
450 for batch_no, batch in enumerate(batch_loader):
452 start_time = time.time()
454 # zero the gradients on the model and optimizer
455 self.model.zero_grad()
456 optimizer.zero_grad()
458 # if necessary, make batch_steps
459 batch_steps = [batch]
460 if len(batch) > micro_batch_size:
461 batch_steps = [batch[x: x + micro_batch_size] for x in range(0, len(batch), micro_batch_size)]
463 # forward and backward for batch
464 for batch_step in batch_steps:
466 # forward pass
467 loss = self.model.forward_loss(batch_step)
469 if isinstance(loss, Tuple):
470 average_over += loss[1]
471 loss = loss[0]
473 # Backward
474 if use_amp:
475 with amp.scale_loss(loss, optimizer) as scaled_loss:
476 scaled_loss.backward()
477 else:
478 loss.backward()
479 train_loss += loss.item()
481 # do the optimizer step
482 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
483 optimizer.step()
485 # do the scheduler step if one-cycle or linear decay
486 if isinstance(scheduler, (OneCycleLR, LinearSchedulerWithWarmup)):
487 scheduler.step()
488 # get new learning rate
489 for group in optimizer.param_groups:
490 learning_rate = group["lr"]
491 if "momentum" in group:
492 momentum = group["momentum"]
493 if "betas" in group:
494 momentum, _ = group["betas"]
496 seen_batches += 1
498 # depending on memory mode, embeddings are moved to CPU, GPU or deleted
499 store_embeddings(batch, embeddings_storage_mode)
501 batch_time += time.time() - start_time
502 if seen_batches % modulo == 0:
503 momentum_info = f' - momentum: {momentum:.4f}' if cycle_momentum else ''
504 intermittent_loss = train_loss / average_over if average_over > 0 else train_loss / seen_batches
505 log.info(
506 f"epoch {epoch} - iter {seen_batches}/{total_number_of_batches} - loss "
507 f"{intermittent_loss:.8f} - samples/sec: {mini_batch_size * modulo / batch_time:.2f}"
508 f" - lr: {learning_rate:.6f}{momentum_info}"
509 )
510 batch_time = 0
511 iteration = epoch * total_number_of_batches + batch_no
512 if not param_selection_mode and write_weights:
513 weight_extractor.extract_weights(self.model.state_dict(), iteration)
515 if average_over != 0:
516 train_loss /= average_over
518 self.model.eval()
520 log_line(log)
521 log.info(f"EPOCH {epoch} done: loss {train_loss:.4f} - lr {learning_rate:.7f}")
523 if use_tensorboard:
524 writer.add_scalar("train_loss", train_loss, epoch)
526 # evaluate on train / dev / test split depending on training settings
527 result_line: str = ""
529 if log_train:
530 train_eval_result = self.model.evaluate(
531 self.corpus.train,
532 gold_label_type=self.model.label_type,
533 mini_batch_size=mini_batch_chunk_size,
534 num_workers=num_workers,
535 embedding_storage_mode=embeddings_storage_mode,
536 main_evaluation_metric=main_evaluation_metric,
537 gold_label_dictionary=gold_label_dictionary_for_eval,
538 )
539 result_line += f"\t{train_eval_result.log_line}"
541 # depending on memory mode, embeddings are moved to CPU, GPU or deleted
542 store_embeddings(self.corpus.train, embeddings_storage_mode)
544 if log_train_part:
545 train_part_eval_result, train_part_loss = self.model.evaluate(
546 train_part,
547 gold_label_type=self.model.label_type,
548 mini_batch_size=mini_batch_chunk_size,
549 num_workers=num_workers,
550 embedding_storage_mode=embeddings_storage_mode,
551 main_evaluation_metric=main_evaluation_metric,
552 gold_label_dictionary=gold_label_dictionary_for_eval,
553 )
554 result_line += f"\t{train_part_loss}\t{train_part_eval_result.log_line}"
556 log.info(
557 f"TRAIN_SPLIT : loss {train_part_loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(train_part_eval_result.main_score, 4)}"
558 )
559 if use_tensorboard:
560 for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
561 writer.add_scalar(
562 f"train_{metric_class_avg_type}_{metric_type}",
563 train_part_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
564 )
566 if log_dev:
567 dev_eval_result = self.model.evaluate(
568 self.corpus.dev,
569 gold_label_type=self.model.label_type,
570 mini_batch_size=mini_batch_chunk_size,
571 num_workers=num_workers,
572 out_path=base_path / "dev.tsv",
573 embedding_storage_mode=embeddings_storage_mode,
574 main_evaluation_metric=main_evaluation_metric,
575 gold_label_dictionary=gold_label_dictionary_for_eval,
576 )
577 result_line += f"\t{dev_eval_result.loss}\t{dev_eval_result.log_line}"
578 log.info(
579 f"DEV : loss {dev_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(dev_eval_result.main_score, 4)}"
580 )
581 # calculate scores using dev data if available
582 # append dev score to score history
583 dev_score_history.append(dev_eval_result.main_score)
584 dev_loss_history.append(dev_eval_result.loss)
586 dev_score = dev_eval_result.main_score
588 # depending on memory mode, embeddings are moved to CPU, GPU or deleted
589 store_embeddings(self.corpus.dev, embeddings_storage_mode)
591 if use_tensorboard:
592 writer.add_scalar("dev_loss", dev_eval_result.loss, epoch)
593 writer.add_scalar("dev_score", dev_eval_result.main_score, epoch)
594 for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
595 writer.add_scalar(
596 f"dev_{metric_class_avg_type}_{metric_type}",
597 dev_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
598 )
600 if log_test:
601 test_eval_result = self.model.evaluate(
602 self.corpus.test,
603 gold_label_type=self.model.label_type,
604 mini_batch_size=mini_batch_chunk_size,
605 num_workers=num_workers,
606 out_path=base_path / "test.tsv",
607 embedding_storage_mode=embeddings_storage_mode,
608 main_evaluation_metric=main_evaluation_metric,
609 gold_label_dictionary=gold_label_dictionary_for_eval,
610 )
611 result_line += f"\t{test_eval_result.loss}\t{test_eval_result.log_line}"
612 log.info(
613 f"TEST : loss {test_eval_result.loss} - {main_evaluation_metric[1]} ({main_evaluation_metric[0]}) {round(test_eval_result.main_score, 4)}"
614 )
616 # depending on memory mode, embeddings are moved to CPU, GPU or deleted
617 store_embeddings(self.corpus.test, embeddings_storage_mode)
619 if use_tensorboard:
620 writer.add_scalar("test_loss", test_eval_result.loss, epoch)
621 writer.add_scalar("test_score", test_eval_result.main_score, epoch)
622 for (metric_class_avg_type, metric_type) in metrics_for_tensorboard:
623 writer.add_scalar(
624 f"test_{metric_class_avg_type}_{metric_type}",
625 test_eval_result.classification_report[metric_class_avg_type][metric_type], epoch
626 )
628 # determine if this is the best model or if we need to anneal
629 current_epoch_has_best_model_so_far = False
630 # default mode: anneal against dev score
631 if not train_with_dev and not anneal_against_dev_loss:
632 if dev_score > best_validation_score:
633 current_epoch_has_best_model_so_far = True
634 best_validation_score = dev_score
636 if isinstance(scheduler, AnnealOnPlateau):
637 scheduler.step(dev_score, dev_eval_result.loss)
639 # alternative: anneal against dev loss
640 if not train_with_dev and anneal_against_dev_loss:
641 if dev_eval_result.loss < best_validation_score:
642 current_epoch_has_best_model_so_far = True
643 best_validation_score = dev_eval_result.loss
645 if isinstance(scheduler, AnnealOnPlateau):
646 scheduler.step(dev_eval_result.loss)
648 # alternative: anneal against train loss
649 if train_with_dev:
650 if train_loss < best_validation_score:
651 current_epoch_has_best_model_so_far = True
652 best_validation_score = train_loss
654 if isinstance(scheduler, AnnealOnPlateau):
655 scheduler.step(train_loss)
657 train_loss_history.append(train_loss)
659 # determine bad epoch number
660 try:
661 bad_epochs = scheduler.num_bad_epochs
662 except:
663 bad_epochs = 0
664 for group in optimizer.param_groups:
665 new_learning_rate = group["lr"]
666 if new_learning_rate != previous_learning_rate:
667 bad_epochs = patience + 1
668 if previous_learning_rate == initial_learning_rate: bad_epochs += initial_extra_patience
670 # log bad epochs
671 log.info(f"BAD EPOCHS (no improvement): {bad_epochs}")
673 if create_loss_file:
674 # output log file
675 with open(loss_txt, "a") as f:
677 # make headers on first epoch
678 if epoch == 1:
679 f.write(f"EPOCH\tTIMESTAMP\tBAD_EPOCHS\tLEARNING_RATE\tTRAIN_LOSS")
681 if log_train:
682 f.write("\tTRAIN_" + "\tTRAIN_".join(train_eval_result.log_header.split("\t")))
684 if log_train_part:
685 f.write("\tTRAIN_PART_LOSS\tTRAIN_PART_" + "\tTRAIN_PART_".join(
686 train_part_eval_result.log_header.split("\t")))
688 if log_dev:
689 f.write("\tDEV_LOSS\tDEV_" + "\tDEV_".join(dev_eval_result.log_header.split("\t")))
691 if log_test:
692 f.write("\tTEST_LOSS\tTEST_" + "\tTEST_".join(test_eval_result.log_header.split("\t")))
694 f.write(
695 f"\n{epoch}\t{datetime.datetime.now():%H:%M:%S}\t{bad_epochs}\t{learning_rate:.4f}\t{train_loss}"
696 )
697 f.write(result_line)
699 # if checkpoint is enabled, save model at each epoch
700 if checkpoint and not param_selection_mode:
701 self.model.save(base_path / "checkpoint.pt", checkpoint=True)
703 # Check whether to save best model
704 if (
705 (not train_with_dev or anneal_with_restarts or anneal_with_prestarts)
706 and not param_selection_mode
707 and current_epoch_has_best_model_so_far
708 and not use_final_model_for_eval
709 ):
710 log.info("saving best model")
711 self.model.save(base_path / "best-model.pt", checkpoint=save_optimizer_state)
713 if anneal_with_prestarts:
714 current_state_dict = self.model.state_dict()
715 self.model.load_state_dict(last_epoch_model_state_dict)
716 self.model.save(base_path / "pre-best-model.pt")
717 self.model.load_state_dict(current_state_dict)
719 if save_model_each_k_epochs > 0 and not epoch % save_model_each_k_epochs:
720 print("saving model of current epoch")
721 model_name = "model_epoch_" + str(epoch) + ".pt"
722 self.model.save(base_path / model_name, checkpoint=save_optimizer_state)
724 if use_swa:
725 optimizer.swap_swa_sgd()
727 # if we do not use dev data for model selection, save final model
728 if save_final_model and not param_selection_mode:
729 self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
731 except KeyboardInterrupt:
732 log_line(log)
733 log.info("Exiting from training early.")
735 if use_tensorboard:
736 writer.close()
738 if not param_selection_mode:
739 log.info("Saving model ...")
740 self.model.save(base_path / "final-model.pt", checkpoint=save_optimizer_state)
741 log.info("Done.")
743 # test best model if test data is present
744 if self.corpus.test and not train_with_test:
745 final_score = self.final_test(
746 base_path=base_path,
747 eval_mini_batch_size=mini_batch_chunk_size,
748 num_workers=num_workers,
749 main_evaluation_metric=main_evaluation_metric,
750 gold_label_dictionary_for_eval=gold_label_dictionary_for_eval,
751 )
752 else:
753 final_score = 0
754 log.info("Test data not provided setting final score to 0")
756 if create_file_logs:
757 log_handler.close()
758 log.removeHandler(log_handler)
760 if use_tensorboard:
761 writer.close()
763 return {
764 "test_score": final_score,
765 "dev_score_history": dev_score_history,
766 "train_loss_history": train_loss_history,
767 "dev_loss_history": dev_loss_history,
768 }
770 def resume(self,
771 model: Optional[Model],
772 **trainer_args,
773 ):
775 self.model = model
777 # recover all arguments that were used to train this model
778 args_used_to_train_model = self.model.model_card['training_parameters']
780 # you can overwrite params with your own
781 for param in trainer_args:
782 args_used_to_train_model[param] = trainer_args[param]
783 if param == 'optimizer' and 'optimizer_state_dict' in args_used_to_train_model:
784 del args_used_to_train_model['optimizer_state_dict']
785 if param == 'scheduler' and 'scheduler_state_dict' in args_used_to_train_model:
786 del args_used_to_train_model['scheduler_state_dict']
788 # surface nested arguments
789 kwargs = args_used_to_train_model['kwargs']
790 del args_used_to_train_model['kwargs']
792 # resume training with these parameters
793 self.train(**args_used_to_train_model, **kwargs)
795 def fine_tune(self,
796 base_path: Union[Path, str],
797 learning_rate: float = 5e-5,
798 max_epochs: int = 10,
799 optimizer=torch.optim.AdamW,
800 scheduler=LinearSchedulerWithWarmup,
801 warmup_fraction: float = 0.1,
802 mini_batch_size: int = 4,
803 embeddings_storage_mode: str = 'none',
804 use_final_model_for_eval: bool = True,
805 **trainer_args,
806 ):
808 return self.train(
809 base_path=base_path,
810 learning_rate=learning_rate,
811 max_epochs=max_epochs,
812 optimizer=optimizer,
813 scheduler=scheduler,
814 warmup_fraction=warmup_fraction,
815 mini_batch_size=mini_batch_size,
816 embeddings_storage_mode=embeddings_storage_mode,
817 use_final_model_for_eval=use_final_model_for_eval,
818 **trainer_args,
819 )
821 def final_test(
822 self,
823 base_path: Union[Path, str],
824 eval_mini_batch_size: int,
825 main_evaluation_metric: Tuple[str, str],
826 num_workers: int = 8,
827 gold_label_dictionary_for_eval: Optional[Dictionary] = None
828 ):
829 if type(base_path) is str:
830 base_path = Path(base_path)
831 base_path.mkdir(exist_ok=True, parents=True)
833 log_line(log)
835 self.model.eval()
837 if (base_path / "best-model.pt").exists():
838 self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
839 else:
840 log.info("Testing using last state of model ...")
842 test_results = self.model.evaluate(
843 self.corpus.test,
844 gold_label_type=self.model.label_type,
845 mini_batch_size=eval_mini_batch_size,
846 num_workers=num_workers,
847 out_path=base_path / "test.tsv",
848 embedding_storage_mode="none",
849 main_evaluation_metric=main_evaluation_metric,
850 gold_label_dictionary=gold_label_dictionary_for_eval,
851 )
853 test_results: Result = test_results
854 log.info(test_results.log_line)
855 log.info(test_results.detailed_results)
856 log_line(log)
858 # if we are training over multiple datasets, do evaluation for each
859 if type(self.corpus) is MultiCorpus:
860 for subcorpus in self.corpus.corpora:
861 log_line(log)
862 if subcorpus.test:
863 subcorpus_results = self.model.evaluate(
864 subcorpus.test,
865 gold_label_type=self.model.label_type,
866 mini_batch_size=eval_mini_batch_size,
867 num_workers=num_workers,
868 out_path=base_path / f"{subcorpus.name}-test.tsv",
869 embedding_storage_mode="none",
870 main_evaluation_metric=main_evaluation_metric
871 )
872 log.info(subcorpus.name)
873 log.info(subcorpus_results.log_line)
875 # get and return the final test score of best model
876 final_score = test_results.main_score
878 return final_score
880 def find_learning_rate(
881 self,
882 base_path: Union[Path, str],
883 optimizer,
884 mini_batch_size: int = 32,
885 start_learning_rate: float = 1e-7,
886 end_learning_rate: float = 10,
887 iterations: int = 1000,
888 stop_early: bool = True,
889 file_name: str = "learning_rate.tsv",
890 **kwargs,
891 ) -> Path:
892 best_loss = None
894 # cast string to Path
895 if type(base_path) is str:
896 base_path = Path(base_path)
897 base_path.mkdir(exist_ok=True, parents=True)
898 learning_rate_tsv = init_output_file(base_path, file_name)
900 with open(learning_rate_tsv, "a") as f:
901 f.write("ITERATION\tTIMESTAMP\tLEARNING_RATE\tTRAIN_LOSS\n")
903 optimizer = optimizer(self.model.parameters(), lr=start_learning_rate, **kwargs)
905 train_data = self.corpus.train
907 scheduler = ExpAnnealLR(optimizer, end_learning_rate, iterations)
909 model_state = self.model.state_dict()
910 self.model.train()
912 step = 0
914 loss_list = []
915 average_loss_list = []
917 while step < iterations:
919 batch_loader = DataLoader(train_data, batch_size=mini_batch_size, shuffle=True)
921 for batch in batch_loader:
922 step += 1
924 # forward pass
925 loss = self.model.forward_loss(batch)
926 if isinstance(loss, Tuple):
927 loss = loss[0]
929 # update optimizer and scheduler
930 optimizer.zero_grad()
931 loss.backward()
932 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
933 optimizer.step()
934 scheduler.step()
936 learning_rate = scheduler.get_lr()[0]
938 # append current loss to list of losses for all iterations
939 loss_list.append(loss.item())
941 # compute averaged loss
942 import statistics
943 moving_avg_loss = statistics.mean(loss_list)
944 average_loss_list.append(moving_avg_loss)
946 if len(average_loss_list) > 10:
947 drop = average_loss_list[-10] - moving_avg_loss
948 else:
949 drop = 0.
951 if not best_loss or moving_avg_loss < best_loss:
952 best_loss = moving_avg_loss
954 if step > iterations:
955 break
957 if stop_early and (moving_avg_loss > 4 * best_loss or torch.isnan(loss)):
958 log_line(log)
959 log.info("loss diverged - stopping early!")
960 step = iterations
961 break
963 with open(str(learning_rate_tsv), "a") as f:
964 f.write(f"{step}\t{learning_rate}\t{loss.item()}\t{moving_avg_loss}\t{drop}\n")
966 self.model.load_state_dict(model_state)
967 self.model.to(flair.device)
969 log_line(log)
970 log.info(f"learning rate finder finished - plot {learning_rate_tsv}")
971 log_line(log)
973 return Path(learning_rate_tsv)