Upload output each epoch
This commit is contained in:
parent
a80ddea2e9
commit
ebd9777f20
11
train.py
11
train.py
|
@ -56,6 +56,7 @@ from mindspore.context import ParallelMode
|
|||
from mindspore.communication.management import init, get_rank
|
||||
import mindspore.ops as ops
|
||||
import time
|
||||
from upload import UploadOutput
|
||||
|
||||
### Copy single dataset from obs to training image###
|
||||
def ObsToEnv(obs_data_url, data_dir):
|
||||
|
@ -138,7 +139,7 @@ parser.add_argument('--epoch_size',
|
|||
help='Training epochs.')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
args, unknown = parser.parse_known_args()
|
||||
data_dir = '/cache/data'
|
||||
train_dir = '/cache/output'
|
||||
if not os.path.exists(data_dir):
|
||||
|
@ -190,12 +191,14 @@ if __name__ == "__main__":
|
|||
if (args.epoch_size):
|
||||
epoch_size = args.epoch_size
|
||||
print('epoch_size is: ', epoch_size)
|
||||
|
||||
#Custom callback, upload output after each epoch
|
||||
uploadOutput = UploadOutput(train_dir,args.train_url)
|
||||
model.train(epoch_size,
|
||||
ds_train,
|
||||
callbacks=[time_cb, ckpoint_cb,
|
||||
LossMonitor()])
|
||||
LossMonitor(), uploadOutput])
|
||||
|
||||
###Copy the trained output data from the local running environment back to obs,
|
||||
###and download it in the training task corresponding to the Qizhi platform
|
||||
UploadToQizhi(train_dir,args.train_url)
|
||||
#This step is not required if UploadOutput is called
|
||||
UploadToQizhi(train_dir,args.train_url)
|
|
@ -0,0 +1,14 @@
|
|||
from mindspore.train.callback import Callback
|
||||
import moxing as mox
|
||||
|
||||
class UploadOutput(Callback):
|
||||
def __init__(self, train_dir, obs_train_url):
|
||||
self.train_dir = train_dir
|
||||
self.obs_train_url = obs_train_url
|
||||
def epoch_end(self,run_context):
|
||||
try:
|
||||
mox.file.copy_parallel(self.train_dir , self.obs_train_url )
|
||||
print("Successfully Upload {} to {}".format(self.train_dir ,self.obs_train_url ))
|
||||
except Exception as e:
|
||||
print('moxing upload {} to {} failed: '.format(self.train_dir ,self.obs_train_url ) + str(e))
|
||||
return
|
Loading…
Reference in New Issue