前言
Hi,我是GISerLiu🙂, 这篇文章是参加2025年5月datawhale学习赛的打卡文章!💡 本文详细解析了LSTM的工作原理,包括LSTM的结构设计、门控机制,以及在时序数据处理中的实际应用。
长短期记忆网络(Long Short-Term Memory, LSTM)是一种特殊的RNN结构,能有效解决传统循环神经网络中的长期依赖问题。本文将基于Staudemeyer和Morris的教程详细解析LSTM的工作原理,并结合PyTorch实现展示其在处理含缺失值时序数据中的应用。
一、时序数据处理的挑战
时序数据在现实世界中随处可见,从股票价格到医疗监测信号,从语音识别到自然语言处理。然而,这类数据通常面临两大挑战:
1.数据"瑕疵"问题
一般采集到的时序数据可能会有"瑕疵",例如包含缺失值、异常数据点和样本等等。虽然这些"瑕疵"也属于数据的特性(例如缺失值的模式可能代表了设备的异常情况),但在一般的建模中是有害的,不被我们需要的,会阻碍算法的学习甚至让我们无法使用神经网络直接对其进行端到端的建模(例如很多的时序预测算法和分类算法无法处理缺失值)。
2.长期依赖问题
时序数据通常需要模型理解长时间跨度的模式和依赖关系。例如,在语句"我在中国生活了很多年,所以我的汉语…“中,要预测"很流利”,模型需要记住较远位置的"中国"信息。
二、RNN及其局限性
0.通俗理解RNN
想象一下你在读一本小说。当你读到第100页时,你能理解当前发生的事情,是因为你记得前面99页的内容。你的大脑不是每翻一页就重新思考,而是持续更新对故事的理解。
这就是循环神经网络(RNN)的核心思想!
RNN就像一个有"记忆"的网络:
- 📚 记忆传递:每读完一个词,它不仅产生当前的理解,还会更新"记忆笔记"传给下一步
- 🔄 循环连接:前一时刻的状态会影响当前的判断(就像你读小说时前面的情节会影响你对当前情节的理解)
- ⏱ 时序处理:特别适合处理有先后顺序的数据,如文本、语音、股票价格等
简单例子:假设RNN在处理句子"今天_很好":
- 看到"今天"→更新记忆→输出初步理解
- 看到"天气"→结合"今天"的记忆,更新记忆→输出更新的理解
- 看到"很好"→结合"今天天气"的记忆,给出最终输出
但RNN也有"健忘"的问题——当句子很长时,它可能已经忘记了开头的内容,这就是著名的"长期依赖问题"。
1.RNN的基本结构
循环神经网络(RNN)是处理序列数据的基础架构:
RNN的核心公式:
h t = tanh ( W x h x t + W h h h t − 1 + b h ) h_t = \tanh(W_{xh}x_t + W_{hh}h_{t-1} + b_h) ht=tanh(Wxhxt+Whhht−1+bh)
y t = W h y h t + b y y_t = W_{hy}h_t + b_y yt=Whyht+by
2.梯度消失与爆炸问题
问题 | 描述 | 影响 |
---|---|---|
梯度消失 | 梯度在反向传播过程中逐渐变得非常小 | 远距离依赖无法学习 |
梯度爆炸 | 梯度在反向传播过程中无限增大 | 权重更新不稳定,模型发散 |
如Staudemeyer和Morris在论文中指出,传统RNN通常只能有效学习5-10个时间步长的依赖关系,这远远不够处理实际应用中的长序列数据。
三、LSTM的核心原理
0.通俗理解LSTM
想象RNN是一个健忘的笔记员,而LSTM是一个更聪明的助手,配备了精密的"记忆管理系统"。
LSTM如何解决"健忘"问题?
LSTM像是给神经网络配备了一个智能记忆管理系统,包含三个关键"开关"(门控机制):
-
遗忘门 🔄:相当于一个智能"删除键"
- 功能:决定哪些旧信息值得保留,哪些应该丢弃
- 例子:在阅读"天气转晴,温度升高"后,可能要淡忘"下雨"的信息
-
输入门 ➕:相当于一个智能"保存键"
- 功能:决定哪些新信息值得记录
- 例子:在获取"股票价格上涨10%"信息时,这是重要变化,应该记录下来
-
输出门 📤:相当于一个智能"分享键"
- 功能:决定当前要输出哪些记忆内容
- 例子:回答问题时,从记忆中提取相关信息,而非全部内容
还有一条特殊的"长期记忆"通道(单元状态),就像是一条高速公路,信息可以几乎不变地从头传到尾,解决了长距离依赖问题。
生活类比:
- 传统RNN就像用一个小本子记笔记,每写一页都会回顾前一页,但很快就会忘记前面的内容
- LSTM就像一个智能笔记本,有不同的标记系统,重要的信息用荧光笔标出并定期复习,确保长期记忆
这种设计让LSTM特别擅长学习长序列中的重要模式,无论是文本、股票数据还是医疗信号。
1.整体架构
LSTM通过引入门控机制和记忆单元解决了长期依赖问题:
2.门控机制详解
LSTM引入了三个门来控制信息流:
- 遗忘门:决定丢弃哪些旧信息
f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft=σ(Wf⋅[ht−1,xt]+bf)
- 输入门:决定更新哪些新信息
i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it=σ(Wi⋅[ht−1,xt]+bi)
C ~ t = tanh ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C) C~t=tanh(WC⋅[ht−1,xt]+bC)
- 单元状态更新:保持长期记忆
C t = f t ⊙ C t − 1 + i t ⊙ C ~ t C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t Ct=ft⊙Ct−1+it⊙C~t
- 输出门:决定输出哪些信息
o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot=σ(Wo⋅[ht−1,xt]+bo)
h t = o t ⊙ tanh ( C t ) h_t = o_t \odot \tanh(C_t) ht=ot⊙tanh(Ct)
其中, σ \sigma σ是sigmoid函数, ⊙ \odot ⊙表示逐元素乘法。
四、两阶段时序处理实战
时序数据处理常采用两阶段策略:
- 上游阶段:数据预处理(如缺失值插补)
- 下游阶段:目标任务建模(如分类、预测)
1.基于PyTorch实现LSTM分类器
以下是本文示例中完整的LSTM分类器实现:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from copy import deepcopy
# 设置模型的运行设备为cpu, 如果你有gpu设备可以设置为cuda
DEVICE='cpu'
class LoadImputedDataAndLabel(Dataset):
def __init__(self, imputed_data, labels):
self.imputed_data = imputed_data
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return (
torch.from_numpy(self.imputed_data[idx]).to(torch.float32),
torch.tensor(self.labels[idx]).to(torch.long),
)
class ClassificationLSTM(torch.nn.Module):
def __init__(self, n_features, rnn_hidden_size, n_classes):
super().__init__()
self.rnn = torch.nn.LSTM(
n_features,
hidden_size=rnn_hidden_size,
batch_first=True,
)
self.fcn = torch.nn.Linear(rnn_hidden_size, n_classes)
def forward(self, data):
hidden_states, _ = self.rnn(data)
logits = self.fcn(hidden_states[:, -1, :])
prediction_probabilities = torch.sigmoid(logits)
return prediction_probabilities
2.模型训练与优化策略
下面是训练函数实现,包含了早停策略:
def train(model, train_dataloader, val_dataloader, test_loader):
n_epochs = 20
patience = 5
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
current_patience = patience
best_loss = float("inf")
for epoch in range(n_epochs):
model.train()
for idx, data in enumerate(train_dataloader):
X, y = map(lambda x: x.to(DEVICE), data)
optimizer.zero_grad()
probabilities = model(X)
loss = F.cross_entropy(probabilities, y.reshape(-1))
loss.backward()
optimizer.step()
model.eval()
loss_collector = []
with torch.no_grad():
for idx, data in enumerate(val_dataloader):
X, y = map(lambda x: x.to(DEVICE), data)
probabilities = model(X)
loss = F.cross_entropy(probabilities, y.reshape(-1))
loss_collector.append(loss.item())
loss = np.asarray(loss_collector).mean()
if best_loss > loss:
current_patience = patience
best_loss = loss
best_model = deepcopy(model.state_dict())
else:
current_patience -= 1
if current_patience == 0:
break
model.load_state_dict(best_model)
model.eval()
probability_collector = []
for idx, data in enumerate(test_loader):
X, y = map(lambda x: x.to(DEVICE), data)
probabilities = model.forward(X)
probability_collector += probabilities.cpu().tolist()
probability_collector = np.asarray(probability_collector)
return probability_collector
上述训练过程中的关键技术点:
技术 | 实现方式 | 作用 |
---|---|---|
优化器选择 | Adam | 自适应学习率调整,加速收敛 |
损失函数 | 交叉熵 | 适合分类任务的损失度量 |
早停策略 | patience机制 | 防止过拟合,节省训练时间 |
模型保存 | best_model | 保存验证集上表现最佳的模型参数 |
3.数据加载与处理
from pypots.data.saving import pickle_load
def get_dataloaders(train_X, train_y, val_X, val_y, test_X, test_y, batch_size=128):
train_set = LoadImputedDataAndLabel(train_X, train_y)
val_set = LoadImputedDataAndLabel(val_X, val_y)
test_set = LoadImputedDataAndLabel(test_X, test_y)
train_loader = DataLoader(train_set, batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size, shuffle=False)
return train_loader, val_loader, test_loader
# 重新加载插补后的数据
imputed_physionet2012 = pickle_load('result_saving/imputed_physionet2012.pkl')
train_X, val_X, test_X = imputed_physionet2012['train_set_imputation'], imputed_physionet2012['val_set_imputation'], imputed_physionet2012['test_set_imputation']
# 这里我们只需要原数据集中相应样本的标签
train_y, val_y, test_y = imputed_physionet2012['train_set_labels'], imputed_physionet2012['val_set_labels'], imputed_physionet2012['test_set_labels']
# 转换成torch dataloader
train_loader, val_loader, test_loader = get_dataloaders(
train_X,
train_y,
val_X,
val_y,
test_X,
test_y,
)
4.模型初始化与训练
# 初始化LSTM分类器
rnn_classifier = ClassificationLSTM(
n_features=37,
rnn_hidden_size=128,
n_classes=2, # physionet2012是一个二分类数据集
)
# 训练LSTM分类器
proba_predictions = train(rnn_classifier, train_loader, val_loader, test_loader)
五、结果分析
在Physionet2012医疗数据集上的性能评估如下:
from pypots.nn.functional.classification import calc_binary_classification_metrics
pos_num = test_y.sum()
neg_num = len(test_y) - test_y.sum()
print(f'test_set中的正负样本比例为{pos_num}:{neg_num}, 正样本占样本数量的{pos_num/len(test_y)}, 所以这是一个不平衡的二分类问题, 故我们在此使用ROC-AUC和PR-AUC作为评价指标\n')
classification_metrics=calc_binary_classification_metrics(
proba_predictions, test_y
)
print(f"LSTM在测试集上的ROC-AUC为: {classification_metrics['roc_auc']:.4f}\n")
print(f"LSTM在测试集上的PR-AUC为: {classification_metrics['pr_auc']:.4f}\n")
由于医疗数据中的类别不平衡问题,我们采用了ROC-AUC和PR-AUC作为评价指标:
指标 | 优势 | 适用场景 |
---|---|---|
ROC-AUC | 考虑所有阈值下的真阳性率与假阳性率权衡 | 综合评估模型区分能力 |
PR-AUC | 关注正样本的精确率与召回率权衡 | 正样本稀少的不平衡分类 |
六、LSTM的变体与拓展
Staudemeyer和Morris在论文中还提到了LSTM的多种变体:
总结
LSTM以其精巧的门控机制解决了传统RNN难以处理长期依赖的问题,本文通过代码实例展示了LSTM在处理含缺失值时序数据中的应用。
然而,随着深度学习的发展,序列建模领域正在快速进化:
尽管新型模型层出不穷,LSTM因其结构简洁、训练高效以及在特定应用中的出色表现,仍然是时序建模的重要工具。
参考资料
- Ralf C. Staudemeyer, Eric Rothstein Morris. (2019). Understanding LSTM – a tutorial into Long Short-Term Memory Recurrent Neural Networks. arXiv:1909.09586.
- Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. Neural computation, 9(8), 1735-1780.
- Graves, A. (2013). Generating sequences with recurrent neural networks. arXiv:1308.0850.
- GISerLiu的专栏文章
如果觉得我的文章对您有帮助,三连+关注便是对我创作的最大鼓励!或者一个star🌟也可以😂.