Pytorch深度学习-代码篇:加载数据(1)
理论基础
数据读取
Dataset类
提供一种方式去获取数据及其label:
- 如何获取每一个数据及其label
- 总共有多少个数据
DataLoader类
为后边的网络提供不同的数据形式
代码实现
通过Dataset构建自定义数据集类
对于一个结构如图(主要特征:文件夹名为标签)所示的数据集:

我们通过如下代码来构建Dataset类:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23from torch.utils.data import dataset
from PIL import Image
import os
class MyData(Dataset):
# 初始化函数
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 根目录文件夹名
self.label_dir = label_dir # label标签名(这是由该数据集特征决定的,具体情况具体分析)
self.path = os.path.join(self.root_dir, self.label_dir) # 图像存储文件夹名
self.img_path = os.listdir(self.path) # 由所有图片文件名构成的列表
# 获取指定元素方法
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label
# 获取数据集总数
def __len__(self):
return len(self.img_path)
实例化数据集
然后构建一个蚂蚁数据集:1
2
3root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)
如果我们调用数据集第一个元素,会返回其值和标签。1
ants_dataset[0]
1 | (<PIL.JpegImageplugin.JpegImageFile image mode=RGB size=768x512 at 0x26F03A17080>,'ants') |
数据集相加
我们还可以对两个数据集进行相加操作。1
2
3bees_label_dir = "bees"
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset