欢迎来到某某鲜果配送有限公司!

专注鲜果配送

新鲜 / 健康 / 便利 / 快速 / 放心

全国咨询热线020-88888888
摩登7-摩登娱乐-摩登注册登录入口

新闻中心

 

推荐产品

24小时服务热线 020-88888888

行业动态

【转载】pytorch optimizer.step()

发布日期:2024-03-04 12:33浏览次数:

因为有人问我optimizer的step为什么不能放在min-batch那个循环之外,还有optimizer.step和loss.backward的区别;那么我想把答案记录下来。

首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方法论。

从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:1. 优化器需要知道当前的网络或者别的什么模型的参数空间,这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用pytorch的话总会出现类似如下的代码:

   
   
  1. optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G) # lr 使用的是初始lr
  2. optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)

2. 需要知道反向传播的梯度信息,我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分

   
   
  1. def step( self, closure=None):
  2. """Performs a single optimization step.
  3. Arguments:
  4. closure (callable, optional): A closure that reevaluates the model
  5. and returns the loss.
  6. """
  7. loss = None
  8. if closure is not None:
  9. loss = closure()
  10. for group in self.param_groups:
  11. weight_decay = group[ 'weight_decay']
  12. momentum = group[ 'momentum']
  13. dampening = group[ 'dampening']
  14. nesterov = group[ 'nesterov']
  15. for p in group[ 'params']:
  16. if p.grad is None:
  17. continue
  18. d_p = p.grad.data
  19. if weight_decay != 0:
  20. d_p.add_(weight_decay, p.data)
  21. if momentum != 0:
  22. param_state = self.state[p]
  23. if 'momentum_buffer' not in param_state:
  24. buf = param_state[ 'momentum_buffer'] = d_p.clone()
  25. else:
  26. buf = param_state[ 'momentum_buffer']
  27. buf.mul_(momentum).add_( 1 - dampening, d_p)
  28. if nesterov:
  29. d_p = d_p.add(momentum, buf)
  30. else:
  31. d_p = buf
  32. p.data.add_(-group[ 'lr'], d_p)
  33. return loss

从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候(loss.backward()的具体运算过程可以参看Pytorch 入门),这也就是经常会碰到,如下情况

   
   
  1. total_loss.backward()
  2. optimizer_G.step()

loss.backward()在前,然后跟一个step。

那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。

scheduler.step()按照Pytorch的定义是用来更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。

?

?

参考资料:

Pytorch SGD 代码

https://discuss.pytorch.org/t/how-are-optimizer-step-and-loss-backward-related/7350/3

https://discuss.pytorch.org/t/taking-a-step-with-torch-optim-object/20371

020-88888888

平台注册入口