FashionMNIST 数据集
大约 3 分钟
FashionMNIST 数据集
FashionMNIST 数据集是一个包含 60,000 个训练图像和 10,000 个测试图像的数据集。引入该数据集后,我们可以更方便和直观地比较模型精度和计算效率。
导入本文所需要的包
import torch.utils.data as Data
import torchvision
from . import d2lzh_pytorch as d2l # 需要根据实际目录更改
获取数据集
我们通过torchvision.datasets
来下载数据集,第一次调用时会从网上获取数据,该函数包括多个参数,写完代码之后我们一一来看各参数的含义。
mnist_train = torchvision.datasets.FashionMNIST(
root="./Datasets/",
train=True,
download=True,
transform=torchvision.transforms.ToTensor(),
)
mnist_test = torchvision.datasets.FashionMNIST(
root="./Datasets/",
train=False,
download=True,
transform=torchvision.transforms.ToTensor(),
)
root
数据集的保存目录,注意这里是在 linux 环境下下载的,所以采用/
作为路径分隔符train
是否为训练集,True 表示训练集,False 表示测试集download
是否下载数据集,如果数据集已经下载过,设置为 False 可以避免重复下载。transform
数据预处理操作,这里使用ToTensor()
将 PIL 图片(尺寸为HeightxWidthxChannels
且数据位于[0,255])或者 数据类型为np.uint8
的 numpy 数组 转换为 尺寸为ChannelsxHeightxWidth
,数据类型为torch.float32
且位于[0.0,1.0]的Tensor
- 函数返回的
mnist_test
和mnist_train
都是torch.utils.data.Dataset
的子类,对应了线性回归模型中制作数据集的步骤,直接获得了features
和labels
的映射。
对 Dataset 的基本操作
len(mnist_train) # 读取数据集数量
type(mnist_train) # 读取数据集类型
feature, label = mnist_train[0] # 读取数据集第一个元素
mnist_train[0][0] # 读取数据集第一个元素的特征
mnist_train[0][1] # 读取数据集第一个元素的标签
值得注意的是,这里的feature
维度需要特别关注,是CxHxW,第一维为通道数。数据集中的图形都是28x28
的灰度图像,只有一个通道,所以feature
的尺寸是1x28x28
。并且由于我们使用了ToTensor()
,数据类型已经转换为torch.float32
,数据均位于[0-1]
区间。
将数值标签转换为文本标签
Fashion-MNIST 中一共包括了 10 个类别,分别为 t-shirt(T 恤)、trouser(裤子)、pullover(套衫)、dress(连衣裙)、coat(外套)、sandal(凉鞋)、shirt(衬衫)、sneaker(运动鞋)、bag(包)和 ankle boot(短靴)。以下函数可以将数值标签转成相应的文本标签。
# 本函数已保存在 d2lzh 包中以方便后续使用
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
一行里面画出多张图像和标签的函数
# 本函数已保存在 d2lzh 包中以方便后续使用
from IPython import display
import matplotlib.pyplot as plt
def use_svg_display():
display.set_matplotlib_formats('svg')
def show_fashion_mnist(images, labels):
d2l.use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
因此我们可以显示出 10 个数据集中的图像。
X , y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
d2l.show_fashion_mnist(X, d2l.get_fashion_mnist_labels(y))