PyTorch 基础概念 - 数据集 Dataset

更新日期: 2022-09-05 阅读次数: 235 字数: 926 分类: AI

Datasets and Dataloaders

  • Dataset (torch.utils.data.Dataset) 存储了样本及其对应的标签。
  • DataLoader (torch.utils.data.DataLoader) 方便访问 Dataset。

Dataset 的类型

  • 图片
  • 文本
  • 音频

等等。

现成的 Dataset 有哪些

例如, FashionMNIST。

>>> from torchvision import datasets
>>> dir(datasets)
['CIFAR10', 'CIFAR100', 'CLEVRClassification', 'Caltech101', 'Caltech256', 'CelebA', 'Cityscapes', 'CocoCaptions', 'CocoDetection', 'Country211', 'DTD', 'DatasetFolder', 'EMNIST', 'EuroSAT', 'FER2013', 'FGVCAircraft', 'FakeData', 'FashionMNIST', 'Flickr30k', 'Flickr8k', 'Flowers102', 'FlyingChairs', 'FlyingThings3D', 'Food101', 'GTSRB', 'HD1K', 'HMDB51', 'INaturalist', 'ImageFolder', 'ImageNet', 'KMNIST', 'Kinetics', 'Kinetics400', 'Kitti', 'KittiFlow', 'LFWPairs', 'LFWPeople', 'LSUN', 'LSUNClass', 'MNIST', 'Omniglot', 'OxfordIIITPet', 'PCAM', 'PhotoTour', 'Places365', 'QMNIST', 'RenderedSST2', 'SBDataset', 'SBU', 'SEMEION', 'STL10', 'SUN397', 'SVHN', 'Sintel', 'StanfordCars', 'UCF101', 'USPS', 'VOCDetection', 'VOCSegmentation', 'VisionDataset', 'WIDERFace', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_optical_flow', 'caltech', 'celeba', 'cifar', 'cityscapes', 'clevr', 'coco', 'country211', 'dtd', 'eurosat', 'fakedata', 'fer2013', 'fgvc_aircraft', 'flickr', 'flowers102', 'folder', 'food101', 'gtsrb', 'hmdb51', 'imagenet', 'inaturalist', 'kinetics', 'kitti', 'lfw', 'lsun', 'mnist', 'omniglot', 'oxford_iiit_pet', 'pcam', 'phototour', 'places365', 'rendered_sst2', 'sbd', 'sbu', 'semeion', 'stanford_cars', 'stl10', 'sun397', 'svhn', 'ucf101', 'usps', 'utils', 'video_utils', 'vision', 'voc', 'widerface']
>>>

可以看到除了 FashionMNIST,还有很多其他的数据集。

但是,我有个疑惑。为何 FashionMNIST 不是在 torch.utils.data.Dataset 中,而是在 torchvision.datasets 中。

MNIST 是什么?

FashionMNIST 的 MNIST 后缀代表什么呢?

https://en.wikipedia.org/wiki/MNIST_database

The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems.

MNIST database 即,改进后的美国国家标准与技术研究院数据库。

  • MNIST:手写数字图片的数据集合。包含 6万张训练图片及1万张测试图片。大小为统一的 28x28 像素图片。
  • EMNIST: 增加了大小写字母,并包含了手写数字。
  • FashionMNIST: 时尚物品数据集合。包含 10 个分类的物品,例如 t shirt,鞋子,西装之类。同样是 28x28 像素的,灰度图片(灰度值0~255)。

从 torchvision.datasets 可以看到,既有 MNIST,也有 EMNIST,FashionMNIST。

数据集的下载

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

training_data = datasets.FashionMNIST(
    root="data", train=True, download=True, transform=ToTensor()
)
  • root 参数,即数据集所在的目录。
  • 在 download 为 True 时,最自动下载数据集文件
  • train 代表,这是训练数据

是从 AWS S3 上下载的压缩包文件。

>python main.py
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data\FashionMNIST\raw\train-images-idx3-ubyte.gz
Extracting data\FashionMNIST\raw\train-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
Extracting data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz
Extracting data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw

文件大小:

> tree -h
.
├── [4.0K]  data
│   └── [4.0K]  FashionMNIST
│       └── [4.0K]  raw
│           ├── [7.5M]  t10k-images-idx3-ubyte
│           ├── [4.2M]  t10k-images-idx3-ubyte.gz
│           ├── [9.8K]  t10k-labels-idx1-ubyte
│           ├── [5.0K]  t10k-labels-idx1-ubyte.gz
│           ├── [ 45M]  train-images-idx3-ubyte
│           ├── [ 25M]  train-images-idx3-ubyte.gz
│           ├── [ 59K]  train-labels-idx1-ubyte
│           └── [ 29K]  train-labels-idx1-ubyte.gz
└── [ 346]  main.py

3 directories, 9 files

完整代码

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plt

training_data = datasets.FashionMNIST(
    root="data", train=True, download=True, transform=ToTensor()
)
print(len(training_data))  # 60000

test_data = datasets.FashionMNIST(
    root="data", train=False, download=True, transform=ToTensor()
)
print(len(test_data))  # 10000

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

运行结果:

数据集在有本地文件的情况下,加载非常快,1秒都不用。打开那些本地数据集文件,可以看到并不是预想的单个图片文件,而且序列化到一起的二进制文件。

Label 与 Feature

  • Label: 输出。即上面 demo 里的物品名称。也是我们用 PyTorch 预测的结果输出。
  • Feature: 输入。即作为预测模型的输入,这里是图片的像素 pattern (patterns in the images pixels)。我不知道 pattern 翻译为什么比较好。
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

这里的参数:

  • transform:用于修改 feature
  • target_transform: 用于修改 label

修改、转换的目的是,使数据适用于训练。

DataLoader 的作用

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)

提升效率,批量从数据集合中加载数据,例如,这里一次加载 64 个,同时设置了随机数据读取。这样每次读出来的就是多个图片和 label。

参考

  • https://docs.microsoft.com/zh-cn/learn/modules/intro-machine-learning-pytorch/3-data

tags: pytorch

爱评论不评论