您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:pytorch 数据加载性能对比分析

51自学网 2021-10-30 22:53:14
  python
这篇教程pytorch 数据加载性能对比分析写得很实用,希望能帮到您。

传统方式需要10s,dat方式需要0.6s

import osimport timeimport torchimport randomfrom common.coco_dataset import COCODatasetdef gen_data(batch_size,data_path,target_path): os.makedirs(target_path,exist_ok=True) dataloader = torch.utils.data.DataLoader(COCODataset(data_path,               (352, 352),               is_training=False, is_scene=True),            batch_size=batch_size,            shuffle=False, num_workers=0, pin_memory=False,            drop_last=True) # DataLoader start = time.time() for step, samples in enumerate(dataloader):  images, labels, image_paths = samples["image"], samples["label"], samples["img_path"]  print("time", images.size(0), time.time() - start)  start = time.time()  # torch.save(samples,target_path+ '/' + str(step) + '.dat')  print(step)def cat_100(target_path,batch_size=100): paths = os.listdir(target_path) li = [i for i in range(len(paths))] random.shuffle(li) images = [] labels = [] image_paths = [] start = time.time() for i in range(len(paths)):  samples = torch.load(target_path + str(li[i]) + ".dat")  image, label, image_path = samples["image"], samples["label"], samples["img_path"]  images.append(image.cuda())  labels.append(label.cuda())  image_paths.append(image_path)  if i % batch_size == batch_size - 1:   images = torch.cat((images), 0)   print("time", images.size(0), time.time() - start)   images = []   labels = []   image_paths = []   start = time.time()  i += 1if __name__ == '__main__': os.environ["CUDA_VISIBLE_DEVICES"] = '3' batch_size=320 # target_path='d:/test_1000/' target_path='d:/img_2/' data_path = r'D:/dataset/origin_all_datas/_2train' gen_data(batch_size,data_path,target_path) # get_data(target_path,batch_size) # cat_100(target_path,batch_size)

这个读取数据也比较快:320 batch_size 450ms

def cat_100(target_path,batch_size=100): paths = os.listdir(target_path) li = [i for i in range(len(paths))] random.shuffle(li) images = [] labels = [] image_paths = [] start = time.time() for i in range(len(paths)):  samples = torch.load(target_path + str(li[i]) + ".dat")  image, label, image_path = samples["image"], samples["label"], samples["img_path"]  images.append(image)#.cuda())  labels.append(label)#.cuda())  image_paths.append(image_path)  if i % batch_size < batch_size - 1:   i += 1   continue  i += 1  images = torch.cat(([image.cuda() for image in images]), 0)  print("time", images.size(0), time.time() - start)  images = []  labels = []  image_paths = []  start = time.time()

补充:pytorch数据加载和处理问题解决方案

最近跟着pytorch中文文档学习遇到一些小问题,已经解决,在此对这些错误进行记录:

在读取数据集时报错:

AttributeError: 'Series' object has no attribute 'as_matrix'

在显示图片是时报错:

ValueError: Masked arrays must be 1-D

显示单张图片时figure一闪而过

在显示多张散点图的时候报错:

TypeError: show_landmarks() got an unexpected keyword argument 'image'

解决方案

主要问题在这一行: 最终目的是将Series转为Matrix,即调用np.mat即可完成。

修改前

landmarks =landmarks_frame.iloc[n, 1:].as_matrix()

修改后

landmarks =np.mat(landmarks_frame.iloc[n, 1:])

打散点的x和y坐标应该均为向量或列表,故将landmarks后使用tolist()方法即可

修改前

plt.scatter(landmarks[:,0],landmarks[:,1],s=10,marker='.',c='r')

修改后

plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')

前面使用plt.ion()打开交互模式,则后面在plt.show()之前一定要加上plt.ioff()。这里直接加到函数里面,避免每次plt.show()之前都用plt.ioff()

修改前

def show_landmarks(imgs,landmarks): '''显示带有地标的图片''' plt.imshow(imgs) plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点 plt.pause(1)#绘图窗口延时

修改后

def show_landmarks(imgs,landmarks): '''显示带有地标的图片''' plt.imshow(imgs) plt.scatter(landmarks[:,0].tolist(),landmarks[:,1].tolist(),s=10,marker='.',c='r')#打上红色散点 plt.pause(1)#绘图窗口延时 plt.ioff()

网上说对于字典类型的sample可通过 **sample的方式获取每个键下的值,但是会报错,于是把输入写的详细一点,就成功了。

修改前

show_landmarks(**sample)

修改后

show_landmarks(sample['image'],sample['landmarks'])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。如有错误或未考虑完全的地方,望不吝赐教。


pandas之query方法和sample随机抽样操作
pytorch从csv加载自定义数据模板的操作
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。