Coverage for tvo/utils/CyclicLR.py: 58%
80 statements
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
« prev ^ index » next coverage.py v7.4.3, created at 2024-03-01 11:33 +0000
1# This is a backport from pytorch:
2# https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py
3#
4# TODO: Switch to the upstream version when:
5# 1. pytorch makes a release that contains the fix for
6# https://github.com/pytorch/pytorch/issues/19003, and
7# 2. GOLD supports that pytorch version
8#
9# LICENSING AND COPYRIGHT FOR THIS FILE:
10# See pytorch's license at https://github.com/pytorch/pytorch/blob/master/LICENSE
12from torch.optim.optimizer import Optimizer
13from torch.optim.lr_scheduler import _LRScheduler
14import math
17class CyclicLR(_LRScheduler):
18 """Sets the learning rate of each parameter group according to
19 cyclical learning rate policy (CLR). The policy cycles the learning
20 rate between two boundaries with a constant frequency, as detailed in
21 the paper `Cyclical Learning Rates for Training Neural Networks`_.
22 The distance between the two boundaries can be scaled on a per-iteration
23 or per-cycle basis.
24 Cyclical learning rate policy changes the learning rate after every batch.
25 `step` should be called after a batch has been used for training.
26 This class has three built-in policies, as put forth in the paper:
27 "triangular":
28 A basic triangular cycle w/ no amplitude scaling.
29 "triangular2":
30 A basic triangular cycle that scales initial amplitude by half each cycle.
31 "exp_range":
32 A cycle that scales initial amplitude by gamma**(cycle iterations) at each
33 cycle iteration.
34 This implementation was adapted from the github repo: `bckenstler/CLR`_
35 Args:
36 optimizer (Optimizer): Wrapped optimizer.
37 base_lr (float or list): Initial learning rate which is the
38 lower boundary in the cycle for each parameter group.
39 max_lr (float or list): Upper learning rate boundaries in the cycle
40 for each parameter group. Functionally,
41 it defines the cycle amplitude (max_lr - base_lr).
42 The lr at any cycle is the sum of base_lr
43 and some scaling of the amplitude; therefore
44 max_lr may not actually be reached depending on
45 scaling function.
46 step_size_up (int): Number of training iterations in the
47 increasing half of a cycle. Default: 2000
48 step_size_down (int): Number of training iterations in the
49 decreasing half of a cycle. If step_size_down is None,
50 it is set to step_size_up. Default: None
51 mode (str): One of {triangular, triangular2, exp_range}.
52 Values correspond to policies detailed above.
53 If scale_fn is not None, this argument is ignored.
54 Default: 'triangular'
55 gamma (float): Constant in 'exp_range' scaling function:
56 gamma**(cycle iterations)
57 Default: 1.0
58 scale_fn (function): Custom scaling policy defined by a single
59 argument lambda function, where
60 0 <= scale_fn(x) <= 1 for all x >= 0.
61 If specified, then 'mode' is ignored.
62 Default: None
63 scale_mode (str): {'cycle', 'iterations'}.
64 Defines whether scale_fn is evaluated on
65 cycle number or cycle iterations (training
66 iterations since start of cycle).
67 Default: 'cycle'
68 cycle_momentum (bool): If ``True``, momentum is cycled inversely
69 to learning rate between 'base_momentum' and 'max_momentum'.
70 Default: True
71 base_momentum (float or list): Lower momentum boundaries in the cycle
72 for each parameter group. Note that momentum is cycled inversely
73 to learning rate; at the peak of a cycle, momentum is
74 'base_momentum' and learning rate is 'max_lr'.
75 Default: 0.8
76 max_momentum (float or list): Upper momentum boundaries in the cycle
77 for each parameter group. Functionally,
78 it defines the cycle amplitude (max_momentum - base_momentum).
79 The momentum at any cycle is the difference of max_momentum
80 and some scaling of the amplitude; therefore
81 base_momentum may not actually be reached depending on
82 scaling function. Note that momentum is cycled inversely
83 to learning rate; at the start of a cycle, momentum is 'max_momentum'
84 and learning rate is 'base_lr'
85 Default: 0.9
86 last_epoch (int): The index of the last batch. This parameter is used when
87 resuming a training job. Since `step()` should be invoked after each
88 batch instead of after each epoch, this number represents the total
89 number of *batches* computed, not the total number of epochs computed.
90 When last_epoch=-1, the schedule is started from the beginning.
91 Default: -1
92 Example:
93 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
94 >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
95 >>> data_loader = torch.utils.data.DataLoader(...)
96 >>> for epoch in range(10):
97 >>> for batch in data_loader:
98 >>> train_batch(...)
99 >>> scheduler.step()
100 .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
101 .. _bckenstler/CLR: https://github.com/bckenstler/CLR
102 """
104 def __init__(
105 self,
106 optimizer,
107 base_lr,
108 max_lr,
109 step_size_up=2000,
110 step_size_down=None,
111 mode="triangular",
112 gamma=1.0,
113 scale_fn=None,
114 scale_mode="cycle",
115 cycle_momentum=True,
116 base_momentum=0.8,
117 max_momentum=0.9,
118 last_epoch=-1,
119 ):
120 if not isinstance(optimizer, Optimizer):
121 raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
122 self.optimizer = optimizer
124 base_lrs = self._format_param("base_lr", optimizer, base_lr)
125 if last_epoch == -1:
126 for lr, group in zip(base_lrs, optimizer.param_groups):
127 group["lr"] = lr
129 self.max_lrs = self._format_param("max_lr", optimizer, max_lr)
131 step_size_up = float(step_size_up)
132 step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
133 self.total_size = step_size_up + step_size_down
134 self.step_ratio = step_size_up / self.total_size
136 if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None:
137 raise ValueError("mode is invalid and scale_fn is None")
139 self.mode = mode
140 self.gamma = gamma
142 if scale_fn is None:
143 if self.mode == "triangular":
144 self.scale_fn = self._triangular_scale_fn
145 self.scale_mode = "cycle"
146 elif self.mode == "triangular2":
147 self.scale_fn = self._triangular2_scale_fn
148 self.scale_mode = "cycle"
149 elif self.mode == "exp_range":
150 self.scale_fn = self._exp_range_scale_fn
151 self.scale_mode = "iterations"
152 else:
153 self.scale_fn = scale_fn
154 self.scale_mode = scale_mode
156 self.cycle_momentum = cycle_momentum
157 if cycle_momentum:
158 if "momentum" not in optimizer.defaults:
159 raise ValueError(
160 "optimizer must support momentum with `cycle_momentum` option enabled"
161 )
163 base_momentums = self._format_param("base_momentum", optimizer, base_momentum)
164 if last_epoch == -1:
165 for momentum, group in zip(base_momentums, optimizer.param_groups):
166 group["momentum"] = momentum
167 self.base_momentums = list(map(lambda group: group["momentum"], optimizer.param_groups))
168 self.max_momentums = self._format_param("max_momentum", optimizer, max_momentum)
170 super(CyclicLR, self).__init__(optimizer, last_epoch)
172 def _format_param(self, name, optimizer, param):
173 """Return correctly formatted lr/momentum for each param group."""
174 if isinstance(param, (list, tuple)):
175 if len(param) != len(optimizer.param_groups):
176 raise ValueError(
177 "expected {} values for {}, got {}".format(
178 len(optimizer.param_groups), name, len(param)
179 )
180 )
181 return param
182 else:
183 return [param] * len(optimizer.param_groups)
185 def _triangular_scale_fn(self, x):
186 return 1.0
188 def _triangular2_scale_fn(self, x):
189 return 1 / (2.0 ** (x - 1))
191 def _exp_range_scale_fn(self, x):
192 return self.gamma ** (x)
194 def get_lr(self):
195 """Calculates the learning rate at batch index. This function treats
196 `self.last_epoch` as the last batch index.
197 If `self.cycle_momentum` is ``True``, this function has a side effect of
198 updating the optimizer's momentum.
199 """
200 cycle = math.floor(1 + self.last_epoch / self.total_size)
201 x = 1.0 + self.last_epoch / self.total_size - cycle
202 if x <= self.step_ratio:
203 scale_factor = x / self.step_ratio
204 else:
205 scale_factor = (x - 1) / (self.step_ratio - 1)
207 lrs = []
208 for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
209 base_height = (max_lr - base_lr) * scale_factor
210 if self.scale_mode == "cycle":
211 lr = base_lr + base_height * self.scale_fn(cycle)
212 else:
213 lr = base_lr + base_height * self.scale_fn(self.last_epoch)
214 lrs.append(lr)
216 if self.cycle_momentum:
217 momentums = []
218 for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
219 base_height = (max_momentum - base_momentum) * scale_factor
220 if self.scale_mode == "cycle":
221 momentum = max_momentum - base_height * self.scale_fn(cycle)
222 else:
223 momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
224 momentums.append(momentum)
225 for param_group, momentum in zip(self.optimizer.param_groups, momentums):
226 param_group["momentum"] = momentum
228 return lrs