Coverage for flair/flair/training_utils.py: 25%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import random
2import logging
3from collections import defaultdict
4from enum import Enum
5from math import inf
6from pathlib import Path
7from typing import Union, List
9from torch.optim import Optimizer
11import flair
12from flair.data import Dictionary, Sentence
13from functools import reduce
14from sklearn.metrics import mean_squared_error, mean_absolute_error
15from scipy.stats import pearsonr, spearmanr
18class Result(object):
19 def __init__(self,
20 main_score: float,
21 log_header: str,
22 log_line: str,
23 detailed_results: str,
24 loss: float,
25 classification_report: dict = None,
26 ):
27 self.main_score: float = main_score
28 self.log_header: str = log_header
29 self.log_line: str = log_line
30 self.detailed_results: str = detailed_results
31 self.classification_report: dict = classification_report
32 self.loss: float = loss
34 def __str__(self):
35 return f"{str(self.detailed_results)}\nLoss: {self.loss}'"
38class MetricRegression(object):
39 def __init__(self, name):
40 self.name = name
42 self.true = []
43 self.pred = []
45 def mean_squared_error(self):
46 return mean_squared_error(self.true, self.pred)
48 def mean_absolute_error(self):
49 return mean_absolute_error(self.true, self.pred)
51 def pearsonr(self):
52 return pearsonr(self.true, self.pred)[0]
54 def spearmanr(self):
55 return spearmanr(self.true, self.pred)[0]
57 ## dummy return to fulfill trainer.train() needs
58 def micro_avg_f_score(self):
59 return self.mean_squared_error()
61 def to_tsv(self):
62 return "{}\t{}\t{}\t{}".format(
63 self.mean_squared_error(),
64 self.mean_absolute_error(),
65 self.pearsonr(),
66 self.spearmanr(),
67 )
69 @staticmethod
70 def tsv_header(prefix=None):
71 if prefix:
72 return "{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN".format(
73 prefix
74 )
76 return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN"
78 @staticmethod
79 def to_empty_tsv():
80 return "\t_\t_\t_\t_"
82 def __str__(self):
83 line = "mean squared error: {0:.4f} - mean absolute error: {1:.4f} - pearson: {2:.4f} - spearman: {3:.4f}".format(
84 self.mean_squared_error(),
85 self.mean_absolute_error(),
86 self.pearsonr(),
87 self.spearmanr(),
88 )
89 return line
92class EvaluationMetric(Enum):
93 MICRO_ACCURACY = "micro-average accuracy"
94 MICRO_F1_SCORE = "micro-average f1-score"
95 MACRO_ACCURACY = "macro-average accuracy"
96 MACRO_F1_SCORE = "macro-average f1-score"
97 MEAN_SQUARED_ERROR = "mean squared error"
100class WeightExtractor(object):
101 def __init__(self, directory: Union[str, Path], number_of_weights: int = 10):
102 if type(directory) is str:
103 directory = Path(directory)
104 self.weights_file = init_output_file(directory, "weights.txt")
105 self.weights_dict = defaultdict(lambda: defaultdict(lambda: list()))
106 self.number_of_weights = number_of_weights
108 def extract_weights(self, state_dict, iteration):
109 for key in state_dict.keys():
111 vec = state_dict[key]
112 # print(vec)
113 try:
114 weights_to_watch = min(
115 self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size()))
116 )
117 except:
118 continue
120 if key not in self.weights_dict:
121 self._init_weights_index(key, state_dict, weights_to_watch)
123 for i in range(weights_to_watch):
124 vec = state_dict[key]
125 for index in self.weights_dict[key][i]:
126 vec = vec[index]
128 value = vec.item()
130 with open(self.weights_file, "a") as f:
131 f.write("{}\t{}\t{}\t{}\n".format(iteration, key, i, float(value)))
133 def _init_weights_index(self, key, state_dict, weights_to_watch):
134 indices = {}
136 i = 0
137 while len(indices) < weights_to_watch:
138 vec = state_dict[key]
139 cur_indices = []
141 for x in range(len(vec.size())):
142 index = random.randint(0, len(vec) - 1)
143 vec = vec[index]
144 cur_indices.append(index)
146 if cur_indices not in list(indices.values()):
147 indices[i] = cur_indices
148 i += 1
150 self.weights_dict[key] = indices
153class AnnealOnPlateau(object):
154 """This class is a modification of
155 torch.optim.lr_scheduler.ReduceLROnPlateau that enables
156 setting an "auxiliary metric" to break ties.
158 Reduce learning rate when a metric has stopped improving.
159 Models often benefit from reducing the learning rate by a factor
160 of 2-10 once learning stagnates. This scheduler reads a metrics
161 quantity and if no improvement is seen for a 'patience' number
162 of epochs, the learning rate is reduced.
164 Args:
165 optimizer (Optimizer): Wrapped optimizer.
166 mode (str): One of `min`, `max`. In `min` mode, lr will
167 be reduced when the quantity monitored has stopped
168 decreasing; in `max` mode it will be reduced when the
169 quantity monitored has stopped increasing. Default: 'min'.
170 factor (float): Factor by which the learning rate will be
171 reduced. new_lr = lr * factor. Default: 0.1.
172 patience (int): Number of epochs with no improvement after
173 which learning rate will be reduced. For example, if
174 `patience = 2`, then we will ignore the first 2 epochs
175 with no improvement, and will only decrease the LR after the
176 3rd epoch if the loss still hasn't improved then.
177 Default: 10.
178 verbose (bool): If ``True``, prints a message to stdout for
179 each update. Default: ``False``.
180 cooldown (int): Number of epochs to wait before resuming
181 normal operation after lr has been reduced. Default: 0.
182 min_lr (float or list): A scalar or a list of scalars. A
183 lower bound on the learning rate of all param groups
184 or each group respectively. Default: 0.
185 eps (float): Minimal decay applied to lr. If the difference
186 between new and old lr is smaller than eps, the update is
187 ignored. Default: 1e-8.
189 Example:
190 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
191 >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
192 >>> for epoch in range(10):
193 >>> train(...)
194 >>> val_loss = validate(...)
195 >>> # Note that step should be called after validate()
196 >>> scheduler.step(val_loss)
197 """
199 def __init__(self, optimizer, mode='min', aux_mode='min', factor=0.1, patience=10, initial_extra_patience=0,
200 verbose=False, cooldown=0, min_lr=0, eps=1e-8):
202 if factor >= 1.0:
203 raise ValueError('Factor should be < 1.0.')
204 self.factor = factor
206 # Attach optimizer
207 if not isinstance(optimizer, Optimizer):
208 raise TypeError('{} is not an Optimizer'.format(
209 type(optimizer).__name__))
210 self.optimizer = optimizer
212 if isinstance(min_lr, list) or isinstance(min_lr, tuple):
213 if len(min_lr) != len(optimizer.param_groups):
214 raise ValueError("expected {} min_lrs, got {}".format(
215 len(optimizer.param_groups), len(min_lr)))
216 self.min_lrs = list(min_lr)
217 else:
218 self.min_lrs = [min_lr] * len(optimizer.param_groups)
220 self.default_patience = patience
221 self.effective_patience = patience + initial_extra_patience
222 self.verbose = verbose
223 self.cooldown = cooldown
224 self.cooldown_counter = 0
225 self.mode = mode
226 self.aux_mode = aux_mode
227 self.best = None
228 self.best_aux = None
229 self.num_bad_epochs = None
230 self.mode_worse = None # the worse value for the chosen mode
231 self.eps = eps
232 self.last_epoch = 0
233 self._init_is_better(mode=mode)
234 self._reset()
236 def _reset(self):
237 """Resets num_bad_epochs counter and cooldown counter."""
238 self.best = self.mode_worse
239 self.cooldown_counter = 0
240 self.num_bad_epochs = 0
242 def step(self, metric, auxiliary_metric=None):
243 # convert `metrics` to float, in case it's a zero-dim Tensor
244 current = float(metric)
245 epoch = self.last_epoch + 1
246 self.last_epoch = epoch
248 is_better = False
250 if self.mode == 'min':
251 if current < self.best:
252 is_better = True
254 if self.mode == 'max':
255 if current > self.best:
256 is_better = True
258 if current == self.best and auxiliary_metric:
259 current_aux = float(auxiliary_metric)
260 if self.aux_mode == 'min':
261 if current_aux < self.best_aux:
262 is_better = True
264 if self.aux_mode == 'max':
265 if current_aux > self.best_aux:
266 is_better = True
268 if is_better:
269 self.best = current
270 if auxiliary_metric:
271 self.best_aux = auxiliary_metric
272 self.num_bad_epochs = 0
273 else:
274 self.num_bad_epochs += 1
276 if self.in_cooldown:
277 self.cooldown_counter -= 1
278 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
280 if self.num_bad_epochs > self.effective_patience:
281 self._reduce_lr(epoch)
282 self.cooldown_counter = self.cooldown
283 self.num_bad_epochs = 0
284 self.effective_patience = self.default_patience
286 self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
288 def _reduce_lr(self, epoch):
289 for i, param_group in enumerate(self.optimizer.param_groups):
290 old_lr = float(param_group['lr'])
291 new_lr = max(old_lr * self.factor, self.min_lrs[i])
292 if old_lr - new_lr > self.eps:
293 param_group['lr'] = new_lr
294 if self.verbose:
295 print('Epoch {:5d}: reducing learning rate'
296 ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
298 @property
299 def in_cooldown(self):
300 return self.cooldown_counter > 0
302 def _init_is_better(self, mode):
303 if mode not in {'min', 'max'}:
304 raise ValueError('mode ' + mode + ' is unknown!')
306 if mode == 'min':
307 self.mode_worse = inf
308 else: # mode == 'max':
309 self.mode_worse = -inf
311 self.mode = mode
313 def state_dict(self):
314 return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
316 def load_state_dict(self, state_dict):
317 self.__dict__.update(state_dict)
318 self._init_is_better(mode=self.mode)
321def init_output_file(base_path: Union[str, Path], file_name: str) -> Path:
322 """
323 Creates a local file.
324 :param base_path: the path to the directory
325 :param file_name: the file name
326 :return: the created file
327 """
328 if type(base_path) is str:
329 base_path = Path(base_path)
330 base_path.mkdir(parents=True, exist_ok=True)
332 file = base_path / file_name
333 open(file, "w", encoding="utf-8").close()
334 return file
337def convert_labels_to_one_hot(
338 label_list: List[List[str]], label_dict: Dictionary
339) -> List[List[int]]:
340 """
341 Convert list of labels (strings) to a one hot list.
342 :param label_list: list of labels
343 :param label_dict: label dictionary
344 :return: converted label list
345 """
346 return [
347 [1 if l in labels else 0 for l in label_dict.get_items()]
348 for labels in label_list
349 ]
352def log_line(log):
353 log.info("-" * 100)
356def add_file_handler(log, output_file):
357 init_output_file(output_file.parents[0], output_file.name)
358 fh = logging.FileHandler(output_file, mode="w", encoding="utf-8")
359 fh.setLevel(logging.INFO)
360 formatter = logging.Formatter("%(asctime)-15s %(message)s")
361 fh.setFormatter(formatter)
362 log.addHandler(fh)
363 return fh
366def store_embeddings(sentences: List[Sentence], storage_mode: str):
367 # if memory mode option 'none' delete everything
368 if storage_mode == "none":
369 for sentence in sentences:
370 sentence.clear_embeddings()
372 # else delete only dynamic embeddings (otherwise autograd will keep everything in memory)
373 else:
374 # find out which ones are dynamic embeddings
375 delete_keys = []
376 if type(sentences[0]) == Sentence:
377 for name, vector in sentences[0][0]._embeddings.items():
378 if sentences[0][0]._embeddings[name].requires_grad:
379 delete_keys.append(name)
381 # find out which ones are dynamic embeddings
382 for sentence in sentences:
383 sentence.clear_embeddings(delete_keys)
385 # memory management - option 1: send everything to CPU (pin to memory if we train on GPU)
386 if storage_mode == "cpu":
387 pin_memory = False if str(flair.device) == "cpu" else True
388 for sentence in sentences:
389 sentence.to("cpu", pin_memory=pin_memory)
391 # record current embedding storage mode to allow optimization (for instance in FlairEmbeddings class)
392 flair.embedding_storage_mode = storage_mode