实现基于BERT的文本分类任务

This commit is contained in:
huaian_zhou 2023-12-13 16:29:52 +08:00
parent 756b1cfaad
commit 2db654deed
5 changed files with 241 additions and 4 deletions

View File

View File

@ -0,0 +1,57 @@
import torch.nn as nn
from ..BERT.bert import BertModel
class BertForSenCls(nn.Module):
"""基于BERT的文本分类模型"""
def __init__(self, config, pretrained_model_dir=None):
"""
:param pretrained_model_dir: 预训练BERT模型文件所在目录
"""
super().__init__()
# 预训练模型文件不为空则从该文件创建BERT模型否则创建随机初始化权重的新BERT模型
if pretrained_model_dir is not None:
self.bert = BertModel.from_pretrained(config, pretrained_model_dir)
else:
self.bert = BertModel(config)
self.num_labels = config.num_labels # 分类类别数
# 在BERT之上添加的分类层
self.classifier = nn.Sequential(
nn.Dropout(config.hidden_dropout_prob),
nn.Linear(config.hidden_size, self.num_labels),
)
def forward(
self,
input_ids,
position_ids=None,
token_type_ids=None,
attention_mask=None,
labels=None,
):
"""
:param input_ids: `#[src_len, batch_size]`
:param position_ids: `#[1, src_len]`
:param token_type_ids: `#[src_len, batch_size]`句子分类任务中输入的token属于同一序列所以该值置为None
:param attention_mask: Padding mask `#[batch_size, src_len]`
:param labels: 句子的真实标签`#[batch_size,]`
"""
# 取[CLS] token对应的embedding或者所有token embedding的平均值作为整个序列语义的表示
# #[batch_size, hidden_size]
pooled_out, _ = self.bert(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
)
logits = self.classifier(pooled_out) # #[batch_size, num_label]
# 若传入了真实标签则计算loss
if labels is not None:
loss_fc = nn.CrossEntropyLoss() # 交叉熵损失
loss = loss_fc(logits.view(-1, self.num_labels), labels.view(-1))
return loss, logits
else:
return logits

View File

@ -3,10 +3,15 @@ import sys
sys.path.append(os.getcwd())
import time
import copy
import torch
import logging
from model.BERT.config import BertConfig
from utils.log_helper import logger_init
from transformers import BertTokenizer
from utils.data_helpers import LoadSenClsDataset
from model.downstream.bert_for_sen_cls import BertForSenCls
class ModelConfig:
@ -56,3 +61,150 @@ class ModelConfig:
logging.info("### 将当前配置打印到日志文件中")
for key, value in self.__dict__.items():
logging.info(f"### {key} = {value}")
def evaluate(model, data_loader, pad_idx, device="cpu"):
model.eval()
with torch.no_grad():
corrects, total = 0.0, 0
for seqs, labels in data_loader:
seqs, labels = seqs.to(device), labels.to(device)
padding_mask = (seqs == pad_idx).transpose(0, 1)
logits = model(seqs, attention_mask=padding_mask)
corrects += (logits.argmax(1) == labels).float().sum().item()
total += len(labels)
model.train()
return corrects / total
def train(config: ModelConfig):
"""训练过程"""
# 1.加载数据集并预处理
# 借用transformers框架中的分词器
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
dataset = LoadSenClsDataset(
vocab_path=config.vocab_path,
tokenizer=tokenizer,
batch_size=config.batch_size,
max_sen_len=config.max_sen_len,
split_sep=config.split_sep,
max_position_embeddings=config.max_position_embeddings,
pad_index=config.pad_token_id,
is_sample_shuffle=config.is_sample_shuffle,
)
train_loader, test_loader, val_loader = dataset.data_loader(
config.train_filepath, config.val_filepath, config.test_filepath
)
# 2.从本地BERT模型文件创建文本分类模型
model = BertForSenCls(config, config.pretrained_model_dir)
# 若不是第一次训练,则加载已有权重
model_save_path = os.path.join(config.model_save_dir, "sen_cls_model.pt")
if os.path.exists(model_save_path):
loaded_params = torch.load(model_save_path)
model.load_state_dict(loaded_params)
logging.info("## 成功载入已有模型,继续训练......")
model = model.to(config.device)
# 3.定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# 4.执行训练
model.train()
best_eval_acc = 0.0
for epoch in range(config.epochs):
losses = 0.0
start_time = time.time()
for idx, (seqs, labels) in enumerate(train_loader):
seqs = seqs.to(config.device) # #[src_len, batch_size]
labels = labels.to(config.device)
# #[batch_size, src_len]
padding_mask = (seqs == dataset.PAD_IDX).transpose(0, 1)
# feed forward
loss, logits = model(
input_ids=seqs,
position_ids=None,
token_type_ids=None, # 输入的序列是单个文本序列
attention_mask=padding_mask,
labels=labels,
)
optimizer.zero_grad() # 清除上一批次计算的权重梯度值
loss.backward() # 反向传播,计算梯度
optimizer.step() # 更新权重
# # acc = (logits.argmax(1) == labels).float().mean()
if (idx + 1) % 50 == 0:
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Batch: [{idx + 1}/{len(train_loader)}], "
f"Batch loss: {loss.item():.3f}"
)
losses += loss.item()
end_time = time.time()
train_loss = losses / len(train_loader)
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: {train_loss:.3f}, Epoch time: {(end_time - start_time):.3f}s"
)
# 评估模型性能
if (epoch + 1) % config.eval_per_epoch == 0:
eval_acc = evaluate(model, val_loader, dataset.PAD_IDX, config.device)
# 保存性能最好的模型权重
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
state_dict = copy.deepcopy(model.state_dict())
torch.save(state_dict, model_save_path)
logging.info(
f"Epoch: [{epoch + 1}/{config.epochs}], Eval acc: {eval_acc:.3f}, Best eval acc: {best_eval_acc:.3f}"
)
def inference(config: ModelConfig):
"""推理过程"""
# 1.加载数据集并预处理
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
dataset = LoadSenClsDataset(
vocab_path=config.vocab_path,
tokenizer=tokenizer,
batch_size=config.batch_size,
max_sen_len=config.max_sen_len,
split_sep=config.split_sep,
max_position_embeddings=config.max_position_embeddings,
pad_index=config.pad_token_id,
is_sample_shuffle=config.is_sample_shuffle,
)
# 仅加载测试集即可
test_loader = dataset.data_loader(
test_filepath=config.test_filepath, only_test=True
)
# 2.创建模型并加载权重
model = BertForSenCls(config, config.pretrained_model_dir)
model_save_path = os.path.join(config.model_save_dir, "sen_cls_model.pt")
if os.path.exists(model_save_path):
loaded_params = torch.load(model_save_path)
model.load_state_dict(loaded_params)
print("## 成功载入已有模型,执行推理......")
else:
raise FileNotFoundError("未找到模型权重文件,请先执行模型训练......")
model = model.to(config.device)
test_acc = evaluate(model, test_loader, dataset.PAD_IDX, config.device)
return test_acc
if __name__ == "__main__":
model_config = ModelConfig()
# 训练
train(model_config)
# 推理
infer_ct = 5
for idx in range(infer_ct):
start_time = time.time()
infer_acc = inference(model_config)
end_time = time.time()
print(
f"Infer number: [{idx + 1}/{infer_ct}], Acc: {infer_acc:.3f}, Cost time: {end_time - start_time:.4f}s"
)

View File

@ -0,0 +1,28 @@
import os
import sys
sys.path.append(os.getcwd())
import torch
from model.downstream.bert_for_sen_cls import BertForSenCls
from model.BERT.config import BertConfig
if __name__ == "__main__":
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
config.__dict__["num_labels"] = 10
# # config.__dict__["num_hidden_layers"] = 3
model = BertForSenCls(config)
# #[src_len, batch_size] [6, 2]
input_ids = torch.tensor(
[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], dtype=torch.long
).transpose(0, 1)
# #[batch_size, src_len] [2, 6]
attention_mask = torch.tensor(
[
[False, False, False, False, False, True],
[False, False, False, True, True, True],
]
)
logits = model(input_ids=input_ids, attention_mask=attention_mask)
print(logits.shape) # #[batch_size, num_labels] [2, 10]

View File

@ -181,10 +181,6 @@ class LoadSenClsDataset:
创建DataLoader
:param only_test: 是否只返回测试集
"""
train_data, max_len = self.token_to_idx(filepath=train_filepath)
if self.max_sen_len == "same":
self.max_sen_len = max_len
test_data, _ = self.token_to_idx(filepath=test_filepath)
test_loader = DataLoader(
test_data,
@ -195,6 +191,10 @@ class LoadSenClsDataset:
if only_test:
return test_loader
train_data, max_len = self.token_to_idx(filepath=train_filepath)
if self.max_sen_len == "same":
self.max_sen_len = max_len
train_loader = DataLoader(
train_data,
batch_size=self.batch_size,