您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:pytorch 6 batch_train 批训练操作

51自学网 2021-10-30 22:36:45
  python
这篇教程pytorch 6 batch_train 批训练操作写得很实用,希望能帮到您。

看代码吧~

import torchimport torch.utils.data as Datatorch.manual_seed(1)    # reproducible# BATCH_SIZE = 5  BATCH_SIZE = 8      # 每次使用8个数据同时传入网路x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)torch_dataset = Data.TensorDataset(x, y)loader = Data.DataLoader(    dataset=torch_dataset,      # torch TensorDataset format    batch_size=BATCH_SIZE,      # mini batch size    shuffle=False,              # 设置不随机打乱数据 random shuffle for training    num_workers=2,              # 使用两个进程提取数据,subprocesses for loading data)def show_batch():    for epoch in range(3):   # 全部的数据使用3遍,train entire dataset 3 times        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step            # train your data...            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',                  batch_x.numpy(), '| batch y: ', batch_y.numpy())if __name__ == '__main__':    show_batch()

BATCH_SIZE = 8 , 所有数据利用三次

Epoch:  0 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]Epoch:  0 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]Epoch:  1 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]Epoch:  1 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]Epoch:  2 | Step:  0 | batch x:  [1. 2. 3. 4. 5. 6. 7. 8.] | batch y:  [10.  9.  8.  7.  6.  5.  4.  3.]Epoch:  2 | Step:  1 | batch x:  [ 9. 10.] | batch y:  [2. 1.]

补充:pytorch批训练bug

问题描述:

在进行pytorch神经网络批训练的时候,有时会出现报错 

TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>

解决办法:

第一步:

检查(重点!!!!!):

train_dataset = Data.TensorDataset(train_x, train_y)

train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable

可以这样将数据变为tensor类:

train_x = torch.FloatTensor(train_x)

第二步:

train_loader = Data.DataLoader(        dataset=train_dataset,        batch_size=batch_size,        shuffle=True    )

实例化一个DataLoader对象

第三步:

    for epoch in range(epochs):        for step, (batch_x, batch_y) in enumerate(train_loader):            batch_x, batch_y = Variable(batch_x), Variable(batch_y)

这样就可以批训练了

需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成Variable

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。


pytorch 如何使用batch训练lstm网络
Keras多线程机制与flask多线程冲突的解决方案
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。