Pytorch深度学习-代码篇:加载数据(1)

在pytorch中读取数据集并操作数据集的具体代码,主要介绍Dataset和DataLoader这两个比较关键的类。着重演示了构建Dataset类的基本流程。

理论基础

数据读取

Dataset类

提供一种方式去获取数据及其label:

  • 如何获取每一个数据及其label
  • 总共有多少个数据

DataLoader类

为后边的网络提供不同的数据形式

代码实现

通过Dataset构建自定义数据集类

对于一个结构如图(主要特征:文件夹名为标签)所示的数据集:

我们通过如下代码来构建Dataset类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch.utils.data import dataset
from PIL import Image
import os

class MyData(Dataset):
# 初始化函数
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 根目录文件夹名
self.label_dir = label_dir # label标签名(这是由该数据集特征决定的,具体情况具体分析)
self.path = os.path.join(self.root_dir, self.label_dir) # 图像存储文件夹名
self.img_path = os.listdir(self.path) # 由所有图片文件名构成的列表

# 获取指定元素方法
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img,label

# 获取数据集总数
def __len__(self):
return len(self.img_path)

实例化数据集

然后构建一个蚂蚁数据集:

1
2
3
root_dir = "dataset/train"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)

如果我们调用数据集第一个元素,会返回其值和标签。

1
ants_dataset[0]

1
(<PIL.JpegImageplugin.JpegImageFile image mode=RGB size=768x512 at 0x26F03A17080>,'ants')

数据集相加

我们还可以对两个数据集进行相加操作。

1
2
3
bees_label_dir = "bees"
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset

---------------------本文结束---------------------