import torch import numpy as np import torch.utils.data as Data from torchvision import transforms, datasets, utils from PIL import Image import matplotlib.pyplot as plt
""" torchvision 有 4 个子模块:datasets、models、transforms 和 utils。model 将在后面讲解,在 3.4 中已经使用过 datasets 下载过数据集。 现在主要介绍如何使用 transforms 对原数据进行预处理、增强等,以及 datasets 下的 ImageFolder 处理自定义数据。 """
""" transforms 提供了对 PIL Image 和 Tensor 对象的常用操作。 A 对 PIL Image 的操作: 1) Resize(size, interpolation=2): 调整图像大小,size 可以是一个整数或一个 (width, height) 元组。 2) CenterCrop(size): 从图像中心裁剪出指定大小的区域,size 可以是一个整数或一个 (width, height) 元组。 RandomCrop 随机裁剪,RandomResizedCrop 裁剪随机大小。 3) Pad(padding, fill=0, padding_mode='constant'): 对图像进行填充,padding 可以是一个整数或一个 (left, top, right, bottom) 元组。 4) ToTensor(): 将 PIL Image 或 numpy.ndarray 转换为 Tensor,并将像素值归一化到 [0, 1] 范围内。如把 (H x W x C) 的图像转换为 (C x H x W) 的 Tensor。 5) RandomHorizontalFlip(p=0.5): 以概率 p 随机水平翻转图像。 6) RandomVerticalFlip(p=0.5): 以概率 p 随机垂直翻转图像。 7) ColorJitter(brightness=0, contrast=0, saturation=0, hue=0): 随机改变图像的亮度、对比度、饱和度和色调。 B 对 Tensor 的操作: 1) Normalize(mean, std, inplace=False): 使用均值和标准差对 Tensor 进行归一化。mean 和 std 应该是与图像通道数相同的列表。 2) ToPILImage(mode=None): 将 Tensor 或 numpy.ndarray 转换为 PIL Image。
如果要对数据集进行多个变换,可以使用 transforms.Compose() 将多个变换组合在一起,按顺序依次执行。类似于 nn.Sequential()。 """ transforms.Compose([ transforms.CenterCrop(10), transforms.RandomCrop(20, padding=0), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ])
transforms.Lambda(lambda x: x + 10)
""" 当文件依据标签处于不同文件夹下时,如: ————— data |———— label1 | |————— 001.jpg | |————— 002.jpg |———— label2 | |————— 001.jpg | |————— 002.jpg ....... 就可以利用 torchvision.datasets.ImageFolder 来直接构造出 dataset,代码如下: dataset = datasets.ImageFolder(Path) loader = utils.data.DataLoader(dataset) ImageFolder 会自动为每个子文件夹分配一个标签,按字母顺序排序,从 0 开始编号。这样载入 DataLoader 后,每个 batch 的数据就会包含图像和对应的标签。 """
my_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) train_data = datasets.ImageFolder(root='./data', transform=my_transforms) train_loader = Data.DataLoader(dataset=train_data, batch_size=6, shuffle=True)
for i_batch, img in enumerate(train_loader): if i_batch == 0: print(img[1]) fig = plt.figure() grid = utils.make_grid(img[0]) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.show() utils.save_image(grid, 'cat_dog.png') break
img = Image.open('cat_dog.png')
|