您当前的位置:首页 > IT编程 > python
| C语言 | Java | VB | VC | python | Android | TensorFlow | C++ | oracle | 学术与代码 | cnn卷积神经网络 | gnn | 图像修复 | Keras | 数据集 | Neo4j | 自然语言处理 | 深度学习 | 医学CAD | 医学影像 | 超参数 | pointnet | pytorch | 异常检测 | Transformers | 情感分类 | 知识图谱 |

自学教程:pytorch中的matmul与mm,bmm区别说明

51自学网 2021-10-30 22:40:47
  python
这篇教程pytorch中的matmul与mm,bmm区别说明写得很实用,希望能帮到您。

pytorch中matmul和mm和bmm区别 matmulmmbmm结论

先看下官网上对这三个函数的介绍。

matmul

在这里插入图片描述

mm

在这里插入图片描述

bmm

顾名思义, 就是两个batch矩阵乘法.

在这里插入图片描述

结论

从官方文档可以看出

1、mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是( n × m ) (n/times m)(n×m)和( m × p ) (m/times p)(m×p)

2、bmm是两个三维张量相乘, 两个输入tensor维度是( b × n × m ) (b/times n/times m)(b×n×m)和( b × m × p ) (b/times m/times p)(b×m×p), 第一维b代表batch size,输出为( b × n × p ) (b/times n /times p)(b×n×p)

3、matmul可以进行张量乘法, 输入可以是高维.

补充:torch中的几种乘法。torch.mm, torch.mul, torch.matmul

一、点乘

点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

>>> a = torch.ones(3,4)>>> atensor([[1., 1., 1., 1.],        [1., 1., 1., 1.],        [1., 1., 1., 1.]])>>> b = torch.Tensor([1,2,3]).reshape((3,1))>>> btensor([[1.],        [2.],        [3.]])>>> torch.mul(a, b)tensor([[1., 1., 1., 1.],        [2., 2., 2., 2.],        [3., 3., 3., 3.]])

当a, b维度不一致时,会自动填充到相同维度相点乘。

二、矩阵乘

矩阵相乘有torch.mm和torch.matmul两个函数。其中前一个是针对二维矩阵,后一个是高维。当torch.mm用于大于二维时将报错。

>>> a = torch.ones(3,4)>>> b = torch.ones(4,2)>>> torch.mm(a, b)tensor([[4., 4.],        [4., 4.],        [4., 4.]])
>>> a = torch.ones(3,4)>>> b = torch.ones(5,4,2)>>> torch.matmul(a, b).shapetorch.Size([5, 3, 2])
>>> a = torch.ones(5,4,2)>>> b = torch.ones(5,2,3)>>> torch.matmul(a, b).shapetorch.Size([5, 4, 3])
>>> a = torch.ones(5,4,2)>>> b = torch.ones(5,2,3)>>> torch.matmul(b, a).shape报错。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持51zixue.net。如有错误或未考虑完全的地方,望不吝赐教。


pytorch-autograde-计算图的特点说明
解决Numpy与Pytorch彼此转换时的坑
万事OK自学网:51自学网_软件自学网_CAD自学网自学excel、自学PS、自学CAD、自学C语言、自学css3实例,是一个通过网络自主学习工作技能的自学平台,网友喜欢的软件自学网站。