百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

Day236:addmm()和addmm_()的用法详解

ztj100 2024-11-03 16:15 10 浏览 0 评论

函数解释

torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:

换句话说,就是需要传入5个参数mat里的每个元素乘以betamat1mat2进行矩阵乘法左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~

    def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__
        """
        addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
        
        Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
        The matrix :attr:`mat` is added to the final result.
        
        If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
        :math:`(m \times p)` tensor, then :attr:`mat` must be
        :ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensor
        and :attr:`out` will be a :math:`(n \times p)` tensor.
        
        :attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between
        :attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively.
        
        .. math::
            out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)
        
        For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
        :attr:`alpha` must be real numbers, otherwise they should be integers.
        
        Args:
            beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
            mat (Tensor): matrix to be added
            alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
            mat1 (Tensor): the first matrix to be multiplied
            mat2 (Tensor): the second matrix to be multiplied
            out (Tensor, optional): the output tensor
        
        Example::
        
            >>> M = torch.randn(2, 3)
            >>> mat1 = torch.randn(2, 3)
            >>> mat2 = torch.randn(3, 3)
            >>> torch.addmm(M, mat1, mat2)
            tensor([[-4.8716,  1.4671, -1.3746],
                    [ 0.7573, -3.9555, -2.8681]])
        """
        pass

代码范例

1.先摆出代码,大家可以先复制粘贴运行一下,在之后会一一讲解

"""
@author:nickhuang1996
"""
import torch
 
rectangle_height = 3
rectangle_width = 3
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
    for j in range(rectangle_width):
        inputs[i] = i * torch.ones(rectangle_width)
'''
inputs and its transpose
-->inputs   =   tensor([[0., 0., 0.],
                        [1., 1., 1.],
                        [2., 2., 2.]])
-->inputs_t =   tensor([[0., 1., 2.],
                        [0., 1., 2.],
                        [0., 1., 2.]])
'''
print("inputs:\n", inputs)
inputs_t = inputs.t()
print("inputs_t:\n", inputs_t)
'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]
                    =   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                        [0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
'''
 
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)
print("d:\n", d)
 
print("e:\n", e)
print("f:\n", f)
 
print("g:\n", g)
print("g2:\n", g2)
 
print("h:\n", h)
print("h12:\n", h12)
print("h21:\n", h21)
print("inputs:\n", inputs)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]
                    =    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                         [2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
print("inputs:\n", inputs)

2.其中

inputs是一个3×3的矩阵,为

tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.]])

inputs_t也是一个3×3的矩阵,是inputs转置矩阵,为

tensor([[0., 1., 2.],
        [0., 1., 2.],
        [0., 1., 2.]])

* inputs_t @ inputs_t

'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]
                    =   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                        [0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
'''

3.代码中abcd展示的是完全形式,即标明了位置参数传入参数。可以看到input这个位置参数可以写在函数的前面,即

torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)

'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)	
a:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
b:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
c:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
d:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])

4.下面的例子更好了说明了input参数的位置可变性,并且betaalpha缺省了:

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)

'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
e:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
f:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])

5.加一个参数,实际上是添加了beta这个参数

完成的公式为:

g = 1 × inputs + 1 ×(inputs_t @ inputs_t)

g2 = 2 × inputs + 1 ×(inputs_t @ inputs_t)

'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
g:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],
        [ 2.,  5.,  8.],
        [ 4.,  7., 10.]])

6.再加一个参数,实际上是添加了alpha这个参数

完成的公式为:

h = 1 × inputs + 1 ×(inputs_t @ inputs_t)

h12 = 1 × inputs + 2 ×(inputs_t @ inputs_t)

h21 = 2 × inputs + 1 ×(inputs_t @ inputs_t)

'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
h:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],
        [ 1.,  7., 13.],
        [ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],
        [ 2.,  5.,  8.],
        [ 4.,  7., 10.]])

7.当然,以上的步骤inputs没有变化,还是为

inputs:
tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.]])

*8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:

inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() ,因为inputs就是改变之后的变量

*inputs@ inputs_t

'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]
                    =    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.]
                         [2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''

完成的公式为:

inputs = 1 × inputs - 2 ×(inputs @ inputs_t)

'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
inputs:
tensor([[  0.,   0.,   0.],
        [  1.,  -5., -11.],
        [  2., -10., -22.]])

三、代码运行结果

inputs:
tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.]])
inputs_t:
tensor([[0., 1., 2.],
        [0., 1., 2.],
        [0., 1., 2.]])
a:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
b:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
c:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
d:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
e:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
f:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
g:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],
        [ 2.,  5.,  8.],
        [ 4.,  7., 10.]])
h:
tensor([[0., 3., 6.],
        [1., 4., 7.],
        [2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],
        [ 1.,  7., 13.],
        [ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],
        [ 2.,  5.,  8.],
        [ 4.,  7., 10.]])
inputs:
tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.]])
inputs:
tensor([[  0.,   0.,   0.],
        [  1.,  -5., -11.],
        [  2., -10., -22.]])

原文:https://blog.csdn.net/qq_36556893/article/details/90638449

相关推荐

Whoosh,纯python编写轻量级搜索工具

引言在许多应用程序中,搜索功能是至关重要的。Whoosh是一个纯Python编写的轻量级搜索引擎库,可以帮助我们快速构建搜索功能。无论是在网站、博客还是本地应用程序中,Whoosh都能提供高效的全文搜...

如何用Python实现二分搜索算法(python二分法查找代码)

如何用Python实现二分搜索算法二分搜索(BinarySearch)是一种高效的查找算法,适用于在有序数组中快速定位目标值。其核心思想是通过不断缩小搜索范围,每次将问题规模减半,时间复杂度为(O...

路径扫描 -- dirsearch(路径查找器怎么使用)

外表干净是尊重别人,内心干净是尊重自己,干净,在今天这个时代,应该是一种极高的赞美和珍贵。。。----网易云热评一、软件介绍Dirsearch是一种命令行工具,可以强制获取web服务器中的目录和文件...

78行Python代码帮你复现微信撤回消息!

来源:悟空智能科技本文约700字,建议阅读5分钟。本文基于python的微信开源库itchat,教你如何收集私聊撤回的信息。...

从零开始学习 Python!2《进阶知识》 Python进阶之路

欢迎来到Python学习的进阶篇章!如果你说已经掌握了基础语法,那么这篇就是你开启高手之路的大门。我们将一起探讨面向对象编程...

白帽黑客如何通过dirsearch脚本工具扫描和收集网站敏感文件

一、背景介绍...

Python之txt数据预定替换word预定义定位标记生成word报告(四)

续接Python之txt数据预定替换word预定义定位标记生成word报告(一)https://mp.toutiao.com/profile_v4/graphic/preview?pgc_id=748...

假期苦短,我用Python!这有个自动回复拜年信息的小程序

...

Python——字符串和正则表达式中的反斜杠(&#39;\&#39;)问题详解

在本篇文章里小编给大家整理的是关于Python字符串和正则表达式中的反斜杠('\')问题以及相关知识点,有需要的朋友们可以学习下。在Python普通字符串中在Python中,我们用'\'来转义某些普通...

Python re模块:正则表达式综合指南

Python...

Python中re模块详解(rem python)

在《...

python之re模块(python re模块sub)

re模块一.re模块的介绍1.什么是正则表达式"定义:正则表达式是一种对字符和特殊字符操作的一种逻辑公式,从特定的字符中,用正则表达字符来过滤的逻辑。(也是一种文本模式;)2、正则表达式可以帮助我们...

MySQL、PostgreSQL、SQL Server 数据库导入导出实操全解

在数字化时代,数据是关键资产,数据库的导入导出操作则是连接数据与应用场景的桥梁。以下是常见数据库导入导出的实用方法及代码,包含更多细节和特殊情况处理,助你应对各种实际场景。一、MySQL数据库...

Zabbix监控系统系列之六:监控 mysql

zabbix监控mysql1、监控规划在创建监控项之前要尽量考虑清楚要监控什么,怎么监控,监控数据如何存储,监控数据如何展现,如何处理报警等。要进行监控的系统规划需要对Zabbix很了解,这里只是...

mysql系列之一文详解Navicat工具的使用(二)

本章内容是系列内容的第二部分,主要介绍Navicat工具的使用。若查看第一部分请见:...

取消回复欢迎 发表评论: