Coverage for flair/flair/training_utils.py: 66%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

218 statements  

1import random 

2import logging 

3from collections import defaultdict 

4from enum import Enum 

5from math import inf 

6from pathlib import Path 

7from typing import Union, List 

8 

9from torch.optim import Optimizer 

10 

11import flair 

12from flair.data import Dictionary, Sentence 

13from functools import reduce 

14from sklearn.metrics import mean_squared_error, mean_absolute_error 

15from scipy.stats import pearsonr, spearmanr 

16 

17 

18class Result(object): 

19 def __init__(self, 

20 main_score: float, 

21 log_header: str, 

22 log_line: str, 

23 detailed_results: str, 

24 loss: float, 

25 classification_report: dict = None, 

26 ): 

27 self.main_score: float = main_score 

28 self.log_header: str = log_header 

29 self.log_line: str = log_line 

30 self.detailed_results: str = detailed_results 

31 self.classification_report: dict = classification_report 

32 self.loss: float = loss 

33 

34 def __str__(self): 

35 return f"{str(self.detailed_results)}\nLoss: {self.loss}'" 

36 

37 

38class MetricRegression(object): 

39 def __init__(self, name): 

40 self.name = name 

41 

42 self.true = [] 

43 self.pred = [] 

44 

45 def mean_squared_error(self): 

46 return mean_squared_error(self.true, self.pred) 

47 

48 def mean_absolute_error(self): 

49 return mean_absolute_error(self.true, self.pred) 

50 

51 def pearsonr(self): 

52 return pearsonr(self.true, self.pred)[0] 

53 

54 def spearmanr(self): 

55 return spearmanr(self.true, self.pred)[0] 

56 

57 ## dummy return to fulfill trainer.train() needs 

58 def micro_avg_f_score(self): 

59 return self.mean_squared_error() 

60 

61 def to_tsv(self): 

62 return "{}\t{}\t{}\t{}".format( 

63 self.mean_squared_error(), 

64 self.mean_absolute_error(), 

65 self.pearsonr(), 

66 self.spearmanr(), 

67 ) 

68 

69 @staticmethod 

70 def tsv_header(prefix=None): 

71 if prefix: 

72 return "{0}_MEAN_SQUARED_ERROR\t{0}_MEAN_ABSOLUTE_ERROR\t{0}_PEARSON\t{0}_SPEARMAN".format( 

73 prefix 

74 ) 

75 

76 return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN" 

77 

78 @staticmethod 

79 def to_empty_tsv(): 

80 return "\t_\t_\t_\t_" 

81 

82 def __str__(self): 

83 line = "mean squared error: {0:.4f} - mean absolute error: {1:.4f} - pearson: {2:.4f} - spearman: {3:.4f}".format( 

84 self.mean_squared_error(), 

85 self.mean_absolute_error(), 

86 self.pearsonr(), 

87 self.spearmanr(), 

88 ) 

89 return line 

90 

91 

92class EvaluationMetric(Enum): 

93 MICRO_ACCURACY = "micro-average accuracy" 

94 MICRO_F1_SCORE = "micro-average f1-score" 

95 MACRO_ACCURACY = "macro-average accuracy" 

96 MACRO_F1_SCORE = "macro-average f1-score" 

97 MEAN_SQUARED_ERROR = "mean squared error" 

98 

99 

100class WeightExtractor(object): 

101 def __init__(self, directory: Union[str, Path], number_of_weights: int = 10): 

102 if type(directory) is str: 

103 directory = Path(directory) 

104 self.weights_file = init_output_file(directory, "weights.txt") 

105 self.weights_dict = defaultdict(lambda: defaultdict(lambda: list())) 

106 self.number_of_weights = number_of_weights 

107 

108 def extract_weights(self, state_dict, iteration): 

109 for key in state_dict.keys(): 

110 

111 vec = state_dict[key] 

112 # print(vec) 

113 try: 

114 weights_to_watch = min( 

115 self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size())) 

116 ) 

117 except: 

118 continue 

119 

120 if key not in self.weights_dict: 

121 self._init_weights_index(key, state_dict, weights_to_watch) 

122 

123 for i in range(weights_to_watch): 

124 vec = state_dict[key] 

125 for index in self.weights_dict[key][i]: 

126 vec = vec[index] 

127 

128 value = vec.item() 

129 

130 with open(self.weights_file, "a") as f: 

131 f.write("{}\t{}\t{}\t{}\n".format(iteration, key, i, float(value))) 

132 

133 def _init_weights_index(self, key, state_dict, weights_to_watch): 

134 indices = {} 

135 

136 i = 0 

137 while len(indices) < weights_to_watch: 

138 vec = state_dict[key] 

139 cur_indices = [] 

140 

141 for x in range(len(vec.size())): 

142 index = random.randint(0, len(vec) - 1) 

143 vec = vec[index] 

144 cur_indices.append(index) 

145 

146 if cur_indices not in list(indices.values()): 

147 indices[i] = cur_indices 

148 i += 1 

149 

150 self.weights_dict[key] = indices 

151 

152 

153class AnnealOnPlateau(object): 

154 """This class is a modification of 

155 torch.optim.lr_scheduler.ReduceLROnPlateau that enables 

156 setting an "auxiliary metric" to break ties. 

157 

158 Reduce learning rate when a metric has stopped improving. 

159 Models often benefit from reducing the learning rate by a factor 

160 of 2-10 once learning stagnates. This scheduler reads a metrics 

161 quantity and if no improvement is seen for a 'patience' number 

162 of epochs, the learning rate is reduced. 

163 

164 Args: 

165 optimizer (Optimizer): Wrapped optimizer. 

166 mode (str): One of `min`, `max`. In `min` mode, lr will 

167 be reduced when the quantity monitored has stopped 

168 decreasing; in `max` mode it will be reduced when the 

169 quantity monitored has stopped increasing. Default: 'min'. 

170 factor (float): Factor by which the learning rate will be 

171 reduced. new_lr = lr * factor. Default: 0.1. 

172 patience (int): Number of epochs with no improvement after 

173 which learning rate will be reduced. For example, if 

174 `patience = 2`, then we will ignore the first 2 epochs 

175 with no improvement, and will only decrease the LR after the 

176 3rd epoch if the loss still hasn't improved then. 

177 Default: 10. 

178 verbose (bool): If ``True``, prints a message to stdout for 

179 each update. Default: ``False``. 

180 cooldown (int): Number of epochs to wait before resuming 

181 normal operation after lr has been reduced. Default: 0. 

182 min_lr (float or list): A scalar or a list of scalars. A 

183 lower bound on the learning rate of all param groups 

184 or each group respectively. Default: 0. 

185 eps (float): Minimal decay applied to lr. If the difference 

186 between new and old lr is smaller than eps, the update is 

187 ignored. Default: 1e-8. 

188 

189 Example: 

190 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 

191 >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 

192 >>> for epoch in range(10): 

193 >>> train(...) 

194 >>> val_loss = validate(...) 

195 >>> # Note that step should be called after validate() 

196 >>> scheduler.step(val_loss) 

197 """ 

198 

199 def __init__(self, optimizer, mode='min', aux_mode='min', factor=0.1, patience=10, initial_extra_patience=0, 

200 verbose=False, cooldown=0, min_lr=0, eps=1e-8): 

201 

202 if factor >= 1.0: 

203 raise ValueError('Factor should be < 1.0.') 

204 self.factor = factor 

205 

206 # Attach optimizer 

207 if not isinstance(optimizer, Optimizer): 

208 raise TypeError('{} is not an Optimizer'.format( 

209 type(optimizer).__name__)) 

210 self.optimizer = optimizer 

211 

212 if isinstance(min_lr, list) or isinstance(min_lr, tuple): 

213 if len(min_lr) != len(optimizer.param_groups): 

214 raise ValueError("expected {} min_lrs, got {}".format( 

215 len(optimizer.param_groups), len(min_lr))) 

216 self.min_lrs = list(min_lr) 

217 else: 

218 self.min_lrs = [min_lr] * len(optimizer.param_groups) 

219 

220 self.default_patience = patience 

221 self.effective_patience = patience + initial_extra_patience 

222 self.verbose = verbose 

223 self.cooldown = cooldown 

224 self.cooldown_counter = 0 

225 self.mode = mode 

226 self.aux_mode = aux_mode 

227 self.best = None 

228 self.best_aux = None 

229 self.num_bad_epochs = None 

230 self.mode_worse = None # the worse value for the chosen mode 

231 self.eps = eps 

232 self.last_epoch = 0 

233 self._init_is_better(mode=mode) 

234 self._reset() 

235 

236 def _reset(self): 

237 """Resets num_bad_epochs counter and cooldown counter.""" 

238 self.best = self.mode_worse 

239 self.cooldown_counter = 0 

240 self.num_bad_epochs = 0 

241 

242 def step(self, metric, auxiliary_metric=None): 

243 # convert `metrics` to float, in case it's a zero-dim Tensor 

244 current = float(metric) 

245 epoch = self.last_epoch + 1 

246 self.last_epoch = epoch 

247 

248 is_better = False 

249 

250 if self.mode == 'min': 

251 if current < self.best: 

252 is_better = True 

253 

254 if self.mode == 'max': 

255 if current > self.best: 

256 is_better = True 

257 

258 if current == self.best and auxiliary_metric: 

259 current_aux = float(auxiliary_metric) 

260 if self.aux_mode == 'min': 

261 if current_aux < self.best_aux: 

262 is_better = True 

263 

264 if self.aux_mode == 'max': 

265 if current_aux > self.best_aux: 

266 is_better = True 

267 

268 if is_better: 

269 self.best = current 

270 if auxiliary_metric: 

271 self.best_aux = auxiliary_metric 

272 self.num_bad_epochs = 0 

273 else: 

274 self.num_bad_epochs += 1 

275 

276 if self.in_cooldown: 

277 self.cooldown_counter -= 1 

278 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 

279 

280 if self.num_bad_epochs > self.effective_patience: 

281 self._reduce_lr(epoch) 

282 self.cooldown_counter = self.cooldown 

283 self.num_bad_epochs = 0 

284 self.effective_patience = self.default_patience 

285 

286 self._last_lr = [group['lr'] for group in self.optimizer.param_groups] 

287 

288 def _reduce_lr(self, epoch): 

289 for i, param_group in enumerate(self.optimizer.param_groups): 

290 old_lr = float(param_group['lr']) 

291 new_lr = max(old_lr * self.factor, self.min_lrs[i]) 

292 if old_lr - new_lr > self.eps: 

293 param_group['lr'] = new_lr 

294 if self.verbose: 

295 print('Epoch {:5d}: reducing learning rate' 

296 ' of group {} to {:.4e}.'.format(epoch, i, new_lr)) 

297 

298 @property 

299 def in_cooldown(self): 

300 return self.cooldown_counter > 0 

301 

302 def _init_is_better(self, mode): 

303 if mode not in {'min', 'max'}: 

304 raise ValueError('mode ' + mode + ' is unknown!') 

305 

306 if mode == 'min': 

307 self.mode_worse = inf 

308 else: # mode == 'max': 

309 self.mode_worse = -inf 

310 

311 self.mode = mode 

312 

313 def state_dict(self): 

314 return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 

315 

316 def load_state_dict(self, state_dict): 

317 self.__dict__.update(state_dict) 

318 self._init_is_better(mode=self.mode) 

319 

320 

321def init_output_file(base_path: Union[str, Path], file_name: str) -> Path: 

322 """ 

323 Creates a local file. 

324 :param base_path: the path to the directory 

325 :param file_name: the file name 

326 :return: the created file 

327 """ 

328 if type(base_path) is str: 

329 base_path = Path(base_path) 

330 base_path.mkdir(parents=True, exist_ok=True) 

331 

332 file = base_path / file_name 

333 open(file, "w", encoding="utf-8").close() 

334 return file 

335 

336 

337def convert_labels_to_one_hot( 

338 label_list: List[List[str]], label_dict: Dictionary 

339) -> List[List[int]]: 

340 """ 

341 Convert list of labels (strings) to a one hot list. 

342 :param label_list: list of labels 

343 :param label_dict: label dictionary 

344 :return: converted label list 

345 """ 

346 return [ 

347 [1 if l in labels else 0 for l in label_dict.get_items()] 

348 for labels in label_list 

349 ] 

350 

351 

352def log_line(log): 

353 log.info("-" * 100) 

354 

355 

356def add_file_handler(log, output_file): 

357 init_output_file(output_file.parents[0], output_file.name) 

358 fh = logging.FileHandler(output_file, mode="w", encoding="utf-8") 

359 fh.setLevel(logging.INFO) 

360 formatter = logging.Formatter("%(asctime)-15s %(message)s") 

361 fh.setFormatter(formatter) 

362 log.addHandler(fh) 

363 return fh 

364 

365 

366def store_embeddings(sentences: List[Sentence], storage_mode: str): 

367 # if memory mode option 'none' delete everything 

368 if storage_mode == "none": 

369 for sentence in sentences: 

370 sentence.clear_embeddings() 

371 

372 # else delete only dynamic embeddings (otherwise autograd will keep everything in memory) 

373 else: 

374 # find out which ones are dynamic embeddings 

375 delete_keys = [] 

376 if type(sentences[0]) == Sentence: 

377 for name, vector in sentences[0][0]._embeddings.items(): 

378 if sentences[0][0]._embeddings[name].requires_grad: 

379 delete_keys.append(name) 

380 

381 # find out which ones are dynamic embeddings 

382 for sentence in sentences: 

383 sentence.clear_embeddings(delete_keys) 

384 

385 # memory management - option 1: send everything to CPU (pin to memory if we train on GPU) 

386 if storage_mode == "cpu": 

387 pin_memory = False if str(flair.device) == "cpu" else True 

388 for sentence in sentences: 

389 sentence.to("cpu", pin_memory=pin_memory) 

390 

391 # record current embedding storage mode to allow optimization (for instance in FlairEmbeddings class) 

392 flair.embedding_storage_mode = storage_mode