MNIST_Example/train_for_c2net.py

92 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt)
The training of the intelligent computing network currently supports single dataset training, and does not require
the obs copy process.It only needs to define two parameters and then call it directly
train_dir = '/cache/output' #The location of the output
data_dir = '/cache/dataset' #The location of the dataset
"""
#!/usr/bin/python
#coding=utf-8
import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.common import set_seed
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
set_seed(1)
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
print('args:')
print(args)
###define two parameters and then call it directly###
train_dir = '/cache/output'
data_dir = '/cache/dataset'
###Specifies the device CPU or Ascend NPU used for training###
context.set_context(mode=context.GRAPH_MODE,
device_target=args.device_target)
ds_train = create_dataset(os.path.join(data_dir, "train"),
cfg.batch_size)
if ds_train.get_dataset_size() == 0:
raise ValueError(
"Please check dataset size > 0 and batch_size <= dataset size")
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
if args.device_target != "Ascend":
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()})
else:
model = Model(network,
net_loss,
net_opt,
metrics={"accuracy": Accuracy()},
amp_level="O2")
config_ck = CheckpointConfig(
save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=train_dir,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])
print("============== Finish Training ==============")