for multi node or multi card

This commit is contained in:
liuzx 2023-07-10 15:44:58 +08:00
parent d4fdc614b1
commit bebbae2714
2 changed files with 20 additions and 10 deletions

View File

@ -139,10 +139,15 @@ if __name__ == "__main__":
epoch_size = args.epoch_size epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size) print('epoch_size is: ', epoch_size)
# set callback functions # set callback functions
callback =[time_cb,LossMonitor()] # callback =[time_cb,LossMonitor()]
local_rank=int(os.getenv('RANK_ID')) # local_rank=int(os.getenv('RANK_ID'))
# for data parallel, only save checkpoint on rank 0 # for data parallel, only save checkpoint on rank 0
if local_rank==0 : # if local_rank==0 :
callback.append(ckpoint_cb) # callback.append(ckpoint_cb)
model.train(epoch_size,ds_train,callbacks=callback) # model.train(epoch_size,ds_train,callbacks=callback)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])
###智算不需要回传训练结果,会在任务结束后自动回传 ###智算不需要回传训练结果,会在任务结束后自动回传

View File

@ -122,10 +122,15 @@ if __name__ == "__main__":
if (args.epoch_size): if (args.epoch_size):
epoch_size = args.epoch_size epoch_size = args.epoch_size
print('epoch_size is: ', epoch_size) print('epoch_size is: ', epoch_size)
# set callback functions # set callback functions
callback =[time_cb,LossMonitor()] # callback =[time_cb,LossMonitor()]
local_rank=int(os.getenv('RANK_ID')) # local_rank=int(os.getenv('RANK_ID'))
# for data parallel, only save checkpoint on rank 0 # for data parallel, only save checkpoint on rank 0
if local_rank==0 : # if local_rank==0 :
callback.append(ckpoint_cb) # callback.append(ckpoint_cb)
model.train(epoch_size,ds_train,callbacks=callback) # model.train(epoch_size,ds_train,callbacks=callback)
model.train(epoch_size,
ds_train,
callbacks=[time_cb, ckpoint_cb,
LossMonitor()])