pytorch中获取模型参数:state_dict和parameters两个方法的差异比较
一、本文的模型案例代码如下:import torchimport torch.nn.functional as Ffrom torch.optim import SGDclass MyNet(torch.nn.Module):def __init__(self):super(MyNet, self).__init__()# 第一句话,调用父类的构...
一、本文的模型案例
代码如下:
import torch
import torch.nn.functional as F
from torch.optim import SGD
class MyNet(torch.nn.Module):
def __init__(self):
super(MyNet, self).__init__() # 第一句话,调用父类的构造函数
self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu1=torch.nn.ReLU()
self.max_pooling1=torch.nn.MaxPool2d(2,1)
self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
self.relu2=torch.nn.ReLU()
self.max_pooling2=torch.nn.MaxPool2d(2,1)
self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
self.dense2 = torch.nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.max_pooling1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.max_pooling2(x)
x = self.dense1(x)
x = self.dense2(x)
return x
model = MyNet() # 构造模型
二、model.state_dict()方法
pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)。这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。
2.1 Module的层的权值以及bias查看
print(type(model.state_dict())) # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”
for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
'''
conv1.weight torch.Size([32, 3, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([32, 3, 3, 3])
conv2.bias torch.Size([32])
dense1.weight torch.Size([128, 288])
dense1.bias torch.Size([128])
dense2.weight torch.Size([10, 128])
dense2.bias torch.Size([10])
'''
2.2 优化器optimizer的state_dict()方法
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9)
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
'''
state {}
param_groups [{'lr': 0.001,
'momentum': 0.9,
'dampening': 0,
'weight_decay': 0,
'nesterov': False,
'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208,
1412966613280, 1412966613352, 1412966613496, 1412966613568]
}]
'''
三、model.parameters()方法
这个方法也会获得模型的参数信息,如下:
print(type(model.parameters())) # 返回的是一个generator
for para in model.parameters():
print(para.size()) # 只查看形状
'''
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([128, 288])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
'''
从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。
其实Module还有一个与parameters类似的函数,named_parameters,而且parameters正是通过named_parameters来实现的,
看一下parameters的定义,很简单:
def parameters(self, recurse=True):
for name, param in self.named_parameters(recurse=recurse):
yield param
来一起看一下named_parameters的简单使用。
print(type(model.named_parameters())) # 返回的是一个generator
for para in model.named_parameters(): # 返回的每一个元素是一个元组 tuple
'''
是一个元组 tuple ,元组的第一个元素是参数所对应的名称,第二个元素就是对应的参数值
'''
print(para[0],'\t',para[1].size())
'''
conv1.weight torch.Size([32, 3, 3, 3])
conv1.bias torch.Size([32])
conv2.weight torch.Size([32, 3, 3, 3])
conv2.bias torch.Size([32])
dense1.weight torch.Size([128, 288])
dense1.bias torch.Size([128])
dense2.weight torch.Size([10, 128])
dense2.bias torch.Size([10])
'''
总结:model.state_dict()、model.parameters()、model.named_parameters()这三个方法都可以查看Module的参数信息,用于更新参数,或者用于模型的保存。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)