tensorflow和pytorch模型之间转换
参考链接:https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py一. tensorflow模型转pytorch模型import tensorflow as tfimport deepdish as ddimport argparseimport osimport numpy as npdef tr(v):# t
·
参考链接:
https://github.com/bermanmaxim/jaccardSegment/blob/master/ckpt_to_dd.py
一. tensorflow模型转pytorch模型
import tensorflow as tf
import deepdish as dd
import argparse
import os
import numpy as np
def tr(v):
# tensorflow weights to pytorch weights
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3,2,0,1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def read_ckpt(ckpt):
# https://github.com/tensorflow/tensorflow/issues/1823
reader = tf.train.NewCheckpointReader(ckpt)
weights = {n: reader.get_tensor(n) for (n, _) in reader.get_variable_to_shape_map().items()}
pyweights = {k: tr(v) for (k, v) in weights.items()}
return pyweights
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Converts ckpt weights to deepdish hdf5")
parser.add_argument("infile", type=str,
help="Path to the ckpt.") # ***model.ckpt-22177***
parser.add_argument("outfile", type=str, nargs='?', default='',
help="Output file (inferred if missing).")
args = parser.parse_args()
if args.outfile == '':
args.outfile = os.path.splitext(args.infile)[0] + '.h5'
outdir = os.path.dirname(args.outfile)
if not os.path.exists(outdir):
os.makedirs(outdir)
weights = read_ckpt(args.infile)
dd.io.save(args.outfile, weights)
1.运行上述代码后会得到model.h5模型,如下:
备注:保持tensorflow和pytorch使用的python版本一致
2.使用:在pytorch内加载改模型:
这里假设网络保存时参数命名一致
net = ...
import torch
import deepdish as dd
net = resnet50(..)
model_dict = net.state_dict()
#先将参数值numpy转换为tensor形式
pretrained_dict = = dd.io.load('./model.h5')
new_pre_dict = {}
for k,v in pretrained_dict.items():
new_pre_dict[k] = torch.Tensor(v)
#更新
model_dict.update(new_pre_dict)
#加载
net.load_state_dict(model_dict)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
已为社区贡献1条内容
所有评论(0)