PyTorch:torch.max、min、argmax、argmin
1、torch.max函数定义:torch.max(input, dim, max=None, max_indices=None) -> (Tensor, LongTensor)作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的位置索引。应用举例:例1——返回相应维度上的最大值import torcha = torch.randint(2, 10,(6,4))
·
目录
1、torch.max
函数定义:
torch.max(input, dim, max=None, max_indices=None, keepdim=False) -> (Tensor, LongTensor)
作用:找出给定tensor的指定维度dim上的上的最大值,并返回最大值在该维度上的值和位置索引。
应用举例:
例1——返回相应维度上的最大值
import torch
a = torch.randint(2, 10,(6,4)) # 创建shape为6*4,值为[2,10]的随机整数的tensor
b, max_index = torch.max(a, dim=1) # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('max_index:', max_index)
''' 输出结果 '''
a: tensor([[9, 6, 6, 5],
[5, 7, 5, 8],
[2, 2, 7, 9],
[8, 9, 3, 5],
[8, 7, 3, 3],
[9, 6, 9, 3]])
b: tensor([9, 8, 9, 9, 8, 9])
max_index: tensor([0, 3, 3, 1, 0, 2])
例2——如果max的参数只有一个tensor,则返回该tensor里所有值中的最大值。
import torch
a = torch.randint(2, 10,(6,4)) # 创建shape为6*4,值为[2,10]的随机整数的tensor
b = torch.max(a) # 找出a的所有元素中的最大值,返回结果
print('a:', a)
print('b:', b)
''' 输出结果 '''
a: tensor([[8, 2, 2, 4],
[7, 4, 3, 4],
[4, 4, 3, 4],
[9, 7, 7, 2],
[5, 4, 7, 9],
[4, 5, 7, 5]])
b: tensor(9)
例3——如果max的参数是两个相同shape的tensor,则返回两tensor对应的最大值的新tensor。
import torch
a = torch.randint(2, 10,(6,4)) # 创建shape为6*4,值为[2,10]的随机整数的tensor
b = torch.randint(2, 10,(6,4)) # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
c = torch.max(a, b) # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('c:', c)
''' 运行结果 '''
a: tensor([[4, 6, 3, 4],
[2, 2, 8, 3],
[6, 2, 6, 8],
[3, 9, 8, 5],
[4, 7, 4, 4],
[9, 5, 8, 3]])
b: tensor([[8, 2, 3, 9],
[6, 7, 4, 6],
[8, 9, 3, 6],
[8, 4, 7, 5],
[9, 3, 7, 6],
[4, 7, 9, 6]])
c: tensor([[8, 6, 3, 9],
[6, 7, 8, 6],
[8, 9, 6, 8],
[8, 9, 8, 5],
[9, 7, 7, 6],
[9, 7, 9, 6]])
例4——keepdim=True, 返回的值和位置索引保持原有的维度数。
import torch
a = torch.randint(2, 10,(6,4)) # 创建shape为6*4,值为[2,10]的随机整数的tensor
b, max_index = torch.max(a, dim=1, keepdim=True) # 找出a的第1维度(列)上的最大值,返回结果和最大值在相应维度的序号
print('a:', a)
print('b:', b)
print('max_index:', max_index)
#=============运行结果===============#
a: tensor([[6, 7, 6, 5],
[7, 6, 2, 3],
[2, 3, 7, 3],
[4, 7, 4, 8],
[5, 7, 7, 6],
[5, 4, 5, 6]])
b: tensor([[7],
[7],
[7],
[8],
[7],
[6]])
max_index: tensor([[1],
[0],
[2],
[3],
[2],
[3]])
2、torch.argmax
定义:
torch.argmax(input, dim, keepdim=False) → LongTensor
作用:返回输入张量中指定维度的最大值的索引。
举例说明:
例1——指定维度:返回相应维度最大值的索引
import torch
a = torch.randint(9,(3, 3))
max_index = torch.argmax(a, dim=0)
print('a:\n', a)
print('max_index:\n', max_index)
''' 运行结果 '''
a:
tensor([[1, 1, 5],
[2, 8, 1],
[3, 7, 3]])
max_index:
tensor([2, 1, 0])
例2——不指定维度,返回整体上最大值的序号
import torch
a = torch.randint(9,(3, 3))
max_index = torch.argmax(a)
print('a:\n', a)
print('max_index:\n', max_index)
''' 运行结果 '''
a:
tensor([[5, 2, 2],
[7, 2, 0],
[8, 0, 6]])
max_index:
tensor(6) # 注:tensor在内存中是顺序存储,所以8所在的序号是6
3、torch.min
用法同max。
4、torch.argmin
用法同argmax。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献16条内容
所有评论(0)