Coverage for tvo/utils/CyclicLR.py: 58%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.3, created at 2024-03-01 11:33 +0000

1# This is a backport from pytorch: 

2# https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py 

3# 

4# TODO: Switch to the upstream version when: 

5# 1. pytorch makes a release that contains the fix for 

6# https://github.com/pytorch/pytorch/issues/19003, and 

7# 2. GOLD supports that pytorch version 

8# 

9# LICENSING AND COPYRIGHT FOR THIS FILE: 

10# See pytorch's license at https://github.com/pytorch/pytorch/blob/master/LICENSE 

11 

12from torch.optim.optimizer import Optimizer 

13from torch.optim.lr_scheduler import _LRScheduler 

14import math 

15 

16 

17class CyclicLR(_LRScheduler): 

18 """Sets the learning rate of each parameter group according to 

19 cyclical learning rate policy (CLR). The policy cycles the learning 

20 rate between two boundaries with a constant frequency, as detailed in 

21 the paper `Cyclical Learning Rates for Training Neural Networks`_. 

22 The distance between the two boundaries can be scaled on a per-iteration 

23 or per-cycle basis. 

24 Cyclical learning rate policy changes the learning rate after every batch. 

25 `step` should be called after a batch has been used for training. 

26 This class has three built-in policies, as put forth in the paper: 

27 "triangular": 

28 A basic triangular cycle w/ no amplitude scaling. 

29 "triangular2": 

30 A basic triangular cycle that scales initial amplitude by half each cycle. 

31 "exp_range": 

32 A cycle that scales initial amplitude by gamma**(cycle iterations) at each 

33 cycle iteration. 

34 This implementation was adapted from the github repo: `bckenstler/CLR`_ 

35 Args: 

36 optimizer (Optimizer): Wrapped optimizer. 

37 base_lr (float or list): Initial learning rate which is the 

38 lower boundary in the cycle for each parameter group. 

39 max_lr (float or list): Upper learning rate boundaries in the cycle 

40 for each parameter group. Functionally, 

41 it defines the cycle amplitude (max_lr - base_lr). 

42 The lr at any cycle is the sum of base_lr 

43 and some scaling of the amplitude; therefore 

44 max_lr may not actually be reached depending on 

45 scaling function. 

46 step_size_up (int): Number of training iterations in the 

47 increasing half of a cycle. Default: 2000 

48 step_size_down (int): Number of training iterations in the 

49 decreasing half of a cycle. If step_size_down is None, 

50 it is set to step_size_up. Default: None 

51 mode (str): One of {triangular, triangular2, exp_range}. 

52 Values correspond to policies detailed above. 

53 If scale_fn is not None, this argument is ignored. 

54 Default: 'triangular' 

55 gamma (float): Constant in 'exp_range' scaling function: 

56 gamma**(cycle iterations) 

57 Default: 1.0 

58 scale_fn (function): Custom scaling policy defined by a single 

59 argument lambda function, where 

60 0 <= scale_fn(x) <= 1 for all x >= 0. 

61 If specified, then 'mode' is ignored. 

62 Default: None 

63 scale_mode (str): {'cycle', 'iterations'}. 

64 Defines whether scale_fn is evaluated on 

65 cycle number or cycle iterations (training 

66 iterations since start of cycle). 

67 Default: 'cycle' 

68 cycle_momentum (bool): If ``True``, momentum is cycled inversely 

69 to learning rate between 'base_momentum' and 'max_momentum'. 

70 Default: True 

71 base_momentum (float or list): Lower momentum boundaries in the cycle 

72 for each parameter group. Note that momentum is cycled inversely 

73 to learning rate; at the peak of a cycle, momentum is 

74 'base_momentum' and learning rate is 'max_lr'. 

75 Default: 0.8 

76 max_momentum (float or list): Upper momentum boundaries in the cycle 

77 for each parameter group. Functionally, 

78 it defines the cycle amplitude (max_momentum - base_momentum). 

79 The momentum at any cycle is the difference of max_momentum 

80 and some scaling of the amplitude; therefore 

81 base_momentum may not actually be reached depending on 

82 scaling function. Note that momentum is cycled inversely 

83 to learning rate; at the start of a cycle, momentum is 'max_momentum' 

84 and learning rate is 'base_lr' 

85 Default: 0.9 

86 last_epoch (int): The index of the last batch. This parameter is used when 

87 resuming a training job. Since `step()` should be invoked after each 

88 batch instead of after each epoch, this number represents the total 

89 number of *batches* computed, not the total number of epochs computed. 

90 When last_epoch=-1, the schedule is started from the beginning. 

91 Default: -1 

92 Example: 

93 >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 

94 >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1) 

95 >>> data_loader = torch.utils.data.DataLoader(...) 

96 >>> for epoch in range(10): 

97 >>> for batch in data_loader: 

98 >>> train_batch(...) 

99 >>> scheduler.step() 

100 .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186 

101 .. _bckenstler/CLR: https://github.com/bckenstler/CLR 

102 """ 

103 

104 def __init__( 

105 self, 

106 optimizer, 

107 base_lr, 

108 max_lr, 

109 step_size_up=2000, 

110 step_size_down=None, 

111 mode="triangular", 

112 gamma=1.0, 

113 scale_fn=None, 

114 scale_mode="cycle", 

115 cycle_momentum=True, 

116 base_momentum=0.8, 

117 max_momentum=0.9, 

118 last_epoch=-1, 

119 ): 

120 if not isinstance(optimizer, Optimizer): 

121 raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) 

122 self.optimizer = optimizer 

123 

124 base_lrs = self._format_param("base_lr", optimizer, base_lr) 

125 if last_epoch == -1: 

126 for lr, group in zip(base_lrs, optimizer.param_groups): 

127 group["lr"] = lr 

128 

129 self.max_lrs = self._format_param("max_lr", optimizer, max_lr) 

130 

131 step_size_up = float(step_size_up) 

132 step_size_down = float(step_size_down) if step_size_down is not None else step_size_up 

133 self.total_size = step_size_up + step_size_down 

134 self.step_ratio = step_size_up / self.total_size 

135 

136 if mode not in ["triangular", "triangular2", "exp_range"] and scale_fn is None: 

137 raise ValueError("mode is invalid and scale_fn is None") 

138 

139 self.mode = mode 

140 self.gamma = gamma 

141 

142 if scale_fn is None: 

143 if self.mode == "triangular": 

144 self.scale_fn = self._triangular_scale_fn 

145 self.scale_mode = "cycle" 

146 elif self.mode == "triangular2": 

147 self.scale_fn = self._triangular2_scale_fn 

148 self.scale_mode = "cycle" 

149 elif self.mode == "exp_range": 

150 self.scale_fn = self._exp_range_scale_fn 

151 self.scale_mode = "iterations" 

152 else: 

153 self.scale_fn = scale_fn 

154 self.scale_mode = scale_mode 

155 

156 self.cycle_momentum = cycle_momentum 

157 if cycle_momentum: 

158 if "momentum" not in optimizer.defaults: 

159 raise ValueError( 

160 "optimizer must support momentum with `cycle_momentum` option enabled" 

161 ) 

162 

163 base_momentums = self._format_param("base_momentum", optimizer, base_momentum) 

164 if last_epoch == -1: 

165 for momentum, group in zip(base_momentums, optimizer.param_groups): 

166 group["momentum"] = momentum 

167 self.base_momentums = list(map(lambda group: group["momentum"], optimizer.param_groups)) 

168 self.max_momentums = self._format_param("max_momentum", optimizer, max_momentum) 

169 

170 super(CyclicLR, self).__init__(optimizer, last_epoch) 

171 

172 def _format_param(self, name, optimizer, param): 

173 """Return correctly formatted lr/momentum for each param group.""" 

174 if isinstance(param, (list, tuple)): 

175 if len(param) != len(optimizer.param_groups): 

176 raise ValueError( 

177 "expected {} values for {}, got {}".format( 

178 len(optimizer.param_groups), name, len(param) 

179 ) 

180 ) 

181 return param 

182 else: 

183 return [param] * len(optimizer.param_groups) 

184 

185 def _triangular_scale_fn(self, x): 

186 return 1.0 

187 

188 def _triangular2_scale_fn(self, x): 

189 return 1 / (2.0 ** (x - 1)) 

190 

191 def _exp_range_scale_fn(self, x): 

192 return self.gamma ** (x) 

193 

194 def get_lr(self): 

195 """Calculates the learning rate at batch index. This function treats 

196 `self.last_epoch` as the last batch index. 

197 If `self.cycle_momentum` is ``True``, this function has a side effect of 

198 updating the optimizer's momentum. 

199 """ 

200 cycle = math.floor(1 + self.last_epoch / self.total_size) 

201 x = 1.0 + self.last_epoch / self.total_size - cycle 

202 if x <= self.step_ratio: 

203 scale_factor = x / self.step_ratio 

204 else: 

205 scale_factor = (x - 1) / (self.step_ratio - 1) 

206 

207 lrs = [] 

208 for base_lr, max_lr in zip(self.base_lrs, self.max_lrs): 

209 base_height = (max_lr - base_lr) * scale_factor 

210 if self.scale_mode == "cycle": 

211 lr = base_lr + base_height * self.scale_fn(cycle) 

212 else: 

213 lr = base_lr + base_height * self.scale_fn(self.last_epoch) 

214 lrs.append(lr) 

215 

216 if self.cycle_momentum: 

217 momentums = [] 

218 for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums): 

219 base_height = (max_momentum - base_momentum) * scale_factor 

220 if self.scale_mode == "cycle": 

221 momentum = max_momentum - base_height * self.scale_fn(cycle) 

222 else: 

223 momentum = max_momentum - base_height * self.scale_fn(self.last_epoch) 

224 momentums.append(momentum) 

225 for param_group, momentum in zip(self.optimizer.param_groups, momentums): 

226 param_group["momentum"] = momentum 

227 

228 return lrs