七月和奶茶 2023-09-10 15:13 采纳率: 14.3%
浏览 7

pytorch图片分类

在进行训练时,使用torch.cat((x,t),dim=1)拼接通道,x的尺寸大小为torch.size([32,2048,7,7]),t为torch.size([32,384,7,7]),在第一个epoch训练结束时遇到了以下错误:RuntimeError: Given groups=1, weight of size [2048, 384, 1, 1], expected input[32, 324, 7, 7] to have 384 channels, but got 324 channels instead。请问该如何解决呢,谢谢?

  • 写回答

4条回答 默认 最新

  • 专家-赤兔[在线] 全栈领域优质创作者 2023-09-10 16:26
    关注

    引用 皆我百晓生 小程序回复内容作答:

    根据错误提示,可以看出输入的通道数不匹配。在torch.cat((x,t),dim=1)中,dim=1表示在通道维度上进行拼接。根据你提供的信息,x的通道数为2048,t的通道数为384。因此,你需要确保t的通道数与x的通道数相同,即384。可以通过调整t的通道数来解决这个问题,例如使用torch.nn.Conv2d()函数将t的通道数调整为384。

    评论

报告相同问题?

问题事件

  • 创建了问题 9月10日