PyTorch的hook机制

1. 简介

钩子编程(hooking),也称作「挂钩」,计算机程序设计术语,指通过拦截软件模块间的「函数调用」、「消息传递」、「事件传递」来修改或扩展操作系统、应用程序或其他软件组件的行为的各种技术。处理被拦截的函数调用、事件、消息的代码,被称为钩子(hook)

hook 在 PyTorch 中是一个非常有用的特性,利用它可以在不必改变网络输入输出结构的情况下,方便地获取、改变网络中间层变量的值和梯度。

2. Hook for Tensors

在 PyTorch 的计算图中,只有叶节点的变量会保留梯度,而所有中间变量的梯度只在反向传播中使用,一旦把向传播完成,中间变量的梯度将自动释放,从而节约内存。如果想在反向传播后保留他们的梯度,则需要使用 retain_grad() 函数特殊指定。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch

x = torch.Tensor([0, 1, 2, 3]).requires_grad_()
y = torch.Tensor([4, 5, 6, 7]).requires_grad_()
w = torch.Tensor([1, 2, 3, 4]).requires_grad_()
z = x + y
o = torch.matmul(w, z)
# z.retain_grad()
# o.retain_grad()
o.backward()

print('x.requires_grad:', x.requires_grad) # True
print('y.requires_grad:', y.requires_grad) # True
print('z.requires_grad:', z.requires_grad) # True
print('w.requires_grad:', w.requires_grad) # True
print('o.requires_grad:', o.requires_grad) # True

# print grad after backward w/o retain_grad
print('x.grad:', x.grad) # tensor([1., 2., 3., 4.])
print('y.grad:', y.grad) # tensor([1., 2., 3., 4.])
print('w.grad:', w.grad) # tensor([4., 6., 8., 10.])
print('z.grad:', z.grad) # None / tensor([1., 2., 3., 4.])
print('o.grad:', o.grad) # None / tensor(1.)

使用 retain_grad() 会增加内存的占用,更好的方法是使用 hook 保存中间变量的梯度。比如对于中间变量 zz,hook 的使用方法为:z.register_hook(hook_fn),其中 hook_fn 是一个用户自定义的函数,为简化表示也可以直接使用 lambda 表达式代替函数:

1
2
3
4
def hook_fn(grad): -> Tensor or None
z.register_hook(hook_fn)
# z.register_hook(lambda x: 2*x) # output Tensor
# z.register_hook(lambda x: print(x)) # output None

该函数输入为变量 zz 的梯度,输出为一个 Tensor 或 None。反向传播时,梯度传播到变量 zz 后,再继续往前传播之前,会先传入 hook_fn 函数。如果 hook_fn 返回值是 None,则梯度不改变,继续向前传播;如果 hook_fn 的返回值是 Tensor 类型,则该 Tensor 将取代变量 zz 原有的梯度,继续向前传播。

3. Hook for Modules

网络模块的 hook 不像 Tensor 一样具有显式的变量名可以访问,而是被封装在神经网络中。PyTorch 设计了两种 hook:register_forward_hookregister_backward_hook,分别用来获取前向传播和反向传播时中间层模块的输入和输出特征及梯度。

register_forward_hook

register_forward_hook 的作用是获取前向传播过程中,网络各模块的输入和输出。对于模块 module,其使用方法为:module.register_forward_hook(hook_fn),其中 hook_fn 为一个用户自定义的函数。

1
def hook_fn(module, input, output): -> Tensor or None

其中,hook_fn 函数的输入变量分别为模块、模块的输入和模块的输出,输出为 None 或 Tensor,用于修改模块的输出。使用该 hook,可以主方便地使用预训练的神经网格提取特征,而不用改变预训练网络的结构。

register_backward_hook

register_backward_hook 的作用是获取反向传播过程中,网络各模块的输入和输出。对于模块 module,其使用方法为:module.register_backward_hook(hook_fn),其中 hook_fn 为一个用户自定义的函数。

1
def hook_fn(module, grad_input, grad_output): -> Tensor or None

其中,hook_fn 函数的输入变量分别为模块、模块输入端的梯度和模块输出端的梯度,这里的输入输出端是站在前向传播的角度来说的。如果模块有多个输入端或输出端,则对应的梯度是 tuple 类型。比如对于线性模块,其 grad_input 是一个三元组,排列顺序分别为:对 bias 的导数、对输入 x 的导数、对权重 w 的导数。

附录


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!