import os
import torchvision
import torchvision.transforms as transforms
# 下载MNIST数据集
transform = transforms.Compose([transforms.ToTensor()])
mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 下载MNIST数据集
# 创建一个目录来保存图像(如果不存在的话)
os.makedirs('./mnist_images/train', exist_ok=True)
os.makedirs('./mnist_images/test', exist_ok=True)
# 遍历数据集,保存图像
for idx,(image,label) in enumerate(mnist_trainset):
# 创建类别文件(如果不存在的话)
label_dir = os.path.join('./mnist_images/train', str(label))
os.makedirs(label_dir, exist_ok=True)
# 转换为PIL格式,保存图像
pil_image = transforms.ToPILImage()(image)
pil_image.save(os.path.join(label_dir, f'{idx}.png'))
#遍历数据集,保存图像
for idx,(image,label) in enumerate(mnist_testset):
# 创建类别文件(如果不存在的话)
label_dir = os.path.join('./mnist_images/test', str(label))
os.makedirs(label_dir, exist_ok=True)
# 转换为PIL格式,保存图像
pil_image = transforms.ToPILImage()(image)
pil_image.save(os.path.join(label_dir, f'{idx}.png'))
# 保存完毕
# 打印完成信息
print('MNIST images saved to ./mnist_images')```
MNIST保存为本地图片
最新推荐文章于 2025-05-31 16:28:19 发布