您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:Pytorch DataLoader shuffle验证方式

51自学网 2021-10-30 22:34:43
  python
这篇教程Pytorch DataLoader shuffle验证方式写得很实用,希望能帮到您。

shuffle = False时,不打乱数据顺序

shuffle = True,随机打乱

import numpy as npimport h5pyimport torchfrom torch.utils.data import DataLoader, Dataset  h5f = h5py.File('train.h5', 'w');data1 = np.array([[1,2,3],               [2,5,6],              [3,5,6],              [4,5,6]])data2 = np.array([[1,1,1],                   [1,2,6],                  [1,3,6],                  [1,4,6]])h5f.create_dataset(str('data'), data=data1)h5f.create_dataset(str('label'), data=data2)class Dataset(Dataset):    def __init__(self):        h5f = h5py.File('train.h5', 'r')        self.data = h5f['data']        self.label = h5f['label']    def __getitem__(self, index):        data = torch.from_numpy(self.data[index])        label = torch.from_numpy(self.label[index])        return data, label     def __len__(self):        assert self.data.shape[0] == self.label.shape[0], "wrong data length"        return self.data.shape[0]  dataset_train = Dataset()loader_train = DataLoader(dataset=dataset_train,                           batch_size=2,                           shuffle = True) for i, data in enumerate(loader_train):    train_data, label = data    print(train_data) 

pytorch DataLoader使用细节

背景:

我一开始是对数据扩增这一块有疑问, 只看到了数据变换(torchvisiom.transforms),但是没看到数据扩增, 后来搞明白了, 数据扩增在pytorch指的是torchvisiom.transforms + torch.utils.data.DataLoader+多个epoch共同作用下完成的,

数据变换共有以下内容

composed = transforms.Compose([transforms.Resize((448, 448)), #  resize                               transforms.RandomCrop(300), # random crop                               transforms.ToTensor(),                               transforms.Normalize(mean=[0.5, 0.5, 0.5],  # normalize                                                    std=[0.5, 0.5, 0.5])])

简单的数据读取类, 进返回PIL格式的image:

class MyDataset(data.Dataset):        def __init__(self, labels_file, root_dir, transform=None):        with open(labels_file) as csvfile:            self.labels_file = list(csv.reader(csvfile))        self.root_dir = root_dir        self.transform = transform            def __len__(self):        return len(self.labels_file)        def __getitem__(self, idx):        im_name = os.path.join(root_dir, self.labels_file[idx][0])        im = Image.open(im_name)                if self.transform:            im = self.transform(im)                    return im

下面是主程序

labels_file = "F:/test_temp/labels.csv"root_dir = "F:/test_temp"dataset_transform = MyDataset(labels_file, root_dir, transform=composed)dataloader = data.DataLoader(dataset_transform, batch_size=1, shuffle=False)"""原始数据集共3张图片, 以batch_size=1, epoch为2 展示所有图片(共6张)  """for eopch in range(2):    plt.figure(figsize=(6, 6))     for ind, i in enumerate(dataloader):        a = i[0, :, :, :].numpy().transpose((1, 2, 0))        plt.subplot(1, 3, ind+1)        plt.imshow(a)

从上述图片总可以看到, 在每个eopch阶段实际上是对原始图片重新使用了transform, , 这就造就了数据的扩增

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。


python 爬取吉首大学网站成绩单
Python爬虫实战之爬取携程评论
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。