MNIST_Example/train.py

169 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt)
"""
"""
######################## 训练环境使用说明 ########################
假设已经使用Ascend NPU调试环境调试完代码欲将调试环境的代码迁移到训练环境进行训练需要做以下工作
1、调试环境的镜像和训练环境的镜像是两个不同的镜像所处的运行目录不一致需要将data_url和train_url的路径进行变换
在调试环境中:
args.data_url = '/home/ma-user/work/data/' //数据集位置
args.train_url = '/home/ma-user/work/model/' //训练输出的模型位置
在训练环境变换为:
args.data_url = '/home/work/user-job-dir/data/'
args.train_url = '/home/work/user-job-dir/model/'
2、在训练环境中需要将数据集从obs拷贝到训练镜像中训练完以后需要将输出的模型拷贝到obs.
将数据集从obs拷贝到训练镜像中
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
将输出的模型拷贝到obs
obs_train_url = args.train_url
args.train_url = '/home/work/user-job-dir/model/'
if not os.path.exists(args.train_url):
os.mkdir(args.train_url)
try:
mox.file.copy_parallel(args.train_url, obs_train_url)
print("Successfully Upload {} to {}".format(args.train_url,
obs_train_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(args.train_url,
obs_train_url) + str(e))
"""
import os
import argparse
import moxing as mox
from config import mnist_cfg as cfg
from dataset import create_dataset
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
# define 2 parameters for running on modelArts
# data_url,train_url是固定用于在modelarts上训练的参数表示数据集的路径和输出模型的路径
parser.add_argument('--data_url',
help='path to training/inference dataset folder',
default='/home/work/user-job-dir/data/')
parser.add_argument('--train_url',
help='model folder to save/load',
default='/home/work/user-job-dir/model/')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU需要在启智平台训练界面上加上运行参数device_target=Ascend')
#modelarts已经默认使用data_url和train_url
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
set_seed(1)
if __name__ == "__main__":
args = parser.parse_args()
######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
# 在训练环境中定义data_url和train_url并把数据从obs拷贝到相应的固定路径以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
#创建数据存放的位置
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
#创建模型存放的位置
obs_train_url = args.train_url
args.train_url = '/home/work/user-job-dir/model/'
if not os.path.exists(args.train_url):
os.mkdir(args.train_url)
#将数据拷贝到训练环境
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
######################## 将数据集从obs拷贝到训练镜像中 ########################
#将dataset_path指向data_urlsave_checkpoint_path指向train_url
args.dataset_path = args.data_url
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target)
#创建数据集
ds_train = create_dataset(os.path.join(args.data_url, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError(
"Please check dataset size > 0 and batch_size <= dataset size")
#创建网络
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
if args.device_target != "Ascend":
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()})
else:
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()},
amp_level="O2")
config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
#定义模型输出路径
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=args.train_url,
config=config_ck)
#开始训练
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])
######################## 将输出的模型拷贝到obs固定写法 ########################
# 把训练后的模型数据从本地的运行环境拷贝回obs在启智平台相对应的训练任务中会提供下载
try:
mox.file.copy_parallel(args.train_url, obs_train_url)
print("Successfully Upload {} to {}".format(args.train_url,
obs_train_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(args.train_url,
obs_train_url) + str(e))
######################## 将输出的模型拷贝到obs ########################