diff --git a/train.py b/train.py index 6d53739..50bc69d 100644 --- a/train.py +++ b/train.py @@ -14,15 +14,24 @@ args.data_url = '/home/work/user-job-dir/inputs/data/' args.train_url = '/home/work/user-job-dir/outputs/model/' 2、在训练环境中,需要将数据集从obs拷贝到训练镜像中,训练完以后,需要将输出的模型拷贝到obs. 将数据集从obs拷贝到训练镜像中: -try: - mox.file.copy_parallel(obs_data_url, args.data_url) - print("Successfully Download {} to {}".format(obs_data_url, - args.data_url)) -except Exception as e: - print('moxing download {} to {} failed: '.format( - obs_data_url, args.data_url) + str(e)) + + obs_data_url = args.data_url + args.data_url = '/home/work/user-job-dir/data/' + if not os.path.exists(args.data_url): + os.mkdir(args.data_url) + try: + mox.file.copy_parallel(obs_data_url, args.data_url) + print("Successfully Download {} to {}".format(obs_data_url, + args.data_url)) + except Exception as e: + print('moxing download {} to {} failed: '.format( + obs_data_url, args.data_url) + str(e)) 将输出的模型拷贝到obs: + obs_train_url = args.train_url + args.train_url = '/home/work/user-job-dir/model/' + if not os.path.exists(args.train_url): + os.mkdir(args.train_url) try: mox.file.copy_parallel(args.train_url, obs_train_url) print("Successfully Upload {} to {}".format(args.train_url, @@ -30,6 +39,7 @@ try: except Exception as e: print('moxing upload {} to {} failed: '.format(args.train_url, obs_train_url) + str(e)) + """ import os @@ -51,11 +61,11 @@ parser = argparse.ArgumentParser(description='MindSpore Lenet Example') # data_url,train_url是固定用于在modelarts上训练的参数,表示数据集的路径和输出模型的路径 parser.add_argument('--data_url', help='path to training/inference dataset folder', - default='./data') + default='/home/work/user-job-dir/data/') parser.add_argument('--train_url', help='model folder to save/load', - default='./model') + default='/home/work/user-job-dir/model/') parser.add_argument( '--device_target', @@ -64,16 +74,7 @@ parser.add_argument( choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented (default: Ascend)') -#用户可自定义的参数,dataset_path在示例中指向data_url,save_checkpoint_path指向train_url,在添加超参数时这两个参数可不添加,modelarts已经默认使用data_url和train_url -parser.add_argument('--dataset_path', - type=str, - default="./Data", - help='path where the dataset is saved') -parser.add_argument('--save_checkpoint_path', - type=str, - default="./ckpt", - help='if is test, must provide\ - path where the trained ckpt file') +#modelarts已经默认使用data_url和train_url parser.add_argument('--epoch_size', type=int, default=5, @@ -86,11 +87,18 @@ if __name__ == "__main__": args = parser.parse_args() ######################## 将数据集从obs拷贝到训练镜像中 (固定写法)######################## - # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径 + # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径,以下写法是将数据拷贝到/home/work/user-job-dir/data/目录下,可修改为其他目录 + #创建数据存放的位置 obs_data_url = args.data_url - args.data_url = '/home/work/user-job-dir/inputs/data/' + args.data_url = '/home/work/user-job-dir/data/' + if not os.path.exists(args.data_url): + os.mkdir(args.data_url) + #创建模型存放的位置 obs_train_url = args.train_url - args.train_url = '/home/work/user-job-dir/outputs/model/' + args.train_url = '/home/work/user-job-dir/model/' + if not os.path.exists(args.train_url): + os.mkdir(args.train_url) + #将数据拷贝到训练环境 try: mox.file.copy_parallel(obs_data_url, args.data_url) print("Successfully Download {} to {}".format(obs_data_url, @@ -102,12 +110,11 @@ if __name__ == "__main__": #将dataset_path指向data_url,save_checkpoint_path指向train_url args.dataset_path = args.data_url - args.save_checkpoint_path = args.train_url context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) #创建数据集 - ds_train = create_dataset(os.path.join(args.dataset_path, "train"), + ds_train = create_dataset(os.path.join(args.data_url, "train"), cfg.batch_size) if ds_train.get_dataset_size() == 0: raise ValueError( @@ -135,7 +142,7 @@ if __name__ == "__main__": keep_checkpoint_max=cfg.keep_checkpoint_max) #定义模型输出路径 ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", - directory=args.save_checkpoint_path, + directory=args.train_url, config=config_ck) #开始训练 print("============== Starting Training ==============")