【小白学PyTorch】3 浅谈Dataset和Dataloader

news/2024/7/10 1:44:03 标签: python, 深度学习, java, tensorflow, vue

这个系列是重新整理的一个《小白学PyTorch系列》。文章来自微信公众号【机器学习炼丹术】,喜欢的话动动小手关注下公众号吧~

文章目录:

文章目录

    • 1 Dataset基类
    • 2 构建Dataset子类
      • 2.1 __Init__
      • 2.2 __getitem__
    • 3 dataloader

1 Dataset基类

PyTorch 读取其他的数据,主要是通过 Dataset 类,所以先简单了解一下 Dataset 类。在看很多PyTorch的代码的时候,也会经常看到dataset这个东西的存在。Dataset类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它。

先看一下源码:
在这里插入图片描述

这里有一个__getitem__函数,__getitem__函数接收一个index,然后返回图片数据和标签,这个index通常是指一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。之后会举例子来讲解这个逻辑

其实说着了些都没用,因为在训练代码里是感觉不到这些操作的,只会看到通过DataLoader就可以获取一个batch的数据,这是触发去读取图片这些操作的是DataLoader里的__iter__(self)(后面再讲)。

2 构建Dataset子类

下面我们构建一下Dataset的子类,叫他MyDataset类:

python">import torch 
from torch.utils.data import Dataset,DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = torch.tensor([[1,2,3],[2,3,4],[3,4,5],[4,5,6]])
        self.label = torch.LongTensor([1,1,0,0])

    def __getitem__(self,index):
        return self.data[index],self.label[index]

    def __len__(self):
        return len(self.data)

2.1 Init

  • 初始化中,一般是把数据直接保存在这个类的属性中。像是self.data,self.label

2.2 getitem

  • index是一个索引,这个索引的取值范围是要根据__len__这个返回值确定的,在上面的例子中,__len__的返回值是4,所以这个index会在0,1,2,3这个范围内。

3 dataloader

从上文中,我们知道了MyDataset这个类中的__getitem__的返回值,应该是某一个样本的数据和标签(如果是测试集的dataset,那么就只返回数据),在梯度下降的过程中,一般是需要将多个数据组成batch,这个需要我们自己来组合吗?不需要的,所以PyTorch中存在DataLoader这个迭代器(这个名词用的准不准确有待考究)。

继续上面的代码,我们接着写代码:

python">mydataloader = DataLoader(dataset=mydataset,
                          batch_size=1)

我们现在创建了一个DataLoader的实例,并且把之前实例化的mydataset作为参数输入进去,并且还输入了batch_size这个参数,现在我们使用的batch_size是1.下面来用for循环来遍历这个dataloader:

python">for i,(data,label) in enumerate(mydataloader):
    print(data,label)

输出结果是:

意料之中的结果,总共输出了4个batch,每个batch都是只有1个样本(数据+标签),值得注意的是,这个输出过程是顺序的

我们稍微修改一下上面的DataLoader的参数:

python">mydataloader = DataLoader(dataset=mydataset,
                          batch_size=2,
                          shuffle=True)

for i,(data,label) in enumerate(mydataloader):
    print(data,label)

结果是:

可以看到每一个batch内出现了2个样本。假如我们再运行一遍上面的代码,得到:

两次结果不同,这是因为shuffle=True,dataset中的index不再是按照顺序从0到3了,而是乱序,可能是[0,1,2,3],也可能是[2,3,1,0]。

【个人感想】

Dataloader和Dataset两个类是非常方便的,因为这个可以快速的做出来batch数据,修改batch_size和乱序都非常地方便。有下面两个希望注意的地方:

  1. 一般标签值应该是Long整数的,所以标签的tensor可以用torch.LongTensor(数据)或者用.long()来转化成Long整数的形式。
  2. 如果要使用PyTorch的GPU训练的话,一般是先判断cuda是否可用,然后把数据标签都用to()放到GPU显存上进行GPU加速。
python">device = 'cuda' if torch.cuda.is_available() else 'cpu'
for i,(data,label) in enumerate(mydataloader):
    data = data.to(device)
    label = label.to(device)
    print(data,label)

看一下输出:


http://www.niftyadmin.cn/n/1358362.html

相关文章

【小白学PyTorch】4 构建模型三要素与权重初始化

这个系列是重新整理的一个《小白学PyTorch系列》。文章来自微信公众号【机器学习炼丹术】,喜欢的话动动小手关注下公众号吧~ 文章目录: 文章目录1 模型三要素2 参数初始化3 完整运行代码4 更细致的看参数1 模型三要素 三要素其实很简单 必须要继承nn.…

【阿里云训练营】python查漏补缺 1

文章来自微信公众号:【机器学习炼丹术】。欢迎大家关注,是我的个人学习干活分享基地。干货文章100 文章目录1 注释2 is 与 3 运算优先级4 查找所有属性和方法5 type和isinstance36 位运算6.1 原码、反码和补码6.2 按位运算6.3 利用位运算实现快速计算6.4…

【阿里云训练营】python查漏补缺 2

参考目录: 文章目录1. if 语句2. if - else 语句3. if - elif - else 语句4. assert 关键词5. while 循环6. while - else 循环7. for 循环8. for - else 循环9. range() 函数6. enumerate()函数7. break 语句8. continue 语句9. pass 语句10. 推导式1. if 语句 if…

【小白学PyTorch】7 最新版本torchvision.transforms常用API翻译与讲解

文章转自:微信公众号【机器学习炼丹术】。 有需要的话,可以添加作者微信交流:cyx645016617。朋友圈经常抽奖送书送红包哈哈。 参考目录: 文章目录1 基本函数1.1 Compose1.2 RandomChoice1.3 RandomOrder2 PIL上的操作2.1 中心切割…

【小白学PyTorch】8 实战之MNIST小试牛刀

文章来自微信公众号【机器学习炼丹术】。有什么问题都可以咨询作者WX:cyx645016617。想交个朋友占一个好友位也是可以的~好友位快满了不过。 文章目录1 探索性数据分析1.1 数据集基本信息1.2 数据集可视化1.3 类别是否均衡2 训练与推理2.1 构建dataset2.2 构建模型类…

【小白学PyTorch】9 tensor数据结构与存储结构

文章来自微信公众号【机器学习炼丹术】。 上一节课,讲解了MNIST图像分类的一个小实战,现在我们继续深入学习一下pytorch的一些有的没的的小知识来作为只是储备。 参考目录: 文章目录1 pytorch数据结构1.1 默认整数与浮点数1.2 dtype修改变量…

【小白学PyTorch】10 pytorch常见运算详解

文章来自微信公众号【机器学习炼丹术】。有问题可以咨询“炼丹兄”,WX:cyx645016617 这一课主要是讲解PyTorch中的一些运算,加减乘除这些,当然还有矩阵的乘法这些。这一课内容不多,作为一个知识储备。在后续的内容中&…

【小白学PyTorch】11 MobileNet详解及PyTorch实现

文章来自微信公众号【机器学习炼丹术】。我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617。 文章目录1 背景2 深度可分离卷积2.2 一般卷积计算量2.2 深度可分离卷积计算量2.3 网络结构3 PyTorch实现本来计划是想在今天讲EfficientNet PyTorch的&#…