forked from baladiwei/COVID-19KG
50 lines
1.8 KiB
Python
50 lines
1.8 KiB
Python
|
|
import tensorflow as tf
|
|
|
|
from bert_base.train.bert_lstm_ner import train
|
|
|
|
|
|
flags = tf.flags
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
#输入输出地址
|
|
flags.DEFINE_string('data_dir', 'data', '数据集地址')
|
|
flags.DEFINE_string('output_dir', 'output', '输出地址')
|
|
|
|
#Bert相关参数
|
|
flags.DEFINE_string('bert_config_file', 'chinese_L-12_H-768_A-12/bert_config.json', 'Bert配置文件')
|
|
flags.DEFINE_string('vocab_file', 'chinese_L-12_H-768_A-12/vocab.txt','vocab_file')
|
|
flags.DEFINE_string('init_checkpoint','chinese_L-12_H-768_A-12/bert_model.ckpt', 'init_checkpoint')
|
|
|
|
#训练和校验的相关参数
|
|
flags.DEFINE_bool('do_train', False, '是否开始训练')
|
|
flags.DEFINE_bool('do_dev', False, '是否开始校验')
|
|
flags.DEFINE_bool('do_test', True, '是否开始测试')
|
|
|
|
flags.DEFINE_bool('do_lower_case', True, '是否转换小写')
|
|
|
|
#模型相关的
|
|
flags.DEFINE_integer('lstm_size', 128, 'lstm_size')
|
|
flags.DEFINE_integer('num_layers', 1, 'num_layers')
|
|
flags.DEFINE_integer('max_seq_length', 128, 'max_seq_length')
|
|
flags.DEFINE_integer('train_batch_size', 64, 'train_batch_size')
|
|
flags.DEFINE_integer('dev_batch_size',64, 'dev_batch_size')
|
|
flags.DEFINE_integer('test_batch_size', 32, 'test_batch_size')
|
|
flags.DEFINE_integer('save_checkpoints_steps', 500, 'save_checkpoints_steps')
|
|
flags.DEFINE_integer('iterations_per_loop', 500, 'iterations_per_loop')
|
|
flags.DEFINE_integer('save_summary_steps', 500, 'save_summary_steps')
|
|
|
|
flags.DEFINE_string('cell', 'lstm', 'cell')
|
|
|
|
flags.DEFINE_float('learning_rate', 5e-5, 'learning_rate')
|
|
flags.DEFINE_float('dropout_rate', 0.5, 'dropout_rate')
|
|
flags.DEFINE_float('clip', 0.5, 'clip')
|
|
flags.DEFINE_float('num_train_epochs', 10.0, 'num_train_epochs')
|
|
flags.DEFINE_float("warmup_proportion", 0.1,'warmup_proportion')
|
|
|
|
def train_ner():
|
|
train(FLAGS)
|
|
|
|
if __name__ == "__main__":
|
|
train_ner() |