函数机制:不改变主体,实现额外功能,像一个挂件
1. TensorObject.register_hook
-
功能:注册一个反向传播hook函数
-
hook函数仅输入一个参数,为张量的梯度,输出为Tensor(会原位覆盖原梯度)或None
e.g. 获取梯度
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)
a_grad = list()
def grad_hook(grad):
a_grad.append(grad)
handle = a.register_hook(grad_hook)
y.backward()
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()
##########输出############
# gradient: tensor([5.]) tensor([2.]) None None None
# a_grad[0]: tensor([2.])
"""特别地,如果hook函数这样写"""
def grad_hook(grad):
grad *= 2
return grad*3
handle = w.register_hook(grad_hook)
y.backward()
# 查看梯度
print("w.grad: ", w.grad) # tensor([30.])
handle.remove()
# 【重点】"*="调用的是__mul__或者__imul__,优先寻找__imul__,其中,__imul__是原位操作
# 可变对象才允许有__imul__这种原位操作
# 【重点】添加的hook函数的return Tensor的话,会在原位覆盖
在Python中,对 += 或者 *=所做的增量赋值来说,如果左边的变量绑定的是不可变对象,会创建新对象;如果是可变对象,会就地修改
2 ModuleObject.register_forward_hook
- 功能:注册Module的前向传播的hook函数,返回None
- 主要参数:
- module: 当前网络层
- input: 当前网络层的输入数据
- output: 当前网络层的输出数据
3 ModuleObject.register_forward_pre_hook
- 功能: 注册Module的前向传播前的hook函数,返回None
- 主要参数:
- module: 当前网络层
- input: 当前层的输入数据
4 ModuleObject.register_backward_hook
- 功能: 注册module反向传播的hook函数,返回Tensor or None
- 主要参数:
- module: 当前网络层
- grad_input:当前网络层输入梯度数据
- grad_output: 当前网络层输出梯度数据
这三个module对象的hook函数,在一个module中,总的执行流程为:
整体类似面向切面编程
#===========================定义网络=================================
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
#==========================定义hook函数==============================
def forward_hook(module, data_input, data_output):
fmap_block.append(data_output)
input_block.append(data_input)
def forward_pre_hook(module, data_input):
print("forward_pre_hook input:{}".format(data_input))
def backward_hook(module, grad_input, grad_output):
print("backward hook input:{}".format(grad_input))
print("backward hook output:{}".format(grad_output))
#=============================初始化网络==============================
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
#======================注册hook函数到具体module=======================
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)
# inference
fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W
output = net(fake_img)
loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()
# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))