本文共 1332 字,大约阅读时间需要 4 分钟。
作者:知乎用户
链接:https://www.zhihu.com/question/61044004/answer/183682138 来源:知乎这个hook设计初衷是啥,一般在什么场景下应用?
首先明确一点,有哪些hook?我看到的有3个:
第一个是register_hook,是针对Variable对象的,
后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。其次,明确一下,为什么需要用hook?
打个比方,有这么个函数
你想通过梯度下降法求最小值。在PyTorch里面很容易实现,你只需要:import torchfrom torch.autograd import Variablex = Variable(torch.randn(2, 1), requires_grad=True)y = x+2z = torch.mean(torch.pow(y, 2))lr = 1e-3z.backward()x.data -= lr*x.grad.data
但问题是,如果我想要求中间变量 y 的梯度,系统会返回错误。
事实上,如果你输入:
type(y.grad)
系统会告诉你:NoneType
这个问题在PyTorch的论坛上有人提问过,开发者说是因为当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉。
因此,hook就派上用场了。简而言之,register_hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。import torchfrom torch.autograd import Variablegrad_list = []def print_grad(grad): grad_list.append(grad)x = Variable(torch.randn(2, 1), requires_grad=True)y = x+2z = torch.mean(torch.pow(y, 2))lr = 1e-3y.register_hook(print_grad)z.backward()x.data -= lr*x.grad.data
需要注意的是,register_hook函数接收的是一个函数,这个函数有如下的形式:
hook(grad) -> Variable or None
也就是说,这个函数是拥有改变梯度值的威力的!
至于register_forward_hook和register_backward_hook的用法和这个大同小异。只不过对象从Variable改成了你自己定义的nn.Module。当你训练一个网络,想要提取中间层的参数、或者特征图的时候,使用hook就能派上用场了。
转载地址:http://zdxen.baihongyu.com/