博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
pytorch笔记:pytorch中的钩子(Hook)有何作用?
阅读量:3905 次
发布时间:2019-05-23

本文共 1332 字,大约阅读时间需要 4 分钟。

引言:

作者:知乎用户

链接:https://www.zhihu.com/question/61044004/answer/183682138
来源:知乎

问:

这个hook设计初衷是啥,一般在什么场景下应用?

答:

首先明确一点,有哪些hook?我看到的有3个:

  1. (Python method, in Automatic differentiation package
  2. (Python method, in torch.nn)

第一个是register_hook,是针对Variable对象的,

后面的两个:register_backward_hook和register_forward_hook是针对nn.Module这个对象的。

其次,明确一下,为什么需要用hook?

打个比方,有这么个函数

在这里插入图片描述
你想通过梯度下降法求最小值。在PyTorch里面很容易实现,你只需要:

import torchfrom torch.autograd import Variablex = Variable(torch.randn(2, 1), requires_grad=True)y = x+2z = torch.mean(torch.pow(y, 2))lr = 1e-3z.backward()x.data -= lr*x.grad.data

但问题是,如果我想要求中间变量 y 的梯度,系统会返回错误。

事实上,如果你输入:

type(y.grad)

系统会告诉你:NoneType

这个问题在PyTorch的论坛上有人提问过,开发者说是因为当初开发时设计的是,对于中间变量,一旦它们完成了自身反传的使命,就会被释放掉。

在这里插入图片描述
因此,hook就派上用场了。简而言之,register_hook的作用是,当反传时,除了完成原有的反传,额外多完成一些任务。你可以定义一个中间变量的hook,将它的grad值打印出来,当然你也可以定义一个全局列表,将每次的grad值添加到里面去。

import torchfrom torch.autograd import Variablegrad_list = []def print_grad(grad):    grad_list.append(grad)x = Variable(torch.randn(2, 1), requires_grad=True)y = x+2z = torch.mean(torch.pow(y, 2))lr = 1e-3y.register_hook(print_grad)z.backward()x.data -= lr*x.grad.data

需要注意的是,register_hook函数接收的是一个函数,这个函数有如下的形式:

hook(grad) -> Variable or None

也就是说,这个函数是拥有改变梯度值的威力的!

至于register_forward_hook和register_backward_hook的用法和这个大同小异。只不过对象从Variable改成了你自己定义的nn.Module。当你训练一个网络,想要提取中间层的参数、或者特征图的时候,使用hook就能派上用场了。

转载地址:http://zdxen.baihongyu.com/

你可能感兴趣的文章
对SOAP消息头的处理
查看>>
webservice TCP Monitor
查看>>
各系统下查看cpu物理和逻辑个数
查看>>
Oracle中sysdate的时区偏差
查看>>
Oracle的时区
查看>>
oracle 时区
查看>>
oracle sysdate,current_date,current_timestamp
查看>>
java轻松开发http server
查看>>
JDK6.0的新特性:轻量级Http Server
查看>>
Http协议客户端的JAVA简单实现
查看>>
ava URLConnection 总结
查看>>
HTTP 文件上传的基本原理
查看>>
java System.in 使用
查看>>
递归倒序输出字符串
查看>>
临近毕业,图像类SCI源刊哪本审稿快?
查看>>
【每日一算】二分查找
查看>>
【每日一算】旋转有序数组
查看>>
【每日一算】两数之和
查看>>
深入理解Mysql索引底层数据结构与算法
查看>>
B+tree结构详解
查看>>