深度学习模型的保存和加载

保存和加载模型

关于保存和加载模型,有三个核心功能需要熟悉:

  1. torch.save:将序列化的对象保存到磁盘。此函数使用Python的 pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。
  2. torch.load:使用pickle的解码功能将序列化的目标文件反序列化到内存中。
  3. torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典 。

什么是state_dict

在PyTorch中,模型的可学习参数(即权重和偏差) torch.nn.Module 包含在模型的参数中 (通过访问model.parameters())state_dict是一个简单的Python字典对象,每个层映射到其参数张量。请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)才在模型的state_dict中具有条目。优化器对象(torch.optim)还具有state_dict,其中包含有关优化器状态以及所用超参数的信息。

由于 state_dict 对象是Python词典,因此可以轻松地保存,更新,更改和还原它们,从而为PyTorch模型和优化器增加了很多模块化。

例如:

让我们从训练分类器 教程中使用的简单模型 看一下sta

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值