for multi node or multi card

This commit is contained in:
liuzx 2023-07-10 15:44:58 +08:00
parent d4fdc614b1
commit bebbae2714
2 changed files with 20 additions and 10 deletions

View File

@ -139,10 +139,15 @@ if __name__ == "__main__":
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
# set callback functions
callback =[time_cb,LossMonitor()]
local_rank=int(os.getenv('RANK_ID'))
# callback =[time_cb,LossMonitor()]
# local_rank=int(os.getenv('RANK_ID'))
# for data parallel, only save checkpoint on rank 0
if local_rank==0 :
callback.append(ckpoint_cb)
model.train(epoch_size,ds_train,callbacks=callback)
# if local_rank==0 :
# callback.append(ckpoint_cb)
# model.train(epoch_size,ds_train,callbacks=callback)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])
###智算不需要回传训练结果,会在任务结束后自动回传

View File

@ -122,10 +122,15 @@ if __name__ == "__main__":
if (args.epoch_size):
epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size)
# set callback functions
callback =[time_cb,LossMonitor()]
local_rank=int(os.getenv('RANK_ID'))
# callback =[time_cb,LossMonitor()]
# local_rank=int(os.getenv('RANK_ID'))
# for data parallel, only save checkpoint on rank 0
if local_rank==0 :
callback.append(ckpoint_cb)
model.train(epoch_size,ds_train,callbacks=callback)
# if local_rank==0 :
# callback.append(ckpoint_cb)
# model.train(epoch_size,ds_train,callbacks=callback)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])