MNIST_Example/read_imagenet.py

72 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
imagnet-1k 数据集已通过磁盘挂载的方式挂载在训练镜像中,
用户可参考下列示例代码读取直接使用。
挂载路径为
.
└── cache/
├── ascend
├── outputs
├── user-job-dir
└── sfs/
└── data/
└── imagenet/
├── train/
│ └── n01440764/
│ └── n01440764_11063.JPEG
└── val/
└── n01440764/
└── ILSVRC2012_val_00011993.JPEG
mindspore.dataset.ImageFolderDataset
- 读取imagenet-1k数据同一文件夹下的数据为同一类class。
mindspore.dataset.vision.c_transforms
- 数据加载和预处理。
mindspore.dataset.ImageFolderDataset
- map给定一组数据增强列表按顺序将数据增强作用在数据集对象上。
- batch将数据集中连续 batch_size 条数据合并为一个批处理数据。
- to_json将数据处理管道序列化为JSON字符串如果提供了文件名则转储到文件中。
'''
import os
import argparse
import moxing as mox
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision.c_transforms as transforms
from openi import env_to_openi
parser = argparse.ArgumentParser(description='Read big dataset ImageNet Example')
parser.add_argument('--train_url',
help='output folder to save/load',
default= '/cache/output/')
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
data_path = '/cache/sfs/data/imagenet/'
modelart_output = '/cache/output'
if not os.path.exists(modelart_output):
os.makedirs(modelart_output)
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"),
shuffle=True)
trans_train = [
transforms.RandomCropDecodeResize(size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.333)),
transforms.RandomHorizontalFlip(prob=0.5),
transforms.Normalize(mean=mean, std=std),
transforms.HWC2CHW()
]
dataset_train = dataset_train.map(operations=trans_train,
input_columns=["image"])
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)
data_info = dataset_train.to_json(filename= modelart_output + '/data_info.json')
print(data_info)
env_to_openi(modelart_output, args.train_url)