add sent back tool
This commit is contained in:
parent
ebd9777f20
commit
dee04b5fe1
|
@ -25,6 +25,7 @@ from mindspore.nn.metrics import Accuracy
|
|||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
import mindspore.ops as ops
|
||||
from upload_for_c2net import UploadOutput
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
|
||||
parser.add_argument(
|
||||
|
@ -94,6 +95,8 @@ if __name__ == "__main__":
|
|||
epoch_size = args.epoch_size
|
||||
print('epoch_size is: ', epoch_size)
|
||||
|
||||
model.train(epoch_size,ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], dataset_sink_mode=False)
|
||||
#Custom callback, upload output after each epoch
|
||||
uploadOutput = UploadOutput(train_dir,args.train_url)
|
||||
model.train(epoch_size,ds_train,callbacks=[time_cb, ckpoint_cb,LossMonitor(), uploadOutput])
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
from mindspore.train.callback import Callback
|
||||
import os
|
||||
|
||||
class UploadOutput(Callback):
|
||||
def epoch_end(self,run_context):
|
||||
os.system("cd /tmp/script_for_grampus/ &&./uploader_for_gpu " + "/tmp/output/")
|
Loading…
Reference in New Issue