基于BERT的文本对分类模型实现

This commit is contained in:
huaian_zhou 2023-12-14 17:48:06 +08:00
parent bf0728cbd9
commit 747e381472
5 changed files with 180 additions and 15 deletions

View File

@ -3,10 +3,15 @@ import sys
sys.path.append(os.getcwd())
import time
import copy
import logging
import torch
from model.BERT.config import BertConfig
from model.downstream.bert_for_sen_cls import BertForSenCls
from utils.data_helpers import LoadPairSenClsDataset
from utils.log_helper import logger_init
from transformers import BertTokenizer, get_scheduler
class ModelConfig:
@ -43,7 +48,7 @@ class ModelConfig:
self.split_sep = "_!_"
self.is_sample_shuffle = True
self.max_sen_len = None # 填充模式
self.eval_per_epoch = 2
self.eval_per_epoch = 1
self.learning_rate = 3.5e-5
# 导入BERT模型部分配置
@ -57,3 +62,163 @@ class ModelConfig:
logging.info("### 将当前配置打印到日志文件中")
for key, value in self.__dict__.items():
logging.info(f"### {key} = {value}")
def evaluate(model, data_loader, pad_idx, device):
model.eval()
with torch.no_grad():
corrects, total = 0.0, 0
for seqs, segs, labels in data_loader:
seqs, segs, labels = seqs.to(device), segs.to(device), labels.to(device)
padding_mask = (seqs == pad_idx).transpose(0, 1)
logits = model(seqs, token_type_ids=segs, attention_mask=padding_mask)
corrects += (logits.argmax(1) == labels).float().sum().item()
total += len(labels)
model.train()
return corrects / total
def train(config: ModelConfig):
# 1.加载数据集并预处理
# 借用transformers库中的分词器
tokenizer = BertTokenizer.from_pretrained(
model_config.pretrained_model_dir
).tokenize
dataset = LoadPairSenClsDataset(
vocab_path=config.vocab_path,
tokenizer=tokenizer,
batch_size=config.batch_size,
max_sen_len=config.max_sen_len,
split_sep=config.split_sep,
max_position_embeddings=config.max_position_embeddings,
pad_index=config.pad_token_id,
is_sample_shuffle=config.is_sample_shuffle,
)
train_loader, val_loader, test_loader = dataset.data_loader(
config.train_filepath, config.val_filepath, config.test_filepath
)
# 2.从本地BERT模型文件创建文本对分类模型
model = BertForSenCls(config, config.pretrained_model_dir)
# 若不是第一次训练,则加载已有权重
model_save_file = os.path.join(config.model_save_dir, "pair_sen_cls_model.pt")
if os.path.exists(model_save_file):
loaded_params = torch.load(model_save_file)
model.load_state_dict(loaded_params)
logging.info("## 成功载入已有模型,继续训练......")
model = model.to(config.device)
# 3.定义优化器和动态学习率
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=int(len(train_loader) * 0), # 热身步数为0
num_training_steps=int(config.epochs * len(train_loader)),
)
# 4.执行训练
model.train()
best_eval_acc = 0
for epoch in range(config.epochs):
losses = 0
start_time = time.time()
for idx, (seqs, segs, labels) in enumerate(train_loader):
seqs = seqs.to(config.device) # #[seq_len, batch_size]
segs = segs.to(config.device) # #[seq_len, batch_size]
labels = labels.to(config.device) # #[batch_size,]
# #[batch_size, seq_len]
padding_mask = (seqs == dataset.PAD_IDX).transpose(0, 1)
# feed forward
loss, logits = model(
input_ids=seqs,
position_ids=None,
token_type_ids=segs,
attention_mask=padding_mask,
labels=labels,
)
optimizer.zero_grad() # 清除上一批次计算的权重梯度值
loss.backward() # 反向传播,计算梯度
optimizer.step() # 更新梯度
lr_scheduler.step() # 更新学习率
# # acc = (logits.argmax(1) == labels).float().mean()
if (idx + 1) % 50 == 0:
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Batch: [{idx + 1}/{len(train_loader)}], "
f"Batch loss: {loss.item():.3f}"
)
losses += loss.item()
end_time = time.time()
train_loss = losses / len(train_loader)
logging.info("=" * 20)
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: {train_loss:.3f}, Epoch time: {(end_time - start_time):.3f}s"
)
logging.info("=" * 20)
# 评估模型性能
if (epoch + 1) % config.eval_per_epoch == 0:
eval_acc = evaluate(model, val_loader, dataset.PAD_IDX, config.device)
# 保存性能最好的模型权重
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
state_dict = copy.deepcopy(model.state_dict())
torch.save(state_dict, model_save_file)
logging.info("=" * 20)
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Eval acc: {eval_acc:.3f}, Best eval acc: {best_eval_acc:.3f}"
)
logging.info("=" * 20)
def inference(config: ModelConfig):
# 1.加载数据集并预处理
dataset = LoadPairSenClsDataset(
vocab_path=config.vocab_path,
tokenizer=BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize,
batch_size=config.batch_size,
max_sen_len=config.max_sen_len,
split_sep=config.split_sep,
max_position_embeddings=config.max_position_embeddings,
pad_index=config.pad_token_id,
is_sample_shuffle=config.is_sample_shuffle,
)
# 仅加载测试集
test_loader = dataset.data_loader(
test_filepath=config.test_filepath, only_test=True
)
# 2.创建模型并加载权重
model = BertForSenCls(config, config.pretrained_model_dir)
model_save_file = os.path.join(config.model_save_dir, "pair_sen_cls_model.pt")
if os.path.exists(model_save_file):
loaded_params = torch.load(model_save_file)
model.load_state_dict(loaded_params)
logging.info("## 成功载入已有模型,执行推理......")
else:
raise FileNotFoundError("未找到模型权重文件,请先执行模型训练......")
model = model.to(config.device)
# 3.执行推理
test_acc = evaluate(
model, test_loader, pad_idx=dataset.PAD_IDX, device=config.device
)
return test_acc
if __name__ == "__main__":
model_config = ModelConfig()
# 训练
train(model_config)
# 推理
infer_ct = 5
for idx in range(infer_ct):
start_time = time.time()
infer_acc = inference(model_config)
end_time = time.time()
print(
f"Infer number: [{idx + 1}/{infer_ct}], Acc: {infer_acc:.3f}, Cost time: {end_time - start_time:.4f}s"
)

View File

@ -94,7 +94,7 @@ def train(config: ModelConfig):
pad_index=config.pad_token_id,
is_sample_shuffle=config.is_sample_shuffle,
)
train_loader, test_loader, val_loader = dataset.data_loader(
train_loader, val_loader, test_loader = dataset.data_loader(
config.train_filepath, config.val_filepath, config.test_filepath
)

View File

@ -22,20 +22,20 @@ if __name__ == "__main__":
is_sample_shuffle=model_config.is_sample_shuffle,
)
train_loader, test_loader, val_loader = dataset.data_loader(
train_loader, val_loader, test_loader = dataset.data_loader(
model_config.train_filepath,
model_config.val_filepath,
model_config.test_filepath,
)
for sample, seg, label in train_loader:
print(sample.shape) # #[seq_len, batch_size]
print(sample.transpose(0, 1)) # #[batch_size, seq_len]
print(seg.shape) # #[seq_len, batch_size]
print(label.shape) # #[batch_size,]
print(label)
for seqs, segs, labels in train_loader:
print(seqs.shape) # #[seq_len, batch_size]
print(seqs.transpose(0, 1)) # #[batch_size, seq_len]
print(segs.shape) # #[seq_len, batch_size]
print(labels.shape) # #[batch_size,]
print(labels)
padding_mask = (sample == dataset.PAD_IDX).transpose(0, 1)
padding_mask = (seqs == dataset.PAD_IDX).transpose(0, 1)
print(padding_mask.shape) # #[batch_size, seq_len]
break

View File

@ -25,7 +25,7 @@ if __name__ == "__main__":
is_sample_shuffle=model_config.is_sample_shuffle,
)
train_loader, test_loader, val_loader = dataset.data_loader(
train_loader, val_loader, test_loader = dataset.data_loader(
model_config.train_filepath,
model_config.val_filepath,
model_config.test_filepath,

View File

@ -210,7 +210,7 @@ class LoadSenClsDataset:
collate_fn=self.generate_batch,
)
return train_loader, test_loader, val_loader
return train_loader, val_loader, test_loader
def generate_batch(self, data_batch):
"""
@ -243,7 +243,7 @@ class LoadPairSenClsDataset(LoadSenClsDataset):
super().__init__(**kwargs)
pass
# 重载子类中的token_to_idx和generate_batch方法
# 重载父类LoadSenClsDataset中的token_to_idx和generate_batch方法
@cache_decorator
def token_to_idx(self, filepath=None):
"""
@ -264,14 +264,14 @@ class LoadPairSenClsDataset(LoadSenClsDataset):
# 将两个索引序列拼接成一个序列,并添加[CLS]、[SEP] token
idx_seq = [self.CLS_IDX] + idx_seq1 + [self.SEP_IDX] + idx_seq2
# BERT模型最大支持512个token的序列
# BERT模型最大支持512个token的序列,若超过,则截断
if len(idx_seq) > self.max_position_embeddings - 1:
idx_seq = idx_seq[: self.max_position_embeddings - 1]
idx_seq += [self.SEP_IDX]
# 创建token_type_id序列用于表示token所在序列
seg_seq1 = [0] * (len(idx_seq1) + 2) # 起始[CLS]和中间的[SEP]两个token属于第一个序列
seg_seq2 = [1] * (len(idx_seq) - len(seg_seq1)) # 末尾的[SEP]token属于第二个序列
seg_seq2 = [1] * (len(idx_seq) - len(seg_seq1)) # 末尾的[SEP]token属于第二个序列
idx_seq = torch.tensor(idx_seq, dtype=torch.long)
seg_seq = torch.tensor(seg_seq1 + seg_seq2, dtype=torch.long)