深入浅出PyTorch(三)pytorch模型定义

此博客为datawhle开源学习平台2025年2月组队学习打卡笔记链接,开源学习手册链接视频教程链接

pytorch模型定义

模型在深度学习中扮演着重要的角色,好的模型极大地促进了深度学习的发展进步,比如CNN的提出解决了图像、视频处理中的诸多问题,RNN/LSTM模型解决了序列数据处理的问题,GNN在图模型上发挥着重要的作用。当我们在向他人介绍一项深度学习工作的时候,对方可能首先要问的就是使用了哪些模型。因此,在PyTorch进阶操作的第一部分中,我们首先来学习PyTorch模型相关的内容。

Module 类是 torch.nn 模块里提供的一个模型构造类 (nn.Module),是所有神经⽹网络模块的基类,我们可以继承它来定义我们想要的模型
PyTorch模型定义应包括两个主要部分:

  1. 各个部分的初始化__init__
  2. 数据流向定义(forward)

PyTorch中模型定义的三种方式

基于nn.Module,我们可以通过Sequential,ModuleList和ModuleDict三种方式定义PyTorch模型

Sequential

对应模块为nn.Sequential()
定义方式:

from collections import OrderedDict
class MySequential(nn.Module):
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
            for key, module in args[0].items():
                self.add_module(key, module)  
                # add_module方法会将module添加进self._modules(一个OrderedDict)
        else:  # 传入的是一些Module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
    def forward(self, input):
        # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成
        for module in self._modules.values():
            input = module(input)
        return input

用Sequential来定义模型只需要将模型的层按序排列起来即可,根据层名的不同,排列的时候有两种方式:

  1. 直接排列
import torch.nn as nn
net = nn.Sequential(
        nn.Linear(784, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net)
Sequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)
  1. 使用OrderedDict
import collections
import torch.nn as nn
net2 = nn.Sequential(collections.OrderedDict([
          ('fc1', nn.Linear(784, 256)),
          ('relu1', nn.ReLU()),
          ('fc2', nn.Linear(256, 10))
          ]))
print(net2)
Sequential(
  (fc1): Linear(in_features=784, out_features=256, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=256, out_features=10, bias=True)
)

优缺点:使用Sequential定义模型的好处在于简单、易读,同时使用Sequential定义的模型不需要再写forward,因为顺序已经定义好了。但使用Sequential也会使得模型定义丧失灵活性,比如需要在模型中间加入一个外部输入时就不适合用Sequential的方式实现。使用时需根据实际需求加以选择。

ModuleList

对应模块为nn.ModuleList()
ModuleList 接收一个子模块(或层,需属于nn.Module类)的列表作为输入,然后也可以类似List那样进行append和extend操作。同时,子模块或层的权重也会自动添加到网络中来。

net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Techhuman

创作不易,请您支持!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值