问题描述
import torch
import tensorwatch as tw
resnet18 = models.resnet18(pretrained=True)
tw.draw_model(resnet18, (1, 3, 224, 224))
使用tensorwatch展示网络结构时报错。
问题原因
当前环境中Pytorch版本为1.8。
高于1.6的版本中,模块’torch.onnx’去掉了 'set_training’这个属性。
解决办法
将Pytorch降到1.6以下。
Conda 安装 v1.5.1
# CUDA 9.2
conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=9.2 -c pytorch
# CUDA 10.1
conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.1 -c pytorch
# CUDA 10.2
conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.2 -c pytorch
# CPU Only
conda install pytorch==1.5.1 torchvision==0.6.1 cpuonly -c pytorch
pip 安装 v1.5.1
# CUDA 10.2
pip install torch==1.5.1 torchvision==0.6.1
# CUDA 10.1
pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html
# CUDA 9.2
pip install torch==1.5.1+cu92 torchvision==0.6.1+cu92 -f https://download.pytorch.org/whl/torch_stable.html
# CPU only
pip install torch==1.5.1+cpu torchvision==0.6.1+cpu -f https://download.pytorch.org/whl/torch_stable.html