From bebbae27143ec888e45c1ae9026a1333dbc600b9 Mon Sep 17 00:00:00 2001 From: liuzx Date: Mon, 10 Jul 2023 15:44:58 +0800 Subject: [PATCH] for multi node or multi card --- pretrain_for_c2net.py | 15 ++++++++++----- train_for_c2net.py | 15 ++++++++++----- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/pretrain_for_c2net.py b/pretrain_for_c2net.py index 46195be..fd670d5 100644 --- a/pretrain_for_c2net.py +++ b/pretrain_for_c2net.py @@ -139,10 +139,15 @@ if __name__ == "__main__": epoch_size = args.epoch_size print('epoch_size is: ', epoch_size) # set callback functions - callback =[time_cb,LossMonitor()] - local_rank=int(os.getenv('RANK_ID')) + # callback =[time_cb,LossMonitor()] + # local_rank=int(os.getenv('RANK_ID')) # for data parallel, only save checkpoint on rank 0 - if local_rank==0 : - callback.append(ckpoint_cb) - model.train(epoch_size,ds_train,callbacks=callback) + # if local_rank==0 : + # callback.append(ckpoint_cb) + # model.train(epoch_size,ds_train,callbacks=callback) + + model.train(epoch_size, + ds_train, + callbacks=[time_cb, ckpoint_cb, + LossMonitor()]) ###智算不需要回传训练结果,会在任务结束后自动回传 \ No newline at end of file diff --git a/train_for_c2net.py b/train_for_c2net.py index 1230bcc..e3230b7 100644 --- a/train_for_c2net.py +++ b/train_for_c2net.py @@ -122,10 +122,15 @@ if __name__ == "__main__": if (args.epoch_size): epoch_size = args.epoch_size print('epoch_size is: ', epoch_size) + # set callback functions - callback =[time_cb,LossMonitor()] - local_rank=int(os.getenv('RANK_ID')) + # callback =[time_cb,LossMonitor()] + # local_rank=int(os.getenv('RANK_ID')) # for data parallel, only save checkpoint on rank 0 - if local_rank==0 : - callback.append(ckpoint_cb) - model.train(epoch_size,ds_train,callbacks=callback) \ No newline at end of file + # if local_rank==0 : + # callback.append(ckpoint_cb) + # model.train(epoch_size,ds_train,callbacks=callback) + model.train(epoch_size, + ds_train, + callbacks=[time_cb, ckpoint_cb, + LossMonitor()]) \ No newline at end of file