更新训练代码注释

This commit is contained in:
liuzx 2022-03-09 15:55:47 +08:00
parent b33255d6ff
commit 10cbce7aab
1 changed files with 32 additions and 25 deletions

View File

@ -14,15 +14,24 @@ args.data_url = '/home/work/user-job-dir/inputs/data/'
args.train_url = '/home/work/user-job-dir/outputs/model/'
2在训练环境中需要将数据集从obs拷贝到训练镜像中训练完以后需要将输出的模型拷贝到obs.
将数据集从obs拷贝到训练镜像中
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
args.data_url))
except Exception as e:
print('moxing download {} to {} failed: '.format(
obs_data_url, args.data_url) + str(e))
将输出的模型拷贝到obs
obs_train_url = args.train_url
args.train_url = '/home/work/user-job-dir/model/'
if not os.path.exists(args.train_url):
os.mkdir(args.train_url)
try:
mox.file.copy_parallel(args.train_url, obs_train_url)
print("Successfully Upload {} to {}".format(args.train_url,
@ -30,6 +39,7 @@ try:
except Exception as e:
print('moxing upload {} to {} failed: '.format(args.train_url,
obs_train_url) + str(e))
"""
import os
@ -51,11 +61,11 @@ parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
# data_url,train_url是固定用于在modelarts上训练的参数表示数据集的路径和输出模型的路径
parser.add_argument('--data_url',
help='path to training/inference dataset folder',
default='./data')
default='/home/work/user-job-dir/data/')
parser.add_argument('--train_url',
help='model folder to save/load',
default='./model')
default='/home/work/user-job-dir/model/')
parser.add_argument(
'--device_target',
@ -64,16 +74,7 @@ parser.add_argument(
choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
#用户可自定义的参数dataset_path在示例中指向data_urlsave_checkpoint_path指向train_url在添加超参数时这两个参数可不添加modelarts已经默认使用data_url和train_url
parser.add_argument('--dataset_path',
type=str,
default="./Data",
help='path where the dataset is saved')
parser.add_argument('--save_checkpoint_path',
type=str,
default="./ckpt",
help='if is test, must provide\
path where the trained ckpt file')
#modelarts已经默认使用data_url和train_url
parser.add_argument('--epoch_size',
type=int,
default=5,
@ -86,11 +87,18 @@ if __name__ == "__main__":
args = parser.parse_args()
######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
# 在训练环境中定义data_url和train_url并把数据从obs拷贝到相应的固定路径
# 在训练环境中定义data_url和train_url并把数据从obs拷贝到相应的固定路径以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
#创建数据存放的位置
obs_data_url = args.data_url
args.data_url = '/home/work/user-job-dir/inputs/data/'
args.data_url = '/home/work/user-job-dir/data/'
if not os.path.exists(args.data_url):
os.mkdir(args.data_url)
#创建模型存放的位置
obs_train_url = args.train_url
args.train_url = '/home/work/user-job-dir/outputs/model/'
args.train_url = '/home/work/user-job-dir/model/'
if not os.path.exists(args.train_url):
os.mkdir(args.train_url)
#将数据拷贝到训练环境
try:
mox.file.copy_parallel(obs_data_url, args.data_url)
print("Successfully Download {} to {}".format(obs_data_url,
@ -102,12 +110,11 @@ if __name__ == "__main__":
#将dataset_path指向data_urlsave_checkpoint_path指向train_url
args.dataset_path = args.data_url
args.save_checkpoint_path = args.train_url
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target)
#创建数据集
ds_train = create_dataset(os.path.join(args.dataset_path, "train"),
ds_train = create_dataset(os.path.join(args.data_url, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError(
@ -135,7 +142,7 @@ if __name__ == "__main__":
keep_checkpoint_max=cfg.keep_checkpoint_max)
#定义模型输出路径
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=args.save_checkpoint_path,
directory=args.train_url,
config=config_ck)
#开始训练
print("============== Starting Training ==============")