宝塔服务器面板,一键全能部署及管理,送你10850元礼包,点我领取

使用PyTorch求平方根报错如何解决?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

问题描述

初步使用PyTorch进行平方根计算,通过range)创建一个张量,然后对其求平方根。

a = torch.tensorlistrange9)))
b = torch.sqrta)

报出以下错误:

RuntimeError: sqrt_vml_cpu not implemented for 'Long'

原因

Long类型的数据不支持log对数运算, 为什么Tensor是Long类型? 因为创建List数组时默认使用的是int, 所以从List转成torch.Tensor后, 数据类型变成了Long。

printa.dtype)

torch.int64

解决方法

提前将数据类型指定为浮点型, 重新执行:

b = torch.sqrta.totorch.double))
printb)

tensor[0.0000, 1.0000, 1.4142, 1.7321, 2.0000, 2.2361, 2.4495, 2.6458, 2.8284], dtype=torch.float64)

补充:pytorch20 pytorch常见运算详解

矩阵与标量

这个是矩阵(张量)每一个元素与标量进行操作。

import torch
a = torch.tensor[1,2])
printa+1)
>>> tensor[2, 3])

哈达玛积

这个就是两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积,也成为element wise。

a = torch.tensor[1,2])
b = torch.tensor[2,3])
printa*b)
printtorch.mula,b))
>>> tensor[2, 6])
>>> tensor[2, 6])

这个torch.mul)和*是等价的。

当然,除法也是类似的:

a = torch.tensor[1.,2.])
b = torch.tensor[2.,3.])
printa/b)
printtorch.diva/b))
>>> tensor[0.5000, 0.6667])
>>> tensor[0.5000, 0.6667])

我们可以发现的torch.div)其实就是/, 类似的:torch.add就是+,torch.sub)就是-,不过符号的运算更简单常用。

矩阵乘法

如果我们想实现线性代数中的矩阵相乘怎么办呢?

这样的操作有三个写法:

torch.mm)

torch.matmul)

@,这个需要记忆,不然遇到这个可能会挺蒙蔽的

a = torch.tensor[[1.],[2.]])
b = torch.tensor[2.,3.]).view1,2)
printtorch.mma, b))
printtorch.matmula, b))
printa @ b)

使用PyTorch求平方根报错如何解决-风君子博客

这是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul)可以使用。等等,多维张量怎么进行矩阵的乘法?在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:

a = torch.rand1,2,64,32))
b = torch.rand1,2,32,64))
printtorch.matmula, b).shape)
>>> torch.Size[1, 2, 64, 64])

使用PyTorch求平方根报错如何解决-风君子博客

a = torch.rand3,2,64,32))
b = torch.rand1,2,32,64))
printtorch.matmula, b).shape)
>>> torch.Size[3, 2, 64, 64])

这样也是可以相乘的,因为这里涉及一个自动传播Broadcasting机制,这个在后面会讲,这里就知道,如果这种情况下,会把b的第一维度复制3次 ,然后变成和a一样的尺寸,进行矩阵相乘。

幂与开方

print'幂运算')
a = torch.tensor[1.,2.])
b = torch.tensor[2.,3.])
c1 = a ** b
c2 = torch.powa, b)
printc1,c2)
>>> tensor[1., 8.]) tensor[1., 8.])

和上面一样,不多说了。开方运算可以用torch.sqrt),当然也可以用a**0.5)。

对数运算

在上学的时候,我们知道ln是以e为底的,但是在pytorch中,并不是这样。

pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。

import numpy as np
print'对数运算')
a = torch.tensor[2,10,np.e])
printtorch.loga))
printtorch.log2a))
printtorch.log10a))
>>> tensor[0.6931, 2.3026, 1.0000])
>>> tensor[1.0000, 3.3219, 1.4427])
>>> tensor[0.3010, 1.0000, 0.4343])

近似值运算

.ceil) 向上取整

.floor)向下取整

.trunc)取整数

.frac)取小数

.round)四舍五入

.ceil) 向上取整.floor)向下取整.trunc)取整数.frac)取小数.round)四舍五入

a = torch.tensor1.2345)
printa.ceil))
>>>tensor2.)
printa.floor))
>>> tensor1.)
printa.trunc))
>>> tensor1.)
printa.frac))
>>> tensor0.2345)
printa.round))
>>> tensor1.)

剪裁运算

这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。

a = torch.rand5)
printa)
printa.clamp0.3,0.7))

使用PyTorch求平方根报错如何解决-风君子博客

pytorch的优点

1.PyTorch是相当简洁且高效快速的框架;2.设计追求最少的封装;3.设计符合人类思维,它让用户尽可能地专注于实现自己的想法;4.与google的Tensorflow类似,FAIR的支持足以确保PyTorch获得持续的开发更新;5.PyTorch作者亲自维护的论坛 供用户交流和求教问题6.入门简单