0%

GRU模型

GRU模型

GRU

GRU(Gated Recurrent Unit)也称门控制单元结构,它也是传统RNN的变体,同LSTM一样能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。同时它的结构和计算要比LSTM更简单,它的核心结构可以分为两个部分去解析:

  • 更新门
  • 重置门

GRU的内部结构图和计算公式

\[ z_t = \sigma(W_z\cdot [h_{t-1}, x_t]) \\ r_t=\sigma(W_r\cdot [h_{t-1},x_t]) \\ \widetilde h_t = \tanh (W\cdot [r_t*h_{t-1}, x_t]) \\ h_t = (1-z_t)*h_{t-1} + z_t * \widetilde h_t \]

## 结构解释图

GRU的更新门和重置门结构图

内部结构分析

  • 和之前分析过的LSTM中的门控一样,首先计算更新门和重置门的门值,分别是z(t)和r(t),计算方法就是使用X(t)与h(t-1)拼接进行线性变换,再经过sigmoid激活
  • 之后更新门门值作用在了h(t-1)上,代表控制上一时间步传来的信息有多少可以被利用
  • 接着就是使用这个更新后的h(t-1)进行基本的RNN计算,即与x(t)拼接进行线性变换,经过tanh激活,得到新的h(t)
  • 最后重置门的门值会作用在新的h(t),而1-门值会作用在h(t-1)上,随后将两者的结果相加,得到最终的隐含状态输出h(t)
  • 这个过程意味着重置门有能力重置之前所有的计算,当门值趋于1时,输出就是新的h(t),而当门值趋于0时,输出就是上一时间步的h(t-1)

Bi-GRU

  • Bi-GRU和Bi-LSTM的逻辑相同,都是不改变其内部结构,而是将模型应用两次且方向不同,再将再次得到的GRU结果进行拼接作为最终输出

Pytorch中的GRU工具使用

  • 位置:在torch.nn工具包中,通过torch.nn.GRU可调用

nn.GRU类初始化主要参数解释

  • input_size:输入张量x中的特征维度的大小
  • hidden_size:隐层张量h中特征维度的大小
  • num_layers:隐含层的数量
  • bidirectional:是否选择使用双向GRU,如果为True,则使用,默认不使用

nn.GRU类实例化对象主要参数解释

  • input:输入张量x
  • h0:初始化的隐层张量h

nn.GRU使用示例

1
2
3
4
5
6
7
8
9
10
11
import torch
import torch.nn as nn
# 5:输入维度;6:隐层维度;2:隐藏层数量
rnn = nn.GRU(5, 6, 2)
# 1:序列长度;3:批次大小;5:输入维度
input = torch.randn(1, 3, 5)
# 2:隐藏层层数;3:批次大小;6:隐层维度
h0 = torch.randn(2, 3, 6)
output, hn = rnn(input, h0)
print(output)
print(hn)

tensor([[[-0.3712, 0.5142, -0.3155, 0.0723, -0.4794, -0.3006], [ 1.0281, 0.8697, -2.3055, -0.3992, -0.7650, -0.5211], [ 0.5593, -0.2648, -0.6986, -0.4224, 0.9587, 0.2137]]], grad_fn=)

tensor([[[ 0.4705, 0.5385, -0.5387, 0.2447, -0.0955, -0.0045], [-0.4243, 1.7435, -0.6828, -1.5498, -1.0110, 0.1387], [ 0.4374, -0.6985, 0.3958, -0.5070, 0.0546, 0.2563]], [[-0.3712, 0.5142, -0.3155, 0.0723, -0.4794, -0.3006], [ 1.0281, 0.8697, -2.3055, -0.3992, -0.7650, -0.5211], [ 0.5593, -0.2648, -0.6986, -0.4224, 0.9587, 0.2137]]], grad_fn=)

GRU的优势

  • GRU和LSTM作用相同,在捕捉长序列语义关联时,能有效抑制梯度消失或爆炸,效果优于传统RNN且计算复杂度相比LSTM要小

GRU的缺点

  • GRU仍然不能完全解决梯度消失问题,同时其作用RNN的变体,有着RNN结构本身的一大弊端,即不可并行计算,这在数据量和模型体量逐渐增大的未来,是RNN发展的关键瓶颈。