MNIST_Example/pretrain_for_c2net.py

128 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
######################## Attention! ########################
使用注意事项:
1、本示例需要用户定义的参数有--multi_data_url,--pretrain_url,--train_url,这3个参数在任务中必须定义
具体的含义如下:
--multi_data_url是启智平台上选择的数据集的obs路径
--pretrain_url是启智平台上选择的预训练模型文件的obs路径
--train_url是训练结果回传到启智平台的obs路径
2、用户需要调用OpenI.C2NETMultiDatasetToEnv等函数来实现数据集、预训练模型文件的拷贝
3、智算网络区别于启智
(1)智算的数据集拷贝到训练镜像后需要解压请使用C2NETMultiDatasetToEnv函数
(2)智算任务结果不需要用户调用函数回传,会在训练结束后自动回传结果
"""
import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore import load_checkpoint, load_param_into_net
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
import time
from openi import c2net_multidataset_to_env as DatasetToEnv
from openi import pretrain_to_env
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
help='使用单数据集训练时,需要定义的参数',
default= '[{}]')
parser.add_argument('--pretrain_url',
help='使用预训练模型单文件或多文件时,需要定义的参数',
default= '[{}]')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
data_dir = '/cache/data'
train_dir = '/cache/output'
pretrain_dir = '/cache/pretrain'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(pretrain_dir):
os.makedirs(pretrain_dir)
###拷贝数据集到训练镜像
DatasetToEnv(args.multi_data_url, data_dir)
###拷贝多个预训练模型文件到训练环境
pretrain_to_env(args.pretrain_url, pretrain_dir)
device_num = int(os.getenv('RANK_SIZE'))
#使用单卡时
if device_num == 1:
DatasetToEnv(args.multi_data_url,data_dir)
context.set_context(mode=context.GRAPH_MODE,device_target=args.device_target)
#使用数据集的方式
ds_train = create_dataset(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
#使用多卡时
if device_num > 1:
# set device_id and init for multi-card training
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=int(os.getenv('ASCEND_DEVICE_ID')))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num = device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
init()
#Copying obs data does not need to be executed multiple times, just let the 0th card copy the data
local_rank=int(os.getenv('RANK_ID'))
if local_rank%8==0:
DatasetToEnv(args.multi_data_url,data_dir)
#Set a cache file to determine whether the data has been copied to obs.
#If this file exists during multi-card training, there is no need to copy the dataset multiple times.
f = open("/cache/download_input.txt", 'w')
f.close()
try:
if os.path.exists("/cache/download_input.txt"):
print("download_input succeed")
except Exception as e:
print("download_input failed")
while not os.path.exists("/cache/download_input.txt"):
time.sleep(1)
ds_train = create_dataset_parallel(os.path.join(data_dir + "/MNISTData", "train"), cfg.batch_size)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
###假如选择了模型文件使用pretrain_dir,注意ckpt_url的方式依然保留你依然可以使用ckpt_url的方式但是这种方式将会逐渐废弃
load_param_into_net(network, load_checkpoint(os.path.join(pretrain_dir, "checkpoint_lenet-1_1875.ckpt")))
if args.device_target != "Ascend":
model = Model(network,net_loss,net_opt,metrics={"accuracy"})
else:
model = Model(network, net_loss,net_opt,metrics={"accuracy"},amp_level="O2")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
print("============== Starting Training ==============")
#若是数据并行,只需要上传一次输出的模型文件
outputDirectory = train_dir
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=outputDirectory,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
callback =[time_cb,LossMonitor()]
model.train(epoch_size,ds_train,callbacks=callback)
###智算不需要回传训练结果,会在任务结束后自动回传