MNIST_Example/train_for_c2net_dataparalle...

97 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
######################## train lenet dataparallel example ########################
train lenet and get network model files(.ckpt)
The training of the intelligent computing network currently supports single dataset training, and does not require
the obs copy process.It only needs to define two parameters and then call it directly
train_dir = '/cache/output' #The location of the output
data_dir = '/cache/dataset' #The location of the dataset
"""
import os
import argparse
from dataset_distributed import create_dataset_parallel
import moxing as mox
from config import mnist_cfg as cfg
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
import mindspore.ops as ops
# set device_id and init
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
init()
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
set_seed(114514)
if __name__ == "__main__":
args = parser.parse_args()
###define two parameters and then call it directly###
train_dir = '/cache/output'
data_dir = '/cache/dataset'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
ds_train = create_dataset_parallel(os.path.join(data_dir, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError(
"Please check dataset size > 0 and batch_size <= dataset size")
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
if args.device_target != "Ascend":
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()})
else:
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()},
amp_level="O2")
config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
#Note that this method saves the model file on each card. You need to specify the save path on each card.
# In the example, get_rank() is added to distinguish different paths.
ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
directory=train_dir + "/" + str(get_rank()) + "/",
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
model.train(epoch_size,ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], dataset_sink_mode=False)