This commit is contained in:
liuzx 2022-08-10 17:03:24 +08:00
parent 382b2863b6
commit 5bf0ca074d
3 changed files with 17 additions and 6 deletions

View File

@ -179,8 +179,12 @@ if __name__ == "__main__":
keep_checkpoint_max=cfg.keep_checkpoint_max)
#Note that this method saves the model file on each card. You need to specify the save path on each card.
# In this example, get_rank() is added to distinguish different paths.
if device_num == 1:
outputDirectory = train_dir + "/"
if device_num > 1:
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=train_dir + "/" + str(get_rank()) + "/",
directory=outputDirectory,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']

View File

@ -81,9 +81,13 @@ if __name__ == "__main__":
keep_checkpoint_max=cfg.keep_checkpoint_max)
#Note that this method saves the model file on each card. You need to specify the save path on each card.
# In the example, get_rank() is added to distinguish different paths.
ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
directory=train_dir + "/" + str(get_rank()) + "/",
config=config_ck)
if device_num == 1:
outputDirectory = train_dir + "/"
if device_num > 1:
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=outputDirectory,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']
if (args.epoch_size):

View File

@ -195,13 +195,16 @@ if __name__ == "__main__":
model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()})
else:
model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
#Note that this method saves the model file on each card. You need to specify the save path on each card.
# In this example, get_rank() is added to distinguish different paths.
if device_num == 1:
outputDirectory = train_dir + "/"
if device_num > 1:
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
directory=train_dir + "/" + str(get_rank()) + "/",
directory=outputDirectory,
config=config_ck)
print("============== Starting Training ==============")
epoch_size = cfg['epoch_size']