PyTorch模型的保存与加载是怎么样的
PyTorch模型的保存与加载是怎么样的,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。
成都创新互联公司是专业的鄂温克网站建设公司,鄂温克接单;提供网站设计、成都网站建设,网页设计,网站设计,建网站,PHP网站建设等专业做网站服务;采用PHP框架,可快速的进行鄂温克网站开发网页制作和功能扩展;专业做搜索引擎喜爱的网站,专业的做网站团队,希望更多企业前来合作!
torch.save()和torch.load():
torch.save()和torch.load()配合使用, 分别用来保存一个对象(任何对象, 不一定要是PyTorch中的对象)到文件,和从文件中加载一个对象. 加载的时候可以指明是否需要数据在CPU和GPU中相互移动.
Module.state_dict()和Module.load_state_dict():
Module.state_dict()返回一个字典, 该字典以键值对的方式保存了Module的整个状态. Module.load_state_dict()可以从一个字典中加载参数到这个module和其后代, 如果strict是True, 那么所加载的字典和该module本身state_dict()方法返回的关键字必须严格确切的匹配上. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function. 返回值是一个命名元组: NamedTuple with missing_keys and unexpected_keys fields, 分别保存缺失的关键字和未预料到的关键字. 如果自己的模型跟预训练模型只有部分层是相同的, 那么可以只加载这部分相同的参数, 只要设置strict参数为False来忽略那些没有匹配到的keys即可。
# 方式1:# model_path = 'model_name.pth'# model_params_path = 'params_name.pth'# ----保存----# torch.save(model, model_path)# ----加载----# model = torch.load(model_path)# 方式2:#----保存----# torch.save(model.state_dict(), model_params_path) #保存的文件名后缀一般是.pt或.pth #----加载----# model=Model().cuda() #定义模型结构 # model.load_state_dict(torch.load(model_params_path)) #加载模型参数
说明:
# 保存/加载整个模型 torch.save(model, PATH) model = torch.load(PATH) model.eval() 这种保存/加载模型的过程使用了最直观的语法, 所用代码量少。这使用Python的pickle保存所有模块。 这种方法的缺点是,保存模型的时候, 序列化的数据被绑定到了特定的类和确切的目录。 这是因为pickle不保存模型类本身,而是保存这个类的路径, 并且在加载的时候会使用。因此, 当在其他项目里使用或者重构的时候,加载模型的时候会出错。 # 保存/加载 state_dict(推荐) torch.save(model.state_dict(), PATH) model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
自己选择要保存的参数,设置checkpoint:
#----保存----torch.save({ 'epoch': epoch + 1,'arch': args.arch,'state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(), 'loss': loss,'best_prec1': best_prec1,}, 'checkpoint_name.tar' )#----加载----checkpoint = torch.load('checkpoint_name.tar')#按关键字获取保存的参数 start_epoch = checkpoint['epoch']best_prec1 = checkpoint['best_prec1']state_dict=checkpoint['state_dict']model=Model()#定义模型结构 model.load_state_dict(state_dict)
保存多个模型到同一个文件:
#----保存----torch.save({ 'modelA_state_dict': modelA.state_dict(), 'modelB_state_dict': modelB.state_dict(), 'optimizerA_state_dict': optimizerA.state_dict(), 'optimizerB_state_dict': optimizerB.state_dict(), ... }, PATH)#----加载----modelA = TheModelAClass(*args, **kwargs)modelB = TheModelAClass(*args, **kwargs)optimizerA = TheOptimizerAClass(*args, **kwargs)optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)modelA.load_state_dict(checkpoint['modelA_state_dict']modelB.load_state_dict(checkpoint['modelB_state_dict']optimizerA.load_state_dict(checkpoint['optimizerA_state_dict']optimizerB.load_state_dict(checkpoint['optimizerB_state_dict']modelA.eval()modelB.eval()# or modelA.train()modelB.train()# 在这里,保存完模型后加载的时候有时会 # 遇到CUDA out of memory的问题, # 我google到的解决方法是加上map_location=‘cpu’ checkpoint = torch.load(PATH,map_location='cpu')
加载预训练模型的部分:
resnet152 = models.resnet152(pretrained=True) #加载模型结构和参数 pretrained_dict = resnet152.state_dict()"""加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数 也可以直接从官方model_zoo下载: pretrained_dict = model_zoo.load_url(model_urls['resnet152'])""" model_dict = model.state_dict()# 将pretrained_dict里不属于model_dict的键剔除掉 pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict}# 更新现有的model_dict model_dict.update(pretrained_dict)# 加载我们真正需要的state_dict model.load_state_dict(model_dict)
或者写详细一点:
model_dict = model.state_dict()state_dict = { }for k, v in pretrained_dict.items():if k in model_dict.keys():# state_dict.setdefault(k, v)state_dict[k] = velse:print("Missing key(s) in state_dict :{}".format(k))model_dict.update(state_dict)model.load_state_dict(model_dict)
关于PyTorch模型的保存与加载是怎么样的问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注创新互联行业资讯频道了解更多相关知识。
本文标题:PyTorch模型的保存与加载是怎么样的
文章出自:http://pcwzsj.com/article/gooecc.html