Coverage for flair/flair/optim.py: 20%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import logging
2import math
4import torch
5from torch.optim import Optimizer
6from torch.optim.optimizer import required
7from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau, LambdaLR
10log = logging.getLogger("flair")
13class SGDW(Optimizer):
14 r"""Implements stochastic gradient descent (optionally with momentum) with
15 weight decay from the paper `Fixing Weight Decay Regularization in Adam`_.
17 Nesterov momentum is based on the formula from
18 `On the importance of initialization and momentum in deep learning`__.
20 Args:
21 params (iterable): iterable of parameters to optimize or dicts defining
22 parameter groups
23 lr (float): learning rate
24 momentum (float, optional): momentum factor (default: 0)
25 weight_decay (float, optional): weight decay factor (default: 0)
26 dampening (float, optional): dampening for momentum (default: 0)
27 nesterov (bool, optional): enables Nesterov momentum (default: False)
29 .. _Fixing Weight Decay Regularization in Adam:
30 https://arxiv.org/abs/1711.05101
32 Example:
33 >>> optimizer = torch.optim.SGDW(model.parameters(), lr=0.1, momentum=0.9,
34 weight_decay=1e-5)
35 >>> optimizer.zero_grad()
36 >>> loss_fn(model(input), target).backward()
37 >>> optimizer.step()
39 __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
41 .. note::
42 The implementation of SGD with Momentum/Nesterov subtly differs from
43 Sutskever et. al. and implementations in some other frameworks.
45 Considering the specific case of Momentum, the update can be written as
47 .. math::
48 v = \rho * v + g \\
49 p = p - lr * v
51 where p, g, v and :math:`\rho` denote the parameters, gradient,
52 velocity, and momentum respectively.
54 This is in contrast to Sutskever et. al. and
55 other frameworks which employ an update of the form
57 .. math::
58 v = \rho * v + lr * g \\
59 p = p - v
61 The Nesterov version is analogously modified.
62 """
64 def __init__(
65 self,
66 params,
67 lr=required,
68 momentum=0,
69 dampening=0,
70 weight_decay=0,
71 nesterov=False,
72 ):
73 if lr is not required and lr < 0.0:
74 raise ValueError("Invalid learning rate: {}".format(lr))
75 if momentum < 0.0:
76 raise ValueError("Invalid momentum value: {}".format(momentum))
77 if weight_decay < 0.0:
78 raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
80 defaults = dict(
81 lr=lr,
82 momentum=momentum,
83 dampening=dampening,
84 weight_decay=weight_decay,
85 nesterov=nesterov,
86 )
87 if nesterov and (momentum <= 0 or dampening != 0):
88 raise ValueError("Nesterov momentum requires a momentum and zero dampening")
89 super(SGDW, self).__init__(params, defaults)
91 def __setstate__(self, state):
92 super(SGDW, self).__setstate__(state)
93 for group in self.param_groups:
94 group.setdefault("nesterov", False)
96 def step(self, closure=None):
97 """Performs a single optimization step.
99 Arguments:
100 closure (callable, optional): A closure that reevaluates the model
101 and returns the loss.
102 """
103 loss = None
104 if closure is not None:
105 loss = closure()
107 for group in self.param_groups:
108 weight_decay = group["weight_decay"]
109 momentum = group["momentum"]
110 dampening = group["dampening"]
111 nesterov = group["nesterov"]
113 for p in group["params"]:
114 if p.grad is None:
115 continue
116 d_p = p.grad.data
118 if momentum != 0:
119 param_state = self.state[p]
120 if "momentum_buffer" not in param_state:
121 buf = param_state["momentum_buffer"] = torch.zeros_like(p.data)
122 buf.mul_(momentum).add_(d_p)
123 else:
124 buf = param_state["momentum_buffer"]
125 buf.mul_(momentum).add_(1 - dampening, d_p)
126 if nesterov:
127 d_p = d_p.add(momentum, buf)
128 else:
129 d_p = buf
131 if weight_decay != 0:
132 p.data.add_(-weight_decay, p.data)
134 p.data.add_(-group["lr"], d_p)
136 return loss
139class ExpAnnealLR(_LRScheduler):
140 """Exponentially anneal the learning rate of each parameter group
141 from the initial lr to end_lr over a number of iterations.
143 Args:
144 optimizer (Optimizer): Wrapped optimizer.
145 end_lr (float): The final learning rate.
146 iterations (int): The number of iterations over which to increase the
147 learning rate.
148 last_epoch (int): The index of the last iteration. Default: -1.
149 """
151 def __init__(self, optimizer, end_lr, iterations, last_epoch=-1):
152 self.end_lr = end_lr
153 self.iterations = iterations
154 super(ExpAnnealLR, self).__init__(optimizer, last_epoch=last_epoch)
156 def get_lr(self):
157 iteration = self.last_epoch + 1
158 pct = iteration / self.iterations
159 return [base_lr * (self.end_lr / base_lr) ** pct for base_lr in self.base_lrs]
162class LinearSchedulerWithWarmup(LambdaLR):
163 """Linearly increase the learning from 0 to initial learning rate during warmup
164 and decrease the learning rate to 0 after the warmup. Uses LambaLR scheduler
165 where the learning rate is multiplied by a lambda factor after calling scheduler.step().
167 Args:
168 optimizer (Optimizer): Wrapped optimizer.
169 num_train_steps (int): total number of training steps (number of batches * epochs).
170 num_warmup_steps (int): number of training steps for learning rate warmup.
171 last_epoch (int): The index of the last iteration. Default: -1. The scheduler
172 will simply restart when resuming training from a checkpoint.
173 """
175 def __init__(self, optimizer, num_train_steps, num_warmup_steps, last_epoch=-1):
177 def linear_lr_lambda(current_step: int):
178 lambda_during_warmup = float(current_step) / float(max(1, num_warmup_steps))
179 lambda_after_warmup = max(
180 0.0, float(num_train_steps - current_step) /
181 float(max(1, num_train_steps - num_warmup_steps))
182 )
183 if current_step < num_warmup_steps:
184 return lambda_during_warmup
185 return lambda_after_warmup
187 super(LinearSchedulerWithWarmup, self).__init__(optimizer,
188 lr_lambda=linear_lr_lambda,
189 last_epoch=last_epoch)
192class ReduceLRWDOnPlateau(ReduceLROnPlateau):
193 """Reduce learning rate and weight decay when a metric has stopped
194 improving. Models often benefit from reducing the learning rate by
195 a factor of 2-10 once learning stagnates. This scheduler reads a metric
196 quantity and if no improvement is seen for a 'patience' number
197 of epochs, the learning rate and weight decay factor is reduced for
198 optimizers that implement the the weight decay method from the paper
199 `Fixing Weight Decay Regularization in Adam`_.
201 .. _Fixing Weight Decay Regularization in Adam:
202 https://arxiv.org/abs/1711.05101
204 Args:
205 optimizer (Optimizer): Wrapped optimizer.
206 mode (str): One of `min`, `max`. In `min` mode, lr will
207 be reduced when the quantity monitored has stopped
208 decreasing; in `max` mode it will be reduced when the
209 quantity monitored has stopped increasing. Default: 'min'.
210 factor (float): Factor by which the learning rate will be
211 reduced. new_lr = lr * factor. Default: 0.1.
212 patience (int): Number of epochs with no improvement after
213 which learning rate will be reduced. For example, if
214 `patience = 2`, then we will ignore the first 2 epochs
215 with no improvement, and will only decrease the LR after the
216 3rd epoch if the loss still hasn't improved then.
217 Default: 10.
218 verbose (bool): If ``True``, prints a message to stdout for
219 each update. Default: ``False``.
220 threshold (float): Threshold for measuring the new optimum,
221 to only focus on significant changes. Default: 1e-4.
222 threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
223 dynamic_threshold = best * ( 1 + threshold ) in 'max'
224 mode or best * ( 1 - threshold ) in `min` mode.
225 In `abs` mode, dynamic_threshold = best + threshold in
226 `max` mode or best - threshold in `min` mode. Default: 'rel'.
227 cooldown (int): Number of epochs to wait before resuming
228 normal operation after lr has been reduced. Default: 0.
229 min_lr (float or list): A scalar or a list of scalars. A
230 lower bound on the learning rate of all param groups
231 or each group respectively. Default: 0.
232 eps (float): Minimal decay applied to lr. If the difference
233 between new and old lr is smaller than eps, the update is
234 ignored. Default: 1e-8.
236 Example:
237 >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3)
238 >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min')
239 >>> for epoch in range(10):
240 >>> train(...)
241 >>> val_loss = validate(...)
242 >>> # Note that step should be called after validate()
243 >>> scheduler.step(val_loss)
244 """
246 def step(self, metrics, epoch=None):
247 current = metrics
248 if epoch is None:
249 epoch = self.last_epoch = self.last_epoch + 1
250 self.last_epoch = epoch
252 if self.is_better(current, self.best):
253 self.best = current
254 self.num_bad_epochs = 0
255 else:
256 self.num_bad_epochs += 1
258 if self.in_cooldown:
259 self.cooldown_counter -= 1
260 self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
262 if self.num_bad_epochs > self.patience:
263 self._reduce_lr(epoch)
264 self._reduce_weight_decay(epoch)
265 self.cooldown_counter = self.cooldown
266 self.num_bad_epochs = 0
268 def _reduce_weight_decay(self, epoch):
269 for i, param_group in enumerate(self.optimizer.param_groups):
270 if param_group["weight_decay"] != 0:
271 old_weight_decay = float(param_group["weight_decay"])
272 new_weight_decay = max(old_weight_decay * self.factor, self.min_lrs[i])
273 if old_weight_decay - new_weight_decay > self.eps:
274 param_group["weight_decay"] = new_weight_decay
275 if self.verbose:
276 log.info(
277 f"Epoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}."
278 )