问题描述:设置的epoch=100,batchsize=128
网络是resnet,输入输出一维信号
在第一个epoch,前几个输出都正常值,后面开始全是nan
解决办法(顺序不是步骤,是挨个尝试的):
降低学习率,我降了,没用
我发现出现nan的地方尽管都在第一个epoch里,但是每次训练出具体地方不一样。猜测是label或者input有问题,但是如果是这样的话应该是在某一个点处出现nan。发现是DataLoader里shuffle=true,我把它改成了false
此时,总是固定的地方出现nan。我用了一个判断
res=torch.isnan(inputs).any()
print(res)
True就是有nan,反之没有
找到了发现input存在nan, 我也不知道为啥存在,我直接把这一小块input删了,反正几万个数据不差这一个