保存和加载模型
关于保存和加载模型,有三个核心功能需要熟悉:
- torch.save:将序列化的对象保存到磁盘。此函数使用Python的 pickle实用程序进行序列化。使用此功能可以保存各种对象的模型,张量和字典。
- torch.load:使用pickle的解码功能将序列化的目标文件反序列化到内存中。
- torch.nn.Module.load_state_dict:使用反序列化的state_dict加载模型的参数字典 。
什么是state_dict
?
在PyTorch中,模型的可学习参数(即权重和偏差) torch.nn.Module
包含在模型的参数中 (通过访问model.parameters()
)state_dict是一个简单的Python字典对象,每个层映射到其参数张量。请注意,只有具有可学习参数的层(卷积层,线性层等)和已注册的缓冲区(batchnorm的running_mean)才在模型的state_dict中具有条目。优化器对象(torch.optim
)还具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
由于 state_dict 对象是Python词典,因此可以轻松地保存,更新,更改和还原它们,从而为PyTorch模型和优化器增加了很多模块化。
例如:
让我们从训练分类器 教程中使用的简单模型 看一下sta