Pytorch 数据和模型存取

本方法总结自《动手学深度学习》(Pytorch版)github项目

  • Pytorch 存储和读取主要依靠 load 和 save 函数
  • 模型存取依靠 load_state_dict() 函数
数据存储与读取
import torch

path = 'p.pth'  # 'p.pt'
a = torch.tensor(1)
torch.save(a, path)
b = torch.load(path)
模型存取
  • 仅存储/加载模型参数
model = net()
state_dict = model.state_dict()  # 模型状态
torch.save(state_dict, path)
model2 = net()
model2.load_state_dict(torch.load(path))
  • 存储/加载整个模型
model = net()
torch.save(model, path)
model2 = torch.load(path)
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐