MNIST_Example/train_for_multinode.py

74 lines
2.7 KiB
Python

import os
import argparse
import moxing as mox
from config import mnist_cfg as cfg
from dataset import create_dataset
from dataset_distributed import create_dataset_parallel
from lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank
import mindspore.ops as ops
import time
### Copy the output to obs###
def EnvToObs(train_dir, obs_train_url):
try:
mox.file.copy_parallel(train_dir, obs_train_url)
print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
except Exception as e:
print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
return
def UploadToQizhi(train_dir, obs_train_url):
EnvToObs(train_dir, obs_train_url)
return
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
parser.add_argument('--multi_data_url',
help='path to training/inference dataset folder',
default= '/cache/data/')
parser.add_argument('--train_url',
help='output folder to save/load',
default= '/cache/output/')
parser.add_argument(
'--device_target',
type=str,
default="Ascend",
choices=['Ascend', 'CPU'],
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
parser.add_argument('--epoch_size',
type=int,
default=5,
help='Training epochs.')
if __name__ == "__main__":
args, unknown = parser.parse_known_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.getenv('ASCEND_DEVICE_ID')))
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num = 4, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
init()
print("begin:")
data_dir = '/cache/data'
train_dir = '/cache/output'
if not os.path.exists(data_dir):
os.makedirs(data_dir)
if not os.path.exists(train_dir):
os.makedirs(train_dir)
#train_dir = '/cache/output'
outputdir = "/cache/output/{}/".format((str(get_rank())))
if not os.path.exists(outputdir):
os.makedirs(outputdir)
os.chdir("/cache/output/{}/".format((str(get_rank()))))
os.system("touch {}.txt".format((get_rank())))
os.system("cd /cache/output;ls -al;")
print("end:")
UploadToQizhi(train_dir, args.train_url)