update
This commit is contained in:
parent
382b2863b6
commit
5bf0ca074d
6
train.py
6
train.py
|
@ -179,8 +179,12 @@ if __name__ == "__main__":
|
|||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
#Note that this method saves the model file on each card. You need to specify the save path on each card.
|
||||
# In this example, get_rank() is added to distinguish different paths.
|
||||
if device_num == 1:
|
||||
outputDirectory = train_dir + "/"
|
||||
if device_num > 1:
|
||||
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||||
directory=train_dir + "/" + str(get_rank()) + "/",
|
||||
directory=outputDirectory,
|
||||
config=config_ck)
|
||||
print("============== Starting Training ==============")
|
||||
epoch_size = cfg['epoch_size']
|
||||
|
|
|
@ -81,9 +81,13 @@ if __name__ == "__main__":
|
|||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
#Note that this method saves the model file on each card. You need to specify the save path on each card.
|
||||
# In the example, get_rank() is added to distinguish different paths.
|
||||
ckpoint_cb = ModelCheckpoint(prefix="data_parallel",
|
||||
directory=train_dir + "/" + str(get_rank()) + "/",
|
||||
config=config_ck)
|
||||
if device_num == 1:
|
||||
outputDirectory = train_dir + "/"
|
||||
if device_num > 1:
|
||||
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||||
directory=outputDirectory,
|
||||
config=config_ck)
|
||||
print("============== Starting Training ==============")
|
||||
epoch_size = cfg['epoch_size']
|
||||
if (args.epoch_size):
|
||||
|
|
|
@ -195,13 +195,16 @@ if __name__ == "__main__":
|
|||
model = Model(network,net_loss,net_opt,metrics={"accuracy": Accuracy()})
|
||||
else:
|
||||
model = Model(network, net_loss,net_opt,metrics={"accuracy": Accuracy()},amp_level="O2")
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
#Note that this method saves the model file on each card. You need to specify the save path on each card.
|
||||
# In this example, get_rank() is added to distinguish different paths.
|
||||
if device_num == 1:
|
||||
outputDirectory = train_dir + "/"
|
||||
if device_num > 1:
|
||||
outputDirectory = train_dir + "/" + str(get_rank()) + "/"
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet",
|
||||
directory=train_dir + "/" + str(get_rank()) + "/",
|
||||
directory=outputDirectory,
|
||||
config=config_ck)
|
||||
print("============== Starting Training ==============")
|
||||
epoch_size = cfg['epoch_size']
|
||||
|
|
Loading…
Reference in New Issue