Pytorch自定义图片数据集

发布时间:2022-07-04 发布网站:脚本宝典
脚本宝典收集整理的这篇文章主要介绍了Pytorch自定义图片数据集脚本宝典觉得挺不错的,现在分享给大家,也给大家做个参考。

本篇博客旨在实现pytorch读取图片并自定义图片数据集

图像加载方法

主流的图像加载方法主要有三种

下表中xxx表示图片的路径

函数/方法 返回值 图像像素格式 像素值范围 图像矩阵表示
skimage io.imread(xxx) numpy.ndarray RGB [0, 255] (H X W X C)
cv2 cv2.imread(xxx) numpy.ndarray BGR [0, 255] (H X W X C)
Pillow(PIL) Image.oPEn(xxx) PIL.Image.Image对象 根据图像格式,一般为RGB [0, 255]

这里使用三种方式读取一张图片

import matplotlib.pyplot as plt
import skimage.io as io
import cv2
From PIL import Image
import numpy as np
import torch

# width = 1081,  height=1920, channel=3

# 使用skimage读取图像
img_skimage = io.imread('./image/BackGround1.jpg')        # skimage.io imread()-----np.ndarray,  (H x W x C), [0, 255],RGB
PRint(img_skimage.Shape)

# 使用Opencv读取图像
img_cv = cv2.imread('./image/BackGround1.jpg')            # cv2.imread()------np.array, (H x W xC), [0, 255], BGR
print(img_cv.shape)

# 使用PIL读取
img_pil = Image.open('./image/BackGround1.jpg')         # PIL.Image.Image对象
img_pil_1 = np.array(img_pil)           # (H x W x C), [0, 255], RGB
print(img_pil_1.shape)

plt.figure()
for i, im in enumerate([img_skimage, img_cv, img_pil_1]):
    ax = plt.subplot(1, 3, i + 1)
    ax.imshow(im)

plt.show()

'''
三种方式输出的shape都是(1081, 1920, 3)
'''

将图片转化为Torch.Tensor

使用np.transpose进行转化,同时要注意numpy和torch中图片维度顺序的不同,因此需要进行转化

  • numpy image: H x W x C
  • torch image: C x H x W
tensor_skimage = torch.from_numpy(np.transpose(img_skimage, (2, 0, 1)))
print(tensor_skimage.shape)
tensor_cv = torch.from_numpy(np.transpose(img_cv, (2, 0, 1)))
print(tensor_cv.shape)
tensor_pil = torch.from_numpy(np.transpose(img_pil_1, (2, 0, 1)))
print(tensor_pil.shape)

'''
输出结果均为torch.Size([3, 1081, 1920])
'''

使用ImageFolder类

ImageFolder是torchvision提供好的一个类,可以让我们直接直接对某一个目录下文件夹内的图片加载为数据集,会自动检测jpg,jpeg,png等图片格式

下面为ImageFolder初始化的

super(ImageFolder, self).__inIT__(root, loader, IMG_extensions if is_valid_file is None else None,
                                  transform=transform,
                                  t@R_512_2604@et_transform=target_transform,
                                  is_valid_file=is_valid_file)
  • root:加载的路径,注意如果root='./',实际上会检测'./'的下级目录,比如'./image/'下的图片,不会再 './'直接检测

  • transform:可以添加transforms.COMpose()进行图片的预处理,例如

    transforms.Compose(
        [
            transforms.ToTensor()
        ]
    )
    # 可以将图片转化为张量
    

下面展示一个使用ImageFolder的例子

Pytorch自定义图片数据集

这是'./image/'下的三张图片

from torchvision import transforms
import torchvision as tv
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)

train_set = tv.datasets.ImageFolder(root='./', transform=transform)
data_loader = DataLoader(dataset=train_set)

# transforms提供的类,注意不是方法,需要先实例化,可以torch.tensor转化为PIL.Image.Image对象
to_pil_image = transforms.ToPILImage()
# 因为ToPILImage()类中定义了__call__方法,因此可以使用to_pil_image(xxx)的方式来调用__call__方法,详情可以见源码

for image, label in data_loader:
    # [Batch, Channels, Height, Width]所以第一维度会是1
    print(type(image)) # torch.Size([1, 3, 1081, 1920]) #
    # 下面使用两张展示的方法
    # 方法1:Image.show()
    # 第一种方法会自动打开脑默认的图片软件来展示图片
    # transforms.ToPILImage()中有一句
    # npimg = np.transpose(pic.numpy(), (1, 2, 0))
    # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
    # image[0]即为torch.Size([3, 1081, 1920])
    img = to_pil_image(image[0])
    img.show()

    # 方法2:plt.imshow(ndarray)
    img = image[0]  # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
    img = img.numpy()  # FloatTensor转为ndarray
    img = np.transpose(img, (1, 2, 0))  # 把channel那一维放到最后

    # 显示图片
    plt.imshow(img)
    plt.show()

重写DataSet类

DataSet是torch中的一个抽象类,用于进行重写自己的类

我们可以重写以下方法

def __getitem__(self, index) -> T_co:
    raise NotImplementedError

def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
    return ConcatDataset([self, other])

def __len__(self)
	return

例如

import os
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

class MyDataset(Dataset):
    def __init__(self, file_path, transform = None):
        super(MyDataset, self).__init__()
        self.file_path = file_path
        self.transform = transform  # 对输入图像进行预处理,这里并没有做,预设为None
        self.image_names = os.listdir(self.file_path)  # 文件名的列表
        print(self.image_names)

    def __getitem__(self, idx):
        image = self.image_names[idx]
        image = io.imread(os.path.join(self.file_path, image))
        if self.transform:
            image= self.transform(image)

        return image

    def __len__(self):
        return len(self.image_names)

在自定义的MyDataSet中,会将file_path路径下的所有图片加入数据集,使用的是skimage.io.read将其写成np.array,可以使用transform将其装变成torch.tensor

将DataSet加载进DataLoader

# 设置自己存放的数据集位置,并尝试转化为PIL.Image.Image对象进行展示
transform = transforms.Compose(
    [
        transforms.ToTensor()
    ]
)
imageloader = MyDataset(file_path="./leaves/",transform = None)

the_dataloader = DataLoader(dataset=imageloader,batch_size=2,shuffle=True)

to_pil_image = transforms.ToPILImage()

for i_batch,batch_data in enumerate(the_dataloader):
    print(i_batch)
    print(len(batch_data)) # 2,即上面设计的batch_size
    for X in batch_data:
        to_pil_image(X.numpy()).show()

在DataLoader中设定dataset,batch_size,和是否打乱shuffle

然后可以通过enumerate来遍历,注意通过enumerate,每一个图片的大小需要一直

其中i_batch表示label,batch_data表示一个batch的数据集

脚本宝典总结

以上是脚本宝典为你收集整理的Pytorch自定义图片数据集全部内容,希望文章能够帮你解决Pytorch自定义图片数据集所遇到的问题。

如果觉得脚本宝典网站内容还不错,欢迎将脚本宝典推荐好友。

本图文内容来源于网友网络收集整理提供,作为学习参考使用,版权属于原作者。
如您有任何意见或建议可联系处理。小编QQ:384754419,请注明来意。