您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:Pytorch 使用tensor特定条件判断索引

51自学网 2021-10-30 22:46:10
  python
这篇教程Pytorch 使用tensor特定条件判断索引写得很实用,希望能帮到您。

torch.where() 用于将两个broadcastable的tensor组合成新的tensor,类似于c++中的三元操作符“?:”

区别于python numpy中的where()直接可以找到特定条件元素的index

想要实现numpy中where()的功能,可以借助nonzero()

对应numpy中的where()操作效果:

补充:Pytorch torch.Tensor.detach()方法的用法及修改指定模块权重的方法

detach

detach的中文意思是分离,官方解释是返回一个新的Tensor,从当前的计算图中分离出来

需要注意的是,返回的Tensor和原Tensor共享相同的存储空间,但是返回的 Tensor 永远不会需要梯度

import torch as ta = t.ones(10,)b = a.detach()print(b)tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

那么这个函数有什么作用?

–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改B网络的参数,但是不想修改A网络的参数,这个时候就可以使用detcah()方法

a = A(input)a = detach()b = B(a)loss = criterion(b, target)loss.backward()

来看一个实际的例子:

import torch as tx = t.ones(1, requires_grad=True)x.requires_grad   #Truey = t.ones(1, requires_grad=True)y.requires_grad   #Truex = x.detach()   #分离之后x.requires_grad   #Falsey = x+y         #tensor([2.])y.requires_grad   #我还是Truey.retain_grad()   #y不是叶子张量,要加上这一行z = t.pow(y, 2)z.backward()    #反向传播y.grad        #tensor([4.])x.grad        #None

以上代码就说明了反向传播到y就结束了,没有到达x,所以x的grad属性为None

既然谈到了修改模型的权重问题,那么还有一种情况是:

–假如A网络输出了一个Tensor类型的变量a, a要作为输入传入到B网络中,如果我想通过损失函数反向传播修改A网络的参数,但是不想修改B网络的参数,这个时候又应该怎么办了?

这时可以使用Tensor.requires_grad属性,只需要将requires_grad修改为False即可.

for param in B.parameters(): param.requires_grad = Falsea = A(input)b = B(a)loss = criterion(b, target)loss.backward()

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。如有错误或未考虑完全的地方,望不吝赐教。


Python面向对象封装继承和多态示例讲解
python 实现简单的吃豆人游戏
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。