diff --git a/train.py b/train.py index 5011ac7..2655951 100644 --- a/train.py +++ b/train.py @@ -179,8 +179,12 @@ if __name__ == "__main__": keep_checkpoint_max=cfg.keep_checkpoint_max) #Note that this method saves the model file on each card. You need to specify the save path on each card. # In this example, get_rank() is added to distinguish different paths. + if device_num == 1: + outputDirectory = train_dir + "/" + if device_num > 1: + outputDirectory = train_dir + "/" + str(get_rank()) + "/" ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", - directory=train_dir + "/" + str(get_rank()) + "/", + directory=outputDirectory, config=config_ck) print("============== Starting Training ==============") epoch_size = cfg['epoch_size'] diff --git a/train_for_c2net.py b/train_for_c2net.py index 4a4459e..d17f15c 100644 --- a/train_for_c2net.py +++ b/train_for_c2net.py @@ -81,9 +81,13 @@ if __name__ == "__main__": keep_checkpoint_max=cfg.keep_checkpoint_max) #Note that this method saves the model file on each card. You need to specify the save path on each card. # In the example, get_rank() is added to distinguish different paths. - ckpoint_cb = ModelCheckpoint(prefix="data_parallel", - directory=train_dir + "/" + str(get_rank()) + "/", - config=config_ck) + if device_num == 1: + outputDirectory = train_dir + "/" + if device_num > 1: + outputDirectory = train_dir + "/" + str(get_rank()) + "/" + ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", + directory=outputDirectory, + config=config_ck) print("============== Starting Training ==============") epoch_size = cfg['epoch_size'] if (args.epoch_size): diff --git a/train_for_multidataset.py b/train_for_multidataset.py index 531827f..a1ab2e8 100644 --- a/train_for_multidataset.py +++ b/train_for_multidataset.py @@ -195,13 +195,16 @@ if __name__ == "__main__": model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()}) else: model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2") - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, keep_checkpoint_max=cfg.keep_checkpoint_max) #Note that this method saves the model file on each card. You need to specify the save path on each card. # In this example, get_rank() is added to distinguish different paths. + if device_num == 1: + outputDirectory = train_dir + "/" + if device_num > 1: + outputDirectory = train_dir + "/" + str(get_rank()) + "/" ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", - directory=train_dir + "/" + str(get_rank()) + "/", + directory=outputDirectory, config=config_ck) print("============== Starting Training ==============") epoch_size = cfg['epoch_size']