PyTorch 基础概念 - 数据集 Dataset

文章目录

    Datasets and Dataloaders

    • Dataset (torch.utils.data.Dataset) 存储了样本及其对应的标签。
    • DataLoader (torch.utils.data.DataLoader) 方便访问 Dataset。

    Dataset 的类型

    • 图片
    • 文本
    • 音频

    等等。

    现成的 Dataset 有哪些

    例如, FashionMNIST。

    >>> from torchvision import datasets
    >>> dir(datasets)
    ['CIFAR10', 'CIFAR100', 'CLEVRClassification', 'Caltech101', 'Caltech256', 'CelebA', 'Cityscapes', 'CocoCaptions', 'CocoDetection', 'Country211', 'DTD', 'DatasetFolder', 'EMNIST', 'EuroSAT', 'FER2013', 'FGVCAircraft', 'FakeData', 'FashionMNIST', 'Flickr30k', 'Flickr8k', 'Flowers102', 'FlyingChairs', 'FlyingThings3D', 'Food101', 'GTSRB', 'HD1K', 'HMDB51', 'INaturalist', 'ImageFolder', 'ImageNet', 'KMNIST', 'Kinetics', 'Kinetics400', 'Kitti', 'KittiFlow', 'LFWPairs', 'LFWPeople', 'LSUN', 'LSUNClass', 'MNIST', 'Omniglot', 'OxfordIIITPet', 'PCAM', 'PhotoTour', 'Places365', 'QMNIST', 'RenderedSST2', 'SBDataset', 'SBU', 'SEMEION', 'STL10', 'SUN397', 'SVHN', 'Sintel', 'StanfordCars', 'UCF101', 'USPS', 'VOCDetection', 'VOCSegmentation', 'VisionDataset', 'WIDERFace', '__all__', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '_optical_flow', 'caltech', 'celeba', 'cifar', 'cityscapes', 'clevr', 'coco', 'country211', 'dtd', 'eurosat', 'fakedata', 'fer2013', 'fgvc_aircraft', 'flickr', 'flowers102', 'folder', 'food101', 'gtsrb', 'hmdb51', 'imagenet', 'inaturalist', 'kinetics', 'kitti', 'lfw', 'lsun', 'mnist', 'omniglot', 'oxford_iiit_pet', 'pcam', 'phototour', 'places365', 'rendered_sst2', 'sbd', 'sbu', 'semeion', 'stanford_cars', 'stl10', 'sun397', 'svhn', 'ucf101', 'usps', 'utils', 'video_utils', 'vision', 'voc', 'widerface']
    >>>
    

    可以看到除了 FashionMNIST,还有很多其他的数据集。

    但是,我有个疑惑。为何 FashionMNIST 不是在 torch.utils.data.Dataset 中,而是在 torchvision.datasets 中。

    MNIST 是什么?

    FashionMNIST 的 MNIST 后缀代表什么呢?

    https://en.wikipedia.org/wiki/MNIST_database

    The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems.

    MNIST database 即,改进后的美国国家标准与技术研究院数据库。

    • MNIST:手写数字图片的数据集合。包含 6万张训练图片及1万张测试图片。大小为统一的 28x28 像素图片。
    • EMNIST: 增加了大小写字母,并包含了手写数字。
    • FashionMNIST: 时尚物品数据集合。包含 10 个分类的物品,例如 t shirt,鞋子,西装之类。同样是 28x28 像素的,灰度图片(灰度值0~255)。

    从 torchvision.datasets 可以看到,既有 MNIST,也有 EMNIST,FashionMNIST。

    数据集的下载

    from torchvision import datasets
    from torchvision.transforms import ToTensor, Lambda
    
    training_data = datasets.FashionMNIST(
        root="data", train=True, download=True, transform=ToTensor()
    )
    
    • root 参数,即数据集所在的目录。
    • 在 download 为 True 时,最自动下载数据集文件
    • train 代表,这是训练数据

    是从 AWS S3 上下载的压缩包文件。

    >python main.py
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data\FashionMNIST\raw\train-images-idx3-ubyte.gz
    Extracting data\FashionMNIST\raw\train-images-idx3-ubyte.gz to data\FashionMNIST\raw
    
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data\FashionMNIST\raw\train-labels-idx1-ubyte.gz
    Extracting data\FashionMNIST\raw\train-labels-idx1-ubyte.gz to data\FashionMNIST\raw
    
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz
    Extracting data\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data\FashionMNIST\raw
    
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
    Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz
    Extracting data\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data\FashionMNIST\raw
    

    文件大小:

    > tree -h
    .
    ├── [4.0K]  data
    │   └── [4.0K]  FashionMNIST
    │       └── [4.0K]  raw
    │           ├── [7.5M]  t10k-images-idx3-ubyte
    │           ├── [4.2M]  t10k-images-idx3-ubyte.gz
    │           ├── [9.8K]  t10k-labels-idx1-ubyte
    │           ├── [5.0K]  t10k-labels-idx1-ubyte.gz
    │           ├── [ 45M]  train-images-idx3-ubyte
    │           ├── [ 25M]  train-images-idx3-ubyte.gz
    │           ├── [ 59K]  train-labels-idx1-ubyte
    │           └── [ 29K]  train-labels-idx1-ubyte.gz
    └── [ 346]  main.py
    
    3 directories, 9 files
    

    完整代码

    import torch
    from torch.utils.data import Dataset
    from torchvision import datasets
    from torchvision.transforms import ToTensor, Lambda
    import matplotlib.pyplot as plt
    
    training_data = datasets.FashionMNIST(
        root="data", train=True, download=True, transform=ToTensor()
    )
    print(len(training_data))  # 60000
    
    test_data = datasets.FashionMNIST(
        root="data", train=False, download=True, transform=ToTensor()
    )
    print(len(test_data))  # 10000
    
    labels_map = {
        0: "T-Shirt",
        1: "Trouser",
        2: "Pullover",
        3: "Dress",
        4: "Coat",
        5: "Sandal",
        6: "Shirt",
        7: "Sneaker",
        8: "Bag",
        9: "Ankle Boot",
    }
    figure = plt.figure(figsize=(8, 8))
    cols, rows = 3, 3
    for i in range(1, cols * rows + 1):
        sample_idx = torch.randint(len(training_data), size=(1,)).item()
        img, label = training_data[sample_idx]
        figure.add_subplot(rows, cols, i)
        plt.title(labels_map[label])
        plt.axis("off")
        plt.imshow(img.squeeze(), cmap="gray")
    plt.show()
    

    运行结果:

    数据集在有本地文件的情况下,加载非常快,1秒都不用。打开那些本地数据集文件,可以看到并不是预想的单个图片文件,而且序列化到一起的二进制文件。

    Label 与 Feature

    • Label: 输出。即上面 demo 里的物品名称。也是我们用 PyTorch 预测的结果输出。
    • Feature: 输入。即作为预测模型的输入,这里是图片的像素 pattern (patterns in the images pixels)。我不知道 pattern 翻译为什么比较好。
      from torchvision import datasets from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST( root="data", train=True, download=True, transform=ToTensor(), target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) )

    这里的参数:

    • transform:用于修改 feature
    • target_transform: 用于修改 label

    修改、转换的目的是,使数据适用于训练。

    DataLoader 的作用

    from torch.utils.data import DataLoader
    
    train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
    

    提升效率,批量从数据集合中加载数据,例如,这里一次加载 64 个,同时设置了随机数据读取。这样每次读出来的就是多个图片和 label。

    参考

    • https://docs.microsoft.com/zh-cn/learn/modules/intro-machine-learning-pytorch/3-data

    关于作者 🌱

    我是来自山东烟台的一名开发者,有感兴趣的话题,或者软件开发需求,欢迎加微信 zhongwei 聊聊,或者关注我的个人公众号“大象工具”, 查看更多联系方式