更新训练代码注释
This commit is contained in:
parent
b33255d6ff
commit
10cbce7aab
57
train.py
57
train.py
|
@ -14,15 +14,24 @@ args.data_url = '/home/work/user-job-dir/inputs/data/'
|
|||
args.train_url = '/home/work/user-job-dir/outputs/model/'
|
||||
2、在训练环境中,需要将数据集从obs拷贝到训练镜像中,训练完以后,需要将输出的模型拷贝到obs.
|
||||
将数据集从obs拷贝到训练镜像中:
|
||||
try:
|
||||
mox.file.copy_parallel(obs_data_url, args.data_url)
|
||||
print("Successfully Download {} to {}".format(obs_data_url,
|
||||
args.data_url))
|
||||
except Exception as e:
|
||||
print('moxing download {} to {} failed: '.format(
|
||||
obs_data_url, args.data_url) + str(e))
|
||||
|
||||
obs_data_url = args.data_url
|
||||
args.data_url = '/home/work/user-job-dir/data/'
|
||||
if not os.path.exists(args.data_url):
|
||||
os.mkdir(args.data_url)
|
||||
try:
|
||||
mox.file.copy_parallel(obs_data_url, args.data_url)
|
||||
print("Successfully Download {} to {}".format(obs_data_url,
|
||||
args.data_url))
|
||||
except Exception as e:
|
||||
print('moxing download {} to {} failed: '.format(
|
||||
obs_data_url, args.data_url) + str(e))
|
||||
|
||||
将输出的模型拷贝到obs:
|
||||
obs_train_url = args.train_url
|
||||
args.train_url = '/home/work/user-job-dir/model/'
|
||||
if not os.path.exists(args.train_url):
|
||||
os.mkdir(args.train_url)
|
||||
try:
|
||||
mox.file.copy_parallel(args.train_url, obs_train_url)
|
||||
print("Successfully Upload {} to {}".format(args.train_url,
|
||||
|
@ -30,6 +39,7 @@ try:
|
|||
except Exception as e:
|
||||
print('moxing upload {} to {} failed: '.format(args.train_url,
|
||||
obs_train_url) + str(e))
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
|
@ -51,11 +61,11 @@ parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
|||
# data_url,train_url是固定用于在modelarts上训练的参数,表示数据集的路径和输出模型的路径
|
||||
parser.add_argument('--data_url',
|
||||
help='path to training/inference dataset folder',
|
||||
default='./data')
|
||||
default='/home/work/user-job-dir/data/')
|
||||
|
||||
parser.add_argument('--train_url',
|
||||
help='model folder to save/load',
|
||||
default='./model')
|
||||
default='/home/work/user-job-dir/model/')
|
||||
|
||||
parser.add_argument(
|
||||
'--device_target',
|
||||
|
@ -64,16 +74,7 @@ parser.add_argument(
|
|||
choices=['Ascend', 'GPU', 'CPU'],
|
||||
help='device where the code will be implemented (default: Ascend)')
|
||||
|
||||
#用户可自定义的参数,dataset_path在示例中指向data_url,save_checkpoint_path指向train_url,在添加超参数时这两个参数可不添加,modelarts已经默认使用data_url和train_url
|
||||
parser.add_argument('--dataset_path',
|
||||
type=str,
|
||||
default="./Data",
|
||||
help='path where the dataset is saved')
|
||||
parser.add_argument('--save_checkpoint_path',
|
||||
type=str,
|
||||
default="./ckpt",
|
||||
help='if is test, must provide\
|
||||
path where the trained ckpt file')
|
||||
#modelarts已经默认使用data_url和train_url
|
||||
parser.add_argument('--epoch_size',
|
||||
type=int,
|
||||
default=5,
|
||||
|
@ -86,11 +87,18 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
######################## 将数据集从obs拷贝到训练镜像中 (固定写法)########################
|
||||
# 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径
|
||||
# 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录
|
||||
#创建数据存放的位置
|
||||
obs_data_url = args.data_url
|
||||
args.data_url = '/home/work/user-job-dir/inputs/data/'
|
||||
args.data_url = '/home/work/user-job-dir/data/'
|
||||
if not os.path.exists(args.data_url):
|
||||
os.mkdir(args.data_url)
|
||||
#创建模型存放的位置
|
||||
obs_train_url = args.train_url
|
||||
args.train_url = '/home/work/user-job-dir/outputs/model/'
|
||||
args.train_url = '/home/work/user-job-dir/model/'
|
||||
if not os.path.exists(args.train_url):
|
||||
os.mkdir(args.train_url)
|
||||
#将数据拷贝到训练环境
|
||||
try:
|
||||
mox.file.copy_parallel(obs_data_url, args.data_url)
|
||||
print("Successfully Download {} to {}".format(obs_data_url,
|
||||
|
@ -102,12 +110,11 @@ if __name__ == "__main__":
|
|||
|
||||
#将dataset_path指向data_url,save_checkpoint_path指向train_url
|
||||
args.dataset_path = args.data_url
|
||||
args.save_checkpoint_path = args.train_url
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target=args.device_target)
|
||||
#创建数据集
|
||||
ds_train = create_dataset(os.path.join(args.dataset_path, "train"),
|
||||
ds_train = create_dataset(os.path.join(args.data_url, "train"),
|
||||
cfg.batch_size)
|
||||
if ds_train.get_dataset_size() == 0:
|
||||
raise ValueError(
|
||||
|
@ -135,7 +142,7 @@ if __name__ == "__main__":
|
|||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
#定义模型输出路径
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||||
directory=args.save_checkpoint_path,
|
||||
directory=args.train_url,
|
||||
config=config_ck)
|
||||
#开始训练
|
||||
print("============== Starting Training ==============")
|
||||
|
|
Loading…
Reference in New Issue