MNIST_Example/pretrain.py

151 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
使用注意事项:
1、本示例只支持单数据集训练如果是多数据集请参考多数据集训练示例train_for_multidataset.py
2、本示例支持选择预训练模型单文件或多文件
3、本示例需要用户定义的参数有--multi_data_url,--pretrain_url,--train_url这3个参数在单数据集任务中必须定义
具体的含义如下:
--multi_data_url是启智平台上选择的数据集的obs路径
--pretrain_url是启智平台上选择的预训练模型文件的obs路径
--train_url是训练结果回传到启智平台的obs路径
用户需要调用openi.py下的openi_multidataset_To_env,pretrain_to_env,env_to_openi等函数来实现数据集、预训练模型文件、训练结果的拷贝和回传
"""
import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
import time
from openi import openi_multidataset_to_env as DatasetToEnv
from openi import env_to_openi
from openi import pretrain_to_env
from openi import EnvToOpenIEpochEnd
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
help='使用数据集时,需要定义的参数',
default= '[{}]')
parser.add_argument('--pretrain_url',
help='使用预训练模型时,需要定义的参数',
default= '[{}]')
parser.add_argument('--train_url',
help='回传结果到启智,需要定义的参数',
default= '')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
if __name__ == "__main__":
#请在代码中加入args, unknown = parser.parse_known_args(),可忽略掉--ckpt_url参数报错等参数问题
args, unknown = parser.parse_known_args()
data_dir = '/cache/data'
train_dir = '/cache/output'
pretrain_dir = '/cache/pretrain'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(pretrain_dir):
os.makedirs(pretrain_dir)
###拷贝数据集到训练环境
DatasetToEnv(args.multi_data_url, data_dir)
###拷贝多个预训练模型文件到训练环境
pretrain_to_env(args.pretrain_url, pretrain_dir)
device_num = int(os.getenv('RANK_SIZE'))
#使用单卡时
if device_num == 1:
DatasetToEnv(args.multi_data_url,data_dir)
context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
#使用数据集的方式
ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
#使用多卡时
if device_num > 1:
# set device_id and init for multi-card training
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
init()
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
local_rank=int(os.getenv('RANK_ID'))
if local_rank%8==0:
DatasetToEnv(args.multi_data_url,data_dir)
#Set a cache file to determine whether the data has been copied to obs.
#If this file exists during multi-card training, there is no need to copy the dataset multiple times.
f = open("/cache/download_input.txt", 'w')
f.close()
try:
if os.path.exists("/cache/download_input.txt"):
print("download_input succeed")
except Exception as e:
print("download_input failed")
while not os.path.exists("/cache/download_input.txt"):
time.sleep(1)
ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
###假如选择了模型文件使用pretrain_url的方式,注意ckpt_url的方式依然保留你依然可以使用ckpt_url的方式但是这种方式将会逐渐废弃
load_param_into_net(network, load_checkpoint(os.path.join(pretrain_dir, "checkpoint_lenet-1_1875.ckpt")))
if args.device_target != "Ascend":
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy"})
else:
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy"},
amp_level="O2")
config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
#若是数据并行,则只需要上传一次模型文件
outputDirectory = train_dir + "/"
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=outputDirectory,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
#Custom callback, upload output after each epoch
uploadOutput = EnvToOpenIEpochEnd(train_dir,args.train_url)
model.train(epoch_size, ds_train,callbacks=[time_cb, ckpoint_cb,LossMonitor(), uploadOutput])
###上传训练结果到启智平台
env_to_openi(train_dir,args.train_url)