在 PyTorch 中,可以使用 requires_grad 属性来决定是否需要计算梯度,从而固定某些网络参数不进行更新。具体地,可以将需要固定的参数的 requires_grad 属性设置为 False,如下所示:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.fc1 = nn.Linear(32 * 8 * 8, 128)
self.fc2 = nn.Linear(128, 10)
# 将 conv1 的参数固定,不进行更新
self.conv1.weight.requires_grad = False
self.conv1.bias.requires_grad = False
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = self.conv2(x)
x = nn.functional.relu(x)
x = x.view(-1, 32 * 8 * 8)
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
在上述代码中,我们将 conv1 的权重和偏置的 requires_grad 属性设置为 False,这样在 backward() 计算梯度时,这些参数不会被更新。而其他参数的 requires_grad 属性默认为 True,会进行更新。
就比如 我们建立网络的卷积层,全连接层,requires_grad 属性默认都是 True,
在 PyTorch 中,nn.Linear 的默认 requires_grad 属性是 True,即默认情况下会对线性层的权重和偏置进行梯度计算。nn.Conv2d 的默认 requires_grad 属性是 True,即默认情况下会对卷积层的权重和偏置进行梯度计算。
所以固定参数,那就直接设置为requires_grad False