您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:pytorch __init__、forward与__call__的用法小结

51自学网 2021-10-30 22:54:02
  python
这篇教程pytorch __init__、forward与__call__的用法小结写得很实用,希望能帮到您。

1.介绍

当我们使用pytorch来构建网络框架的时候,也会遇到和tensorflow(tensorflow __init__、build 和call小结)类似的情况,即经常会遇到__init__、forward和call这三个互相搭配着使用,那么它们的主要区别又在哪里呢?

1)__init__主要用来做参数初始化用,比如我们要初始化卷积的一些参数,就可以放到这里面,这点和tf里面的用法是一样的

2)forward是表示一个前向传播,构建网络层的先后运算步骤

3)__call__的功能其实和forward类似,所以很多时候,我们构建网络的时候,可以用__call__替代forward函数,但它们两个的区别又在哪里呢?

当网络构建完之后,调__call__的时候,会去先调forward,即__call__其实是包了一层forward,所以会导致两者的功能类似。

在pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数:

https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

2.代码

import torchimport torch.nn as nnimport torch.nn.functional as F class Net(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(Net, self).__init__() self.conv0 = torch.nn.Sequential( torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), torch.nn.LeakyReLU()) self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))  def forward(self, x): x = self.conv0(x) x = self.conv1(x) return x class Net(nn.Module): def __init__(self, in_channels, mid_channels, out_channels): super(Net, self).__init__() self.conv0 = torch.nn.Sequential( torch.nn.Conv2d(in_channels, mid_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), torch.nn.LeakyReLU()) self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(mid_channels, out_channels * 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))  def __call__(self, x): x = self.conv0(x) x = self.conv1(x) return x

补充:torch/nn目录结构以及__init__.py

torch/nn目录结构以及init.py

torch/nn目录结构

__init__.py:

from .modules import *#nn.modules  导入modules目录下内容 定义容器modulesfrom .parameter import Parameter#nn.Parameter 导入parameter.py  定义parameterfrom .parallel import DataParallel#导入parallel目录下data_parallel.py中的DataParallel类from . import init#nn.init   导入init.py   参数初始化from . import utils#nn.utils  导入utils目录下内容 官网api下nn.utils下api

对于backends, functional.py, _functions 需要在代码前重新Import

例如我们常用的

import torch.nn.functional as F 就是导入了functional.py

backends和_functions是functional.py实现各种函数时所用到的。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。如有错误或未考虑完全的地方,望不吝赐教。


python 实现有道翻译功能
python如何发送带有附件、正文为HTML的邮件
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。