Coverage for flair/flair/optim.py: 20%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

95 statements  

1import logging 

2import math 

3 

4import torch 

5from torch.optim import Optimizer 

6from torch.optim.optimizer import required 

7from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau, LambdaLR 

8 

9 

10log = logging.getLogger("flair") 

11 

12 

13class SGDW(Optimizer): 

14 r"""Implements stochastic gradient descent (optionally with momentum) with 

15 weight decay from the paper `Fixing Weight Decay Regularization in Adam`_. 

16 

17 Nesterov momentum is based on the formula from 

18 `On the importance of initialization and momentum in deep learning`__. 

19 

20 Args: 

21 params (iterable): iterable of parameters to optimize or dicts defining 

22 parameter groups 

23 lr (float): learning rate 

24 momentum (float, optional): momentum factor (default: 0) 

25 weight_decay (float, optional): weight decay factor (default: 0) 

26 dampening (float, optional): dampening for momentum (default: 0) 

27 nesterov (bool, optional): enables Nesterov momentum (default: False) 

28 

29 .. _Fixing Weight Decay Regularization in Adam: 

30 https://arxiv.org/abs/1711.05101 

31 

32 Example: 

33 >>> optimizer = torch.optim.SGDW(model.parameters(), lr=0.1, momentum=0.9, 

34 weight_decay=1e-5) 

35 >>> optimizer.zero_grad() 

36 >>> loss_fn(model(input), target).backward() 

37 >>> optimizer.step() 

38 

39 __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf 

40 

41 .. note:: 

42 The implementation of SGD with Momentum/Nesterov subtly differs from 

43 Sutskever et. al. and implementations in some other frameworks. 

44 

45 Considering the specific case of Momentum, the update can be written as 

46 

47 .. math:: 

48 v = \rho * v + g \\ 

49 p = p - lr * v 

50 

51 where p, g, v and :math:`\rho` denote the parameters, gradient, 

52 velocity, and momentum respectively. 

53 

54 This is in contrast to Sutskever et. al. and 

55 other frameworks which employ an update of the form 

56 

57 .. math:: 

58 v = \rho * v + lr * g \\ 

59 p = p - v 

60 

61 The Nesterov version is analogously modified. 

62 """ 

63 

64 def __init__( 

65 self, 

66 params, 

67 lr=required, 

68 momentum=0, 

69 dampening=0, 

70 weight_decay=0, 

71 nesterov=False, 

72 ): 

73 if lr is not required and lr < 0.0: 

74 raise ValueError("Invalid learning rate: {}".format(lr)) 

75 if momentum < 0.0: 

76 raise ValueError("Invalid momentum value: {}".format(momentum)) 

77 if weight_decay < 0.0: 

78 raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 

79 

80 defaults = dict( 

81 lr=lr, 

82 momentum=momentum, 

83 dampening=dampening, 

84 weight_decay=weight_decay, 

85 nesterov=nesterov, 

86 ) 

87 if nesterov and (momentum <= 0 or dampening != 0): 

88 raise ValueError("Nesterov momentum requires a momentum and zero dampening") 

89 super(SGDW, self).__init__(params, defaults) 

90 

91 def __setstate__(self, state): 

92 super(SGDW, self).__setstate__(state) 

93 for group in self.param_groups: 

94 group.setdefault("nesterov", False) 

95 

96 def step(self, closure=None): 

97 """Performs a single optimization step. 

98 

99 Arguments: 

100 closure (callable, optional): A closure that reevaluates the model 

101 and returns the loss. 

102 """ 

103 loss = None 

104 if closure is not None: 

105 loss = closure() 

106 

107 for group in self.param_groups: 

108 weight_decay = group["weight_decay"] 

109 momentum = group["momentum"] 

110 dampening = group["dampening"] 

111 nesterov = group["nesterov"] 

112 

113 for p in group["params"]: 

114 if p.grad is None: 

115 continue 

116 d_p = p.grad.data 

117 

118 if momentum != 0: 

119 param_state = self.state[p] 

120 if "momentum_buffer" not in param_state: 

121 buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 

122 buf.mul_(momentum).add_(d_p) 

123 else: 

124 buf = param_state["momentum_buffer"] 

125 buf.mul_(momentum).add_(1 - dampening, d_p) 

126 if nesterov: 

127 d_p = d_p.add(momentum, buf) 

128 else: 

129 d_p = buf 

130 

131 if weight_decay != 0: 

132 p.data.add_(-weight_decay, p.data) 

133 

134 p.data.add_(-group["lr"], d_p) 

135 

136 return loss 

137 

138 

139class ExpAnnealLR(_LRScheduler): 

140 """Exponentially anneal the learning rate of each parameter group 

141 from the initial lr to end_lr over a number of iterations. 

142 

143 Args: 

144 optimizer (Optimizer): Wrapped optimizer. 

145 end_lr (float): The final learning rate. 

146 iterations (int): The number of iterations over which to increase the 

147 learning rate. 

148 last_epoch (int): The index of the last iteration. Default: -1. 

149 """ 

150 

151 def __init__(self, optimizer, end_lr, iterations, last_epoch=-1): 

152 self.end_lr = end_lr 

153 self.iterations = iterations 

154 super(ExpAnnealLR, self).__init__(optimizer, last_epoch=last_epoch) 

155 

156 def get_lr(self): 

157 iteration = self.last_epoch + 1 

158 pct = iteration / self.iterations 

159 return [base_lr * (self.end_lr / base_lr) ** pct for base_lr in self.base_lrs] 

160 

161 

162class LinearSchedulerWithWarmup(LambdaLR): 

163 """Linearly increase the learning from 0 to initial learning rate during warmup 

164 and decrease the learning rate to 0 after the warmup. Uses LambaLR scheduler 

165 where the learning rate is multiplied by a lambda factor after calling scheduler.step(). 

166 

167 Args: 

168 optimizer (Optimizer): Wrapped optimizer. 

169 num_train_steps (int): total number of training steps (number of batches * epochs). 

170 num_warmup_steps (int): number of training steps for learning rate warmup. 

171 last_epoch (int): The index of the last iteration. Default: -1. The scheduler 

172 will simply restart when resuming training from a checkpoint. 

173 """ 

174 

175 def __init__(self, optimizer, num_train_steps, num_warmup_steps, last_epoch=-1): 

176 

177 def linear_lr_lambda(current_step: int): 

178 lambda_during_warmup = float(current_step) / float(max(1, num_warmup_steps)) 

179 lambda_after_warmup = max( 

180 0.0, float(num_train_steps - current_step) / 

181 float(max(1, num_train_steps - num_warmup_steps)) 

182 ) 

183 if current_step < num_warmup_steps: 

184 return lambda_during_warmup 

185 return lambda_after_warmup 

186 

187 super(LinearSchedulerWithWarmup, self).__init__(optimizer, 

188 lr_lambda=linear_lr_lambda, 

189 last_epoch=last_epoch) 

190 

191 

192class ReduceLRWDOnPlateau(ReduceLROnPlateau): 

193 """Reduce learning rate and weight decay when a metric has stopped 

194 improving. Models often benefit from reducing the learning rate by 

195 a factor of 2-10 once learning stagnates. This scheduler reads a metric 

196 quantity and if no improvement is seen for a 'patience' number 

197 of epochs, the learning rate and weight decay factor is reduced for 

198 optimizers that implement the the weight decay method from the paper 

199 `Fixing Weight Decay Regularization in Adam`_. 

200 

201 .. _Fixing Weight Decay Regularization in Adam: 

202 https://arxiv.org/abs/1711.05101 

203 

204 Args: 

205 optimizer (Optimizer): Wrapped optimizer. 

206 mode (str): One of `min`, `max`. In `min` mode, lr will 

207 be reduced when the quantity monitored has stopped 

208 decreasing; in `max` mode it will be reduced when the 

209 quantity monitored has stopped increasing. Default: 'min'. 

210 factor (float): Factor by which the learning rate will be 

211 reduced. new_lr = lr * factor. Default: 0.1. 

212 patience (int): Number of epochs with no improvement after 

213 which learning rate will be reduced. For example, if 

214 `patience = 2`, then we will ignore the first 2 epochs 

215 with no improvement, and will only decrease the LR after the 

216 3rd epoch if the loss still hasn't improved then. 

217 Default: 10. 

218 verbose (bool): If ``True``, prints a message to stdout for 

219 each update. Default: ``False``. 

220 threshold (float): Threshold for measuring the new optimum, 

221 to only focus on significant changes. Default: 1e-4. 

222 threshold_mode (str): One of `rel`, `abs`. In `rel` mode, 

223 dynamic_threshold = best * ( 1 + threshold ) in 'max' 

224 mode or best * ( 1 - threshold ) in `min` mode. 

225 In `abs` mode, dynamic_threshold = best + threshold in 

226 `max` mode or best - threshold in `min` mode. Default: 'rel'. 

227 cooldown (int): Number of epochs to wait before resuming 

228 normal operation after lr has been reduced. Default: 0. 

229 min_lr (float or list): A scalar or a list of scalars. A 

230 lower bound on the learning rate of all param groups 

231 or each group respectively. Default: 0. 

232 eps (float): Minimal decay applied to lr. If the difference 

233 between new and old lr is smaller than eps, the update is 

234 ignored. Default: 1e-8. 

235 

236 Example: 

237 >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 

238 >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 

239 >>> for epoch in range(10): 

240 >>> train(...) 

241 >>> val_loss = validate(...) 

242 >>> # Note that step should be called after validate() 

243 >>> scheduler.step(val_loss) 

244 """ 

245 

246 def step(self, metrics, epoch=None): 

247 current = metrics 

248 if epoch is None: 

249 epoch = self.last_epoch = self.last_epoch + 1 

250 self.last_epoch = epoch 

251 

252 if self.is_better(current, self.best): 

253 self.best = current 

254 self.num_bad_epochs = 0 

255 else: 

256 self.num_bad_epochs += 1 

257 

258 if self.in_cooldown: 

259 self.cooldown_counter -= 1 

260 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown 

261 

262 if self.num_bad_epochs > self.patience: 

263 self._reduce_lr(epoch) 

264 self._reduce_weight_decay(epoch) 

265 self.cooldown_counter = self.cooldown 

266 self.num_bad_epochs = 0 

267 

268 def _reduce_weight_decay(self, epoch): 

269 for i, param_group in enumerate(self.optimizer.param_groups): 

270 if param_group["weight_decay"] != 0: 

271 old_weight_decay = float(param_group["weight_decay"]) 

272 new_weight_decay = max(old_weight_decay * self.factor, self.min_lrs[i]) 

273 if old_weight_decay - new_weight_decay > self.eps: 

274 param_group["weight_decay"] = new_weight_decay 

275 if self.verbose: 

276 log.info( 

277 f"Epoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}." 

278 )