From e32fe199c450b934db17b4cff0442de01e5222bb Mon Sep 17 00:00:00 2001 From: Charles Jacobs Date: Fri, 12 Oct 2018 15:07:01 -0700 Subject: [PATCH] ... --- beginner_source/blitz/cifar10_tutorial.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/beginner_source/blitz/cifar10_tutorial.py b/beginner_source/blitz/cifar10_tutorial.py index e0cda390657..d21243222cf 100644 --- a/beginner_source/blitz/cifar10_tutorial.py +++ b/beginner_source/blitz/cifar10_tutorial.py @@ -55,10 +55,16 @@ Using ``torchvision``, it’s extremely easy to load CIFAR10. """ + +import os + import torch import torchvision import torchvision.transforms as transforms +DATASET_DIR = "/Users/cjacobs/Work/ID/Datasets/pytorch" +CIFAR_DATASET_DIR = os.path.join(DATASET_DIR, "cifar-10") + ######################################################################## # The output of torchvision datasets are PILImage images of range [0, 1]. # We transform them to Tensors of normalized range [-1, 1]. @@ -67,12 +73,12 @@ [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) -trainset = torchvision.datasets.CIFAR10(root='./data', train=True, +trainset = torchvision.datasets.CIFAR10(root=CIFAR_DATASET_DIR, train=True, download=True, transform=transform) trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2) -testset = torchvision.datasets.CIFAR10(root='./data', train=False, +testset = torchvision.datasets.CIFAR10(root=CIFAR_DATASET_DIR, train=False, download=True, transform=transform) testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2) @@ -155,7 +161,8 @@ def forward(self, x): # We simply have to loop over our data iterator, and feed the inputs to the # network and optimize. -for epoch in range(2): # loop over the dataset multiple times +NUM_EPOCHS = 10 +for epoch in range(NUM_EPOCHS): # loop over the dataset multiple times running_loss = 0.0 for i, data in enumerate(trainloader, 0): @@ -259,6 +266,10 @@ def forward(self, x): print('Accuracy of %5s : %2d %%' % ( classes[i], 100 * class_correct[i] / class_total[i])) +input_tensor = trainset[0][0] +torch.onnx._export(net, input_tensor, './cifar-10.onnx', export_params=True, verbose=True) + + ######################################################################## # Okay, so what next? #