陈伟@航天科技智慧城市研究院 chenwei@ascs.tech
本项目通过Tensorflow对Cifar10数据集进行读写操作。
包括完整代码和详细注释。
Cifar10数据集介绍
- 由60000个图片组成
- 6万个图片中,5万张用于训练,1万张用于测试
- 每个图片是32x32像素
- 所有图片可以分成10类
- 每个图片都有一个标签,标记属于哪一个类
- 测试集中一个类对应1000张图
- 训练集中将5万张图分为5份
- 类之间的图片是互斥的,不存在类别重叠的情况
Cifar10数据集分类
Cifar10数据集下载
将Cifar10数据转换成图片:convert_cifar10_image.py
"""
本脚本对cifar10数据进行解析,转换成图片,生成训练图片和测试图片。
"""
import urllib.request
import os
import sys
import tarfile
import glob
import pickle
import numpy as np
import cv2
# 通过这个函数完成对数据集的下载和解压
# tarball_url 表示cifar10数据集的下载链接
# dataset_dir 表示存储的路径
# 执行下面的代码可以完成数据集的下载和解压
# DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
# DATA_DIR = 'data'
# download_and_uncompress_tarball(DATA_URL, DATA_DIR)
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
# tarball_url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
filename = tarball_url.split('/')[-1] # 文件名,通过/拆分字符串,取最后一节,也就是cifar-10-python.tar.gz
# dataset_dir = 'data'
# os.path.join()路径拼接,/data/cifar-10-python.tar.gz
filepath = os.path.join(dataset_dir, filename)
# 定义进度函数,分块下载