74 lines
2.7 KiB
Python
74 lines
2.7 KiB
Python
import os
|
|
import argparse
|
|
import moxing as mox
|
|
from config import mnist_cfg as cfg
|
|
from dataset import create_dataset
|
|
from dataset_distributed import create_dataset_parallel
|
|
from lenet import LeNet5
|
|
import mindspore.nn as nn
|
|
from mindspore import context
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
|
from mindspore.train import Model
|
|
from mindspore.nn.metrics import Accuracy
|
|
from mindspore.context import ParallelMode
|
|
from mindspore.communication.management import init, get_rank
|
|
import mindspore.ops as ops
|
|
import time
|
|
|
|
### Copy the output to obs###
|
|
def EnvToObs(train_dir, obs_train_url):
|
|
try:
|
|
mox.file.copy_parallel(train_dir, obs_train_url)
|
|
print("Successfully Upload {} to {}".format(train_dir,obs_train_url))
|
|
except Exception as e:
|
|
print('moxing upload {} to {} failed: '.format(train_dir,obs_train_url) + str(e))
|
|
return
|
|
|
|
def UploadToQizhi(train_dir, obs_train_url):
|
|
EnvToObs(train_dir, obs_train_url)
|
|
return
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
|
parser.add_argument('--multi_data_url',
|
|
help='path to training/inference dataset folder',
|
|
default= '/cache/data/')
|
|
|
|
parser.add_argument('--train_url',
|
|
help='output folder to save/load',
|
|
default= '/cache/output/')
|
|
|
|
parser.add_argument(
|
|
'--device_target',
|
|
type=str,
|
|
default="Ascend",
|
|
choices=['Ascend', 'CPU'],
|
|
help='device where the code will be implemented (default: Ascend),if to use the CPU on the Qizhi platform:device_target=CPU')
|
|
|
|
parser.add_argument('--epoch_size',
|
|
type=int,
|
|
default=5,
|
|
help='Training epochs.')
|
|
if __name__ == "__main__":
|
|
args, unknown = parser.parse_known_args()
|
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.getenv('ASCEND_DEVICE_ID')))
|
|
context.reset_auto_parallel_context()
|
|
context.set_auto_parallel_context(device_num = 4, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, parameter_broadcast=True)
|
|
init()
|
|
print("begin:")
|
|
data_dir = '/cache/data'
|
|
train_dir = '/cache/output'
|
|
if not os.path.exists(data_dir):
|
|
os.makedirs(data_dir)
|
|
if not os.path.exists(train_dir):
|
|
os.makedirs(train_dir)
|
|
#train_dir = '/cache/output'
|
|
outputdir = "/cache/output/{}/".format((str(get_rank())))
|
|
if not os.path.exists(outputdir):
|
|
os.makedirs(outputdir)
|
|
os.chdir("/cache/output/{}/".format((str(get_rank()))))
|
|
os.system("touch {}.txt".format((get_rank())))
|
|
os.system("cd /cache/output;ls -al;")
|
|
|
|
print("end:")
|
|
UploadToQizhi(train_dir, args.train_url) |