169 lines
7.2 KiB
Python
169 lines
7.2 KiB
Python
"""
|
||
######################## train lenet example ########################
|
||
train lenet and get network model files(.ckpt)
|
||
"""
|
||
"""
|
||
######################## 训练环境使用说明 ########################
|
||
假设已经使用Ascend NPU调试环境调试完代码,欲将调试环境的代码迁移到训练环境进行训练,需要做以下工作:
|
||
1、调试环境的镜像和训练环境的镜像是两个不同的镜像,所处的运行目录不一致,需要将data_url和train_url的路径进行变换
|
||
在调试环境中:
|
||
args.data_url = '/home/ma-user/work/data/' //数据集位置
|
||
args.train_url = '/home/ma-user/work/model/' //训练输出的模型位置
|
||
在训练环境变换为:
|
||
args.data_url = '/home/work/user-job-dir/data/'
|
||
args.train_url = '/home/work/user-job-dir/model/'
|
||
2、在训练环境中,需要将数据集从obs拷贝到训练镜像中,训练完以后,需要将输出的模型拷贝到obs.
|
||
将数据集从obs拷贝到训练镜像中:
|
||
|
||
obs_data_url = args.data_url
|
||
args.data_url = '/home/work/user-job-dir/data/'
|
||
if not os.path.exists(args.data_url):
|
||
os.mkdir(args.data_url)
|
||
try:
|
||
mox.file.copy_parallel(obs_data_url, args.data_url)
|
||
print("Successfully Download {} to {}".format(obs_data_url,
|
||
args.data_url))
|
||
except Exception as e:
|
||
print('moxing download {} to {} failed: '.format(
|
||
obs_data_url, args.data_url) + str(e))
|
||
|
||
将输出的模型拷贝到obs:
|
||
obs_train_url = args.train_url
|
||
args.train_url = '/home/work/user-job-dir/model/'
|
||
if not os.path.exists(args.train_url):
|
||
os.mkdir(args.train_url)
|
||
try:
|
||
mox.file.copy_parallel(args.train_url, obs_train_url)
|
||
print("Successfully Upload {} to {}".format(args.train_url,
|
||
obs_train_url))
|
||
except Exception as e:
|
||
print('moxing upload {} to {} failed: '.format(args.train_url,
|
||
obs_train_url) + str(e))
|
||
|
||
"""
|
||
|
||
import os
|
||
import argparse
|
||
import moxing as mox
|
||
from config import mnist_cfg as cfg
|
||
from dataset import create_dataset
|
||
from lenet import LeNet5
|
||
import mindspore.nn as nn
|
||
from mindspore import context
|
||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||
from mindspore.train import Model
|
||
from mindspore.nn.metrics import Accuracy
|
||
from mindspore.common import set_seed
|
||
|
||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||
|
||
# define 2 parameters for running on modelArts
|
||
# data_url,train_url是固定用于在modelarts上训练的参数,表示数据集的路径和输出模型的路径
|
||
parser.add_argument('--data_url',
|
||
help='path to training/inference dataset folder',
|
||
default='/home/work/user-job-dir/data/')
|
||
|
||
parser.add_argument('--train_url',
|
||
help='model folder to save/load',
|
||
default='/home/work/user-job-dir/model/')
|
||
|
||
parser.add_argument(
|
||
'--device_target',
|
||
type=str,
|
||
default="Ascend",
|
||
choices=['Ascend', 'CPU'],
|
||
help='device where the code will be implemented (default: CPU),若要在启智平台上使用NPU,需要在启智平台训练界面上加上运行参数device_target=Ascend')
|
||
|
||
#modelarts已经默认使用data_url和train_url
|
||
parser.add_argument('--epoch_size',
|
||
type=int,
|
||
default=5,
|
||
help='Training epochs.')
|
||
|
||
set_seed(1)
|
||
|
||
if __name__ == "__main__":
|
||
|
||
args = parser.parse_args()
|
||
|
||
######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
|
||
# 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
|
||
#创建数据存放的位置
|
||
obs_data_url = args.data_url
|
||
args.data_url = '/home/work/user-job-dir/data/'
|
||
if not os.path.exists(args.data_url):
|
||
os.mkdir(args.data_url)
|
||
#创建模型存放的位置
|
||
obs_train_url = args.train_url
|
||
args.train_url = '/home/work/user-job-dir/model/'
|
||
if not os.path.exists(args.train_url):
|
||
os.mkdir(args.train_url)
|
||
#将数据拷贝到训练环境
|
||
try:
|
||
mox.file.copy_parallel(obs_data_url, args.data_url)
|
||
print("Successfully Download {} to {}".format(obs_data_url,
|
||
args.data_url))
|
||
except Exception as e:
|
||
print('moxing download {} to {} failed: '.format(
|
||
obs_data_url, args.data_url) + str(e))
|
||
######################## 将数据集从obs拷贝到训练镜像中 ########################
|
||
|
||
#将dataset_path指向data_url,save_checkpoint_path指向train_url
|
||
args.dataset_path = args.data_url
|
||
|
||
context.set_context(mode=context.GRAPH_MODE,
|
||
device_target=args.device_target)
|
||
#创建数据集
|
||
ds_train = create_dataset(os.path.join(args.data_url, "train"),
|
||
cfg.batch_size)
|
||
if ds_train.get_dataset_size() == 0:
|
||
raise ValueError(
|
||
"Please check dataset size > 0 and batch_size <= dataset size")
|
||
#创建网络
|
||
network = LeNet5(cfg.num_classes)
|
||
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
|
||
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
|
||
|
||
if args.device_target != "Ascend":
|
||
model = Model(network,
|
||
net_loss,
|
||
net_opt,
|
||
metrics={"accuracy": Accuracy()})
|
||
else:
|
||
model = Model(network,
|
||
net_loss,
|
||
net_opt,
|
||
metrics={"accuracy": Accuracy()},
|
||
amp_level="O2")
|
||
|
||
config_ck = CheckpointConfig(
|
||
save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||
#定义模型输出路径
|
||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||
directory=args.train_url,
|
||
config=config_ck)
|
||
#开始训练
|
||
print("============== Starting Training ==============")
|
||
epoch_size = cfg['epoch_size']
|
||
if (args.epoch_size):
|
||
epoch_size = args.epoch_size
|
||
print('epoch_size is: ', epoch_size)
|
||
|
||
model.train(epoch_size,
|
||
ds_train,
|
||
callbacks=[time_cb, ckpoint_cb,
|
||
LossMonitor()])
|
||
|
||
######################## 将输出的模型拷贝到obs(固定写法) ########################
|
||
# 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
|
||
try:
|
||
mox.file.copy_parallel(args.train_url, obs_train_url)
|
||
print("Successfully Upload {} to {}".format(args.train_url,
|
||
obs_train_url))
|
||
except Exception as e:
|
||
print('moxing upload {} to {} failed: '.format(args.train_url,
|
||
obs_train_url) + str(e))
|
||
######################## 将输出的模型拷贝到obs ########################
|