构建ResNet网络ResNet.py
from torch import nn
def conv3x3(in_channels,out_channnels,stride=1):
return nn.Conv2d(
in_channels,
out_channnels,
kernel_size=3,
stride=stride,
padding=1,
bias=False
)
class ResidualBlock(nn.Module):
def __init__(self,in_channels,out_channels,stride=1,downsample=None):
super(ResidualBlock, self).__init__()
self.layer = nn.Sequential(
conv3x3(in_channels, out_channels, stride),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
conv3x3(out_channels, out_channels),
nn.BatchNorm2d(out_channels)
)
self.downsample = downsample
self.relu = nn.ReLU(inplace=True)
def forward(self,x):
residual = x
out = self.layer(x)
if self.downsample:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,block,layers,num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3,16)
self.bn = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self.make_layer(block,16,layers[0])
self.layer2 = self.make_layer(block, 32, layers[0],2)
self.layer3 = self.make_layer(block, 64, layers[1], 2)
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64,num_classes)
def make_layer(self,block,out_channels,blocks,stride=1):
downsample = None
if (stride != 1) or (self.in_channels != out_channels):
downsample = nn.Sequential(
conv3x3(self.in_channels,out_channels,stride=stride),
nn.BatchNorm2d(out_channels)
)
layers = []
layers.append(
block(self.in_channels,out_channels,stride,downsample)
)
self.in_channels = out_channels
for i in range(1,blocks):
layers.append(block(out_channels,out_channels))
return nn.Sequential(*layers)
def forward(self,x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avg_pool(out)
out = out.view(out.size(0),-1)
out = self.fc(out)
return out
main.py主函数
import torch
from torch import nn,optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import ResNet
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 80
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose([
transforms.Resize((32,32)),
transforms.ToTensor()
]),download=True)
cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()
]), download=True)
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
device = torch.device('cuda')
model = ResNet.ResNet(ResNet.ResidualBlock,[2,2,2,2],num_classes=num_classes).to(device)
criteon = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(),lr=learning_rate)
for epoch in range(1000):
model.train()
for batchidx,(x,label) in enumerate(cifar_train):
x,label = x.to(device),label.to(device)
logits = model(x)
loss = criteon(logits,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
with torch.no_grad():
total_correct = 0
total_num = 0
for x,label in cifar_test:
x, label = x.to(device), label.to(device)
logits = model(x)
pred = logits.argmax(dim=1)
total_correct += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
acc = total_correct/total_num
print('epoch:{:d}, loss:{:.4f}, acc:{:.2f}%'.format(epoch,loss.item(),acc*100))
if __name__ == '__main__':
main()