Pytorch的模型文件一般会保存为.pth文件,C++接口一般读取的是.pt文件,因此,C++在调用Pytorch训练好的模型文件的时候就需要进行一个转换,转换为.pt文件,才能够读取。

所以在转换的时候,首先就需要先将模型文件读取进来,然后利用pytorch提供的函数torch.jit.trace进行转换,这个函数的声明为:

def trace(func,
          example_inputs,
          optimize=True,
          check_trace=True,
          check_inputs=None,
          check_tolerance=1e-5,
          _force_outplace=False,
          _module_class=None):

也就是,第一个参数为输入的模型,第二个参数为输入的带测试数据,通常其数据形式要跟模型的输入数据的形式是一样的。

转换的代码例子如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchsummary import summary

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.fc1 = nn.Linear(4*4*64, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

model = torch.load("mnist_cnn.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

summary(model, input_size=(1, 28, 28))
model = model.to(device)
traced_script_module = torch.jit.trace(model, torch.ones(1, 1, 28, 28).to(device))
traced_script_module.save("mnist_cnn_cc1.pt")

 

凤兮凤兮,何德之衰。

往者不可谏,来者犹可追。

已而已而。今之从政者殆而。

-- 《楚狂接舆歌》

 

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐