『PyTorch』Hook函数

函数机制:不改变主体,实现额外功能,像一个挂件

1. TensorObject.register_hook

  • 功能:注册一个反向传播hook函数

  • hook函数仅输入一个参数,为张量的梯度,输出为Tensor(会原位覆盖原梯度)或None

e.g. 获取梯度

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)
b = torch.add(w, 1)
y = torch.mul(a, b)

a_grad = list()

def grad_hook(grad):
    a_grad.append(grad)

handle = a.register_hook(grad_hook)

y.backward()
    
# 查看梯度
print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
print("a_grad[0]: ", a_grad[0])
handle.remove()   
    
##########输出############
# gradient: tensor([5.]) tensor([2.]) None None None
# a_grad[0]:  tensor([2.])

"""特别地,如果hook函数这样写"""
def grad_hook(grad):
    grad *= 2
    return grad*3

handle = w.register_hook(grad_hook)

y.backward()

# 查看梯度
print("w.grad: ", w.grad) # tensor([30.])
handle.remove()

# 【重点】"*="调用的是__mul__或者__imul__,优先寻找__imul__,其中,__imul__是原位操作
# 可变对象才允许有__imul__这种原位操作
# 【重点】添加的hook函数的return Tensor的话,会在原位覆盖

在Python中,对 += 或者 *=所做的增量赋值来说,如果左边的变量绑定的是不可变对象,会创建新对象;如果是可变对象,会就地修改

2 ModuleObject.register_forward_hook

  • 功能:注册Module的前向传播的hook函数,返回None
  • 主要参数:
    • module: 当前网络层
    • input: 当前网络层的输入数据
    • output: 当前网络层的输出数据

3 ModuleObject.register_forward_pre_hook

  • 功能: 注册Module的前向传播的hook函数,返回None
  • 主要参数:
    • module: 当前网络层
    • input: 当前层的输入数据

4 ModuleObject.register_backward_hook

  • 功能: 注册module反向传播的hook函数,返回Tensor or None
  • 主要参数:
    • module: 当前网络层
    • grad_input:当前网络层输入梯度数据
    • grad_output: 当前网络层输出梯度数据

这三个module对象的hook函数,在一个module中,总的执行流程为:

Created with Raphaël 2.2.0 input forward_pre_hooks forward forward_hooks backward_hooks output

整体类似面向切面编程

#===========================定义网络=================================
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.pool1 = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        return x

#==========================定义hook函数==============================
def forward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)


def forward_pre_hook(module, data_input):
    print("forward_pre_hook input:{}".format(data_input))


def backward_hook(module, grad_input, grad_output):
    print("backward hook input:{}".format(grad_input))
    print("backward hook output:{}".format(grad_output))
    
#=============================初始化网络==============================
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()

#======================注册hook函数到具体module=======================
fmap_block = list()
input_block = list()
net.conv1.register_forward_hook(forward_hook)
net.conv1.register_forward_pre_hook(forward_pre_hook)
net.conv1.register_backward_hook(backward_hook)

# inference
fake_img = torch.ones((1, 1, 4, 4))  # batch size * channel * H * W
output = net(fake_img)

loss_fnc = nn.L1Loss()
target = torch.randn_like(output)
loss = loss_fnc(target, output)
loss.backward()

# 观察
print("output shape: {}\noutput value: {}\n".format(output.shape, output))
print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值