pytorch梯度累积

news/2024/7/24 13:21:07 标签: pytorch, 人工智能, python

梯度累加其实是为了变相扩大batch_size,用来解决显存受限问题。

常规训练方式,每次从train_loader读取出一个batch的数据:

python">for x,y in train_loader:
	pred = model(x)
	loss = criterion(pred, label)
	# 反向传播
	loss.backward()
	# 根据新的梯度更新网络参数
	optimizer.step()
	# 清空以往梯度,通过下面反向传播重新计算梯度
	optimizer.zero_grad()

        pytorch每次forward完都会得到一个用于梯度回传的计算图,pytorch构建的计算图是动态的,其实在每次backward后计算图都会从内存中释放掉,但是梯度不会清空的。所以若不显示的进行optimizer.zero_grad()清空过往梯度这一步操作,backward()的时候就会累加过往梯度。

梯度累加的做法:

python">accumulation_steps = 4
for i,(x,y) in enumerate(train_loader):
	pred = model(x)
	loss = criterion(pred, label)
	
	# 相当于对累加后的梯度取平均
	loss = loss/accumulation_steps
	# 反向传播
	loss.backward()

	if (i+1) % accumulation_steps == 0:
		# 根据新的梯度更新网络参数
		optimizer.step()
		# 清空以往梯度,通过下面反向传播重新计算梯度
		optimizer.zero_grad()

        代码中设置accumulation_steps = 4,意思就是变相扩大batch_size四倍。因为代码中每隔4次迭代才清空梯度,更新参数。
        至于为啥loss = loss/accumulation_steps,因为梯度累加了四次呀,那就要取平均,除以4。那我每次loss取4,其实就相当于最后将累加后的梯度除4咯。同时,因为累计了4个batch,那学习率也应该扩大4倍,让更新的步子跨大点。

 看网上的帖子有讨论对BN层是否有影响,因为BN的估算阶段(计算batch内均值、方差)是在forward阶段完成的,那真实的batch_size放大4倍效果肯定是比通过梯度累加放大4倍效果好的,毕竟计算真实的大batch_size内的均值、方差肯定更精确。

 还有讨论说通过调低BN参数momentum可以得到更长序列的统计信息,应该意思是能够记忆更久远的统计信息(均值、方差),以逼近真实的扩大batch_size的效果。

参考

pytorch骚操作之梯度累加,变相增大batch size


http://www.niftyadmin.cn/n/5402581.html

相关文章

QT之液晶电子时钟

根据qt的<QLDNumber>做了一个qt液晶电子时钟. 结果 实时显示当前时间,左键可以拖动时钟在屏幕的位置,右键点击关闭显示. 实现过程 新建一个class文件,让这个文件的父类是QLCDNumber 相关功能变量定义和函数实现 .c文件代码 这里需要注意的一点是event->button是获取的…

【python、nlp、transformer】transformer学习部分

注&#xff1a; 此博文仅为了解transformer架构&#xff0c;如果使用&#xff0c;建议直接调用库就行了 Transformer的优势 相比之前占领市场的LSTM和GRU模型&#xff0c;Transformer有两个显著的优势&#xff1a; 1. Transformer能够利用分布式GPU进行并行训练&#xff0c…

探寻2024年国内热门低代码平台排行!| 功能特点一览

低代码开发是一项革命性的技术&#xff0c;主要目的是尽量避免程序研发的复杂性&#xff0c;让外行开发者也能加入到应用程序的搭建中。低代码平台的核心概念和构成部分通常包括用户界面和拖拽设计、预构件和模块、自动化工作内容与数据库集成和扩展应用&#xff0c;应用低代码…

第二章:数据类型 第五节:数组

一、数组的概念 数组不同于其他语言的数组&#xff0c;是多维向量&#xff0c;至少要三维或者更多 二、创建数组 数组可以使用array(向量、维度、维度名称)函数创建,这边创建一个3x4的数组&#xff0c;则将dim中的维度设置为c(3,4) 同样&#xff0c;我们可以利用数组创建三维…

什么是依赖注入(Dependency Injection)?它在 C++ 中是如何实现的?

什么是依赖注入&#xff08;Dependency Injection&#xff09;&#xff1f;它在 C 中是如何实现的&#xff1f; 依赖注入&#xff08;Dependency Injection&#xff0c;DI&#xff09;是一种设计模式&#xff0c;用于减少软件组件之间的耦合度&#xff0c;提高代码的可测试性、…

前端【技术类】资源学习网站整理(那些年的小网站)

学习网站整理 值得分享的视频博主&#xff1a;学习网站链接 百度首页的资源收藏里的截图&#xff08;排列顺序没有任何意义&#xff0c;随性而已~&#xff09;&#xff0c;可根据我标注的关键词百度搜索到这些网站呀&#xff0c;本篇末尾会一一列出来&#xff0c;供大家学习呀 …

React富文本编辑器开发(一)

这是一个系统的完整的教程&#xff0c;每一节文章的内容都很重要。这个教程学完后自己可以开发出一个相当完美的富文本编辑器了。下面就开始我们今天的内容&#xff1a; 安装 是的&#xff0c;我们的开发是基于Slate的开发基础&#xff0c;所以要安装它&#xff1a; yarn ad…

【R语言教程】

R语言简介&#xff1a; R 是一种自由、开源的编程语言&#xff0c;专门用于统计分析、数据挖掘、数据可视化以及整理和清洗数据。 R 的强大功能和丰富的扩展包使得它在全球统计学家、数据科学家甚至其它领域的研究员和技术人员中都非常受欢迎。 R语言环境&#xff1a; 要开始…