for multi node or multi card
This commit is contained in:
parent
d4fdc614b1
commit
bebbae2714
|
@ -139,10 +139,15 @@ if __name__ == "__main__":
|
|||
epoch_size = args.epoch_size
|
||||
print('epoch_size is: ', epoch_size)
|
||||
# set callback functions
|
||||
callback =[time_cb,LossMonitor()]
|
||||
local_rank=int(os.getenv('RANK_ID'))
|
||||
# callback =[time_cb,LossMonitor()]
|
||||
# local_rank=int(os.getenv('RANK_ID'))
|
||||
# for data parallel, only save checkpoint on rank 0
|
||||
if local_rank==0 :
|
||||
callback.append(ckpoint_cb)
|
||||
model.train(epoch_size,ds_train,callbacks=callback)
|
||||
# if local_rank==0 :
|
||||
# callback.append(ckpoint_cb)
|
||||
# model.train(epoch_size,ds_train,callbacks=callback)
|
||||
|
||||
model.train(epoch_size,
|
||||
ds_train,
|
||||
callbacks=[time_cb, ckpoint_cb,
|
||||
LossMonitor()])
|
||||
###智算不需要回传训练结果,会在任务结束后自动回传
|
|
@ -122,10 +122,15 @@ if __name__ == "__main__":
|
|||
if (args.epoch_size):
|
||||
epoch_size = args.epoch_size
|
||||
print('epoch_size is: ', epoch_size)
|
||||
|
||||
# set callback functions
|
||||
callback =[time_cb,LossMonitor()]
|
||||
local_rank=int(os.getenv('RANK_ID'))
|
||||
# callback =[time_cb,LossMonitor()]
|
||||
# local_rank=int(os.getenv('RANK_ID'))
|
||||
# for data parallel, only save checkpoint on rank 0
|
||||
if local_rank==0 :
|
||||
callback.append(ckpoint_cb)
|
||||
model.train(epoch_size,ds_train,callbacks=callback)
|
||||
# if local_rank==0 :
|
||||
# callback.append(ckpoint_cb)
|
||||
# model.train(epoch_size,ds_train,callbacks=callback)
|
||||
model.train(epoch_size,
|
||||
ds_train,
|
||||
callbacks=[time_cb, ckpoint_cb,
|
||||
LossMonitor()])
|
Loading…
Reference in New Issue