您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:PyTorch 如何将CIFAR100数据按类标归类保存

51自学网 2021-10-30 22:41:24
  python
这篇教程PyTorch 如何将CIFAR100数据按类标归类保存写得很实用,希望能帮到您。

few-shot learning的采样

Few-shot learning 基于任务对模型进行训练,在N-way-K-shot中,一个任务中的meta-training中含有N类,每一类抽取K个样本构成support set, query set则是在刚才抽取的N类剩余的样本中sample一定数量的样本(可以是均匀采样,也可以是不均匀采样)。

对数据按类标归类

针对上述情况,我们需要使用不同类别放置在不同文件夹的数据集。但有时,数据并没有按类放置,这时就需要对数据进行处理。

下面以CIFAR100为列(不含N-way-k-shot的采样):

import osfrom skimage import ioimport torchvision as tvimport numpy as npimport torchdef Cifar100(root):    character = [[] for i in range(100)]    train_set = tv.datasets.CIFAR100(root, train=True, download=True)    test_set = tv.datasets.CIFAR100(root, train=False, download=True)    dataset = []    for (X, Y) in zip(train_set.train_data, train_set.train_labels):  # 将train_set的数据和label读入列表        dataset.append(list((X, Y)))    for (X, Y) in zip(test_set.test_data, test_set.test_labels):  # 将test_set的数据和label读入列表        dataset.append(list((X, Y)))    for X, Y in dataset:        character[Y].append(X)  # 32*32*3    character = np.array(character)    character = torch.from_numpy(character)    # 按类打乱    np.random.seed(6)    shuffle_class = np.arange(len(character))    np.random.shuffle(shuffle_class)    character = character[shuffle_class]    # shape = self.character.shape    # self.character = self.character.view(shape[0], shape[1], shape[4], shape[2], shape[3])  # 将数据转成channel在前    meta_training, meta_validation, meta_testing = /    character[:64], character[64:80], character[80:]  # meta_training : meta_validation : Meta_testing = 64类:16类:20类    dataset = []  # 释放内存    character = []    os.mkdir(os.path.join(root, 'meta_training'))    for i, per_class in enumerate(meta_training):        character_path = os.path.join(root, 'meta_training', 'character_' + str(i))        os.mkdir(character_path)        for j, img in enumerate(per_class):            img_path = character_path + '/' + str(j) + ".jpg"            io.imsave(img_path, img)    os.mkdir(os.path.join(root, 'meta_validation'))    for i, per_class in enumerate(meta_validation):        character_path = os.path.join(root, 'meta_validation', 'character_' + str(i))        os.mkdir(character_path)        for j, img in enumerate(per_class):            img_path = character_path + '/' + str(j) + ".jpg"            io.imsave(img_path, img)    os.mkdir(os.path.join(root, 'meta_testing'))    for i, per_class in enumerate(meta_testing):        character_path = os.path.join(root, 'meta_testing', 'character_' + str(i))        os.mkdir(character_path)        for j, img in enumerate(per_class):            img_path = character_path + '/' + str(j) + ".jpg"            io.imsave(img_path, img)if __name__ == '__main__':    root = '/home/xie/文档/datasets/cifar_100'    Cifar100(root)    print("-----------------")

补充:使用Pytorch对数据集CIFAR-10进行分类

主要是以下几个步骤:

1、下载并预处理数据集

2、定义网络结构

3、定义损失函数和优化器

4、训练网络并更新参数

5、测试网络效果

#数据加载和预处理#使用CIFAR-10数据进行分类实验import torch as timport torchvision as tvimport torchvision.transforms as transformsfrom torchvision.transforms import ToPILImageshow = ToPILImage() # 可以把Tensor转成Image,方便可视化 #定义对数据的预处理transform = transforms.Compose([    transforms.ToTensor(),    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  #归一化]) #训练集trainset = tv.datasets.CIFAR10(    root = './data/',    train = True,    download = True,    transform = transform) trainloader = t.utils.data.DataLoader(    trainset,    batch_size = 4,    shuffle = True,    num_workers = 2,) #测试集testset = tv.datasets.CIFAR10(    root = './data/',    train = False,    download = True,    transform = transform,)testloader = t.utils.data.DataLoader(    testset,    batch_size = 4,    shuffle = False,    num_workers = 2,) classes = ('plane', 'car', 'bird', 'cat',           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

初次下载需要一些时间,运行结束后,显示如下:

import torch.nn as nnimport torch.nn.functional as Fimport timestart = time.time()#计时#定义网络结构class Net(nn.Module):    def __init__(self):        super(Net,self).__init__()        self.conv1 = nn.Conv2d(3,6,5)        self.conv2 = nn.Conv2d(6,16,5)        self.fc1 = nn.Linear(16*5*5,120)        self.fc2 = nn.Linear(120,84)        self.fc3 = nn.Linear(84,10)            def forward(self,x):        x = F.max_pool2d(F.relu(self.conv1(x)),2)        x = F.max_pool2d(F.relu(self.conv2(x)),2)                x = x.view(x.size()[0],-1)        x = F.relu(self.fc1(x))        x = F.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Net()print(net)

显示net结构如下:

#定义优化和损失loss_func = nn.CrossEntropyLoss()  #交叉熵损失函数optimizer = t.optim.SGD(net.parameters(),lr = 0.001,momentum = 0.9) #训练网络for epoch in range(2):    running_loss = 0    for i,data in enumerate(trainloader,0):        inputs,labels = data               outputs = net(inputs)        loss = loss_func(outputs,labels)        optimizer.zero_grad()        loss.backward()        optimizer.step()        running_loss +=loss.item()        if i%2000 ==1999:            print('epoch:',epoch+1,'|i:',i+1,'|loss:%.3f'%(running_loss/2000))            running_loss = 0.0end = time.time()time_using = end - startprint('finish training')print('time:',time_using)

结果如下:

下一步进行使用测试集进行网络测试:

#测试网络correct = 0 #定义的预测正确的图片数total = 0#总共图片个数with t.no_grad():    for data in testloader:        images,labels = data        outputs = net(images)        _,predict = t.max(outputs,1)        total += labels.size(0)        correct += (predict == labels).sum()print('测试集中的准确率为:%d%%'%(100*correct/total))

结果如下:

简单的网络训练确实要比10%的比例高一点:)

在GPU中训练:

#在GPU中训练device = t.device('cuda:0' if t.cuda.is_available() else 'cpu') net.to(device)images = images.to(device)labels = labels.to(device) output = net(images)loss = loss_func(output,labels) loss

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。如有错误或未考虑完全的地方,望不吝赐教。


python获取linux和windows系统指定接口的IP地址的步骤及代码
pytorch 优化器(optim)不同参数组,不同学习率设置的操作
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。