直接从github上面下下来的代码,配置好环境之后运行就报错了。
P1是报错
P2是源码,在/libcity/model/traffic_flow_prediction/PDFormer.py", line 439,在PDFormer类的forward方法中
之前有次在循环前面加了下面的几行扩展self.pattern_keys维度,是可以运行了,
但是当用的是PEMS04/07/08数据集的时候,实验结果有很大差距;数据集是NYCTaxi,CHIBike,T-Drive的时候实验结果是正常的
#添加的代码
if self.pattern_keys.shape[-1] == 1:
self.pattern_keys = self.pattern_keys.expand(-1, -1, self.output_dim)
#源码
x_pattern_list = []
pattern_key_list = []
for i in range(self.output_dim):
x_pattern_list.append(self.pattern_embeddings[i](x_patterns[..., i]).unsqueeze(-1))
pattern_key_list.append(self.pattern_embeddings[i](self.pattern_keys[..., i]).unsqueeze(-1))
有没有朋友可以帮帮忙~