Pytorch 学习(九):Pytorch 数据和模型存取
Pytorch 网络模型创建本方法总结自《动手学深度学习》(Pytorch版)github项目Pytorch 存储和读取主要依靠 load 和 save 函数模型存取依靠 load_state_dict() 函数数据存储与读取import torchpath = 'p.pth'# 'p.pt'a = torch.tensor(1)torch.save(a, path)b = torch.load(
·
Pytorch 数据和模型存取
本方法总结自《动手学深度学习》(Pytorch版)github项目
- Pytorch 存储和读取主要依靠 load 和 save 函数
- 模型存取依靠 load_state_dict() 函数
数据存储与读取
import torch
path = 'p.pth' # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)
模型存取
- 仅存储/加载模型参数
model = net()
state_dict = model.state_dict() # 模型状态
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
- 存储/加载整个模型
model = net()
torch.save(model, path)
model2 = torch.load(path)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献1条内容
所有评论(0)