PyTorch碎片:深刻透彻理解Torch中Tensor.contiguous()函数
1.函数定义Returns a contiguous tensor containing the same data as self tensor.返回一个与原始tensor相同元素数据的 “连续”tensor类型If self tensor is contiguous, this function returns the self tensor.如果原始tensor本身就是连续的,则返...
1.函数定义
Returns a contiguous tensor containing the same data as self tensor.
返回一个与原始tensor相同元素数据的 “连续”tensor类型
If self tensor is contiguous, this function returns the self tensor.
如果原始tensor本身就是连续的,则返回原始tensor
2.定义理解
定义本身有两个重要的点:
对原始tensor进行复制
返回contiguous“类型”的一个tensor
Tensor.contiguous()函数不会对原始数据进行任何修改,而仅仅对其进行复制,并在内存空间上进行对齐,即在内存空间上,tensor元素的内存地址保持连续。
这么做的目的是,在对tensor元素进行转换和维度变换等操作之后,元素地址在内存空间中保证连续性,在后续利用指针对tensor元素进行读取时,能够减少读取便利,提高内存空间优化。
3.数据案例分析
import torch
src_t = torch.randn((2,3))
print(src_t.shape)
print(src_t.is_contiguous())
输出:
>>> torch.Size([2, 3])
>>> True
可以看出,在利用torch.randn函数进行tensor创建时,获取的tensor元素地址是连续内存空间保存的。那么,如果对创建的tensor进行transpose变换操作:
trans_t = src_t.transpose(0,1)
print(trans_t.shape)
print(trans_t.is_contiguous())
输出:
>>> torch.Size([3, 2])
>>> False
我们发现经过transpose变换以后,tensor变成非连续保存类型(uncontiguous)。
那么,变成这种非连续保存类型会造成什么样的后果呢?
简单的以view函数为例:
trans_t.view(-1,3)
当尝试对uncontiguous类型tensor进行维度变换时,就会出现下面错误:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
错误提示告诉我们,至少有一个维度数据在内存空间上跨越了两个连续子空间!此时,我们输出trans_t的连续保存类型是什么:
print(trans_t.is_contiguous())
>>> False
因此,为了能够实现对张量trans_t的维度变换,需要先对tensor进行contiguous内存地址对齐操作,然后再进行view操作:
print(trans_t.shape)
trans_t.contiguous().view(-1,3)
print(trans_t.shape)
>>> torch.Size([3, 2])
>>> torch.Size([2, 3])
4.总结
总结一下,为了保证代码的可读性和严谨性,当对tensor进行维度变化时,常需要配合contiguous函数使用,但是哪些函数会造成原始tensor变的uncontiguous呢?
transpose()
narrow()
expand()
有其他函数,我会进一步补充,有错误欢迎指正,谢谢!
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)