各位小伙伴肯定看到过下面这段代码:

correct += (predicted == labels).sum().item()

这里面(predicted == labels)是布尔型,为什么可以接sum()呢?

我做了个测试,如果这里的predicted和labels是列表形式就会报错,如果是numpy的数组格式,会返回一个值,如果是tensor形式,就会返回一个张量。

举个例子:

import torch

a = torch.tensor([1,2,3])
b = torch.tensor([1,3,2])

print((a == b).sum())

上述代码的输出结果:

tensor(1)

如果将a和b改成numpy下的数组格式:

import numpy as np

a = np.array([1,2,3])
b = np.array([1,3,2])

print((a == b).sum())

上述代码的输出结果:

1

如果将a和b改成列表:

a = [1,2,3]
b = [1,3,2]

print((a == b).sum())

上述代码的输出结果:

Traceback (most recent call last):
  File "路径", line 4, in <module>
    print((a == b).sum())
AttributeError: 'bool' object has no attribute 'sum'

Process finished with exit code 1

Added:

.item()用于取出tensor中的值。

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐