实现基于BERT的文本分类任务
This commit is contained in:
parent
756b1cfaad
commit
2db654deed
|
@ -0,0 +1,57 @@
|
|||
import torch.nn as nn
|
||||
from ..BERT.bert import BertModel
|
||||
|
||||
|
||||
class BertForSenCls(nn.Module):
|
||||
"""基于BERT的文本分类模型"""
|
||||
|
||||
def __init__(self, config, pretrained_model_dir=None):
|
||||
"""
|
||||
:param pretrained_model_dir: 预训练BERT模型文件所在目录
|
||||
"""
|
||||
super().__init__()
|
||||
# 预训练模型文件不为空,则从该文件创建BERT模型,否则创建随机初始化权重的新BERT模型
|
||||
if pretrained_model_dir is not None:
|
||||
self.bert = BertModel.from_pretrained(config, pretrained_model_dir)
|
||||
else:
|
||||
self.bert = BertModel(config)
|
||||
|
||||
self.num_labels = config.num_labels # 分类类别数
|
||||
# 在BERT之上添加的分类层
|
||||
self.classifier = nn.Sequential(
|
||||
nn.Dropout(config.hidden_dropout_prob),
|
||||
nn.Linear(config.hidden_size, self.num_labels),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
):
|
||||
"""
|
||||
:param input_ids: `#[src_len, batch_size]`
|
||||
:param position_ids: `#[1, src_len]`
|
||||
:param token_type_ids: `#[src_len, batch_size]`,句子分类任务中,输入的token属于同一序列,所以该值置为None
|
||||
:param attention_mask: Padding mask `#[batch_size, src_len]`
|
||||
:param labels: 句子的真实标签,`#[batch_size,]`
|
||||
"""
|
||||
# 取[CLS] token对应的embedding(或者所有token embedding的平均值)作为整个序列语义的表示
|
||||
# #[batch_size, hidden_size]
|
||||
pooled_out, _ = self.bert(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
logits = self.classifier(pooled_out) # #[batch_size, num_label]
|
||||
|
||||
# 若传入了真实标签,则计算loss
|
||||
if labels is not None:
|
||||
loss_fc = nn.CrossEntropyLoss() # 交叉熵损失
|
||||
loss = loss_fc(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss, logits
|
||||
else:
|
||||
return logits
|
|
@ -3,10 +3,15 @@ import sys
|
|||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
import time
|
||||
import copy
|
||||
import torch
|
||||
import logging
|
||||
from model.BERT.config import BertConfig
|
||||
from utils.log_helper import logger_init
|
||||
from transformers import BertTokenizer
|
||||
from utils.data_helpers import LoadSenClsDataset
|
||||
from model.downstream.bert_for_sen_cls import BertForSenCls
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
@ -56,3 +61,150 @@ class ModelConfig:
|
|||
logging.info("### 将当前配置打印到日志文件中")
|
||||
for key, value in self.__dict__.items():
|
||||
logging.info(f"### {key} = {value}")
|
||||
|
||||
|
||||
def evaluate(model, data_loader, pad_idx, device="cpu"):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
corrects, total = 0.0, 0
|
||||
for seqs, labels in data_loader:
|
||||
seqs, labels = seqs.to(device), labels.to(device)
|
||||
padding_mask = (seqs == pad_idx).transpose(0, 1)
|
||||
logits = model(seqs, attention_mask=padding_mask)
|
||||
corrects += (logits.argmax(1) == labels).float().sum().item()
|
||||
total += len(labels)
|
||||
|
||||
model.train()
|
||||
return corrects / total
|
||||
|
||||
|
||||
def train(config: ModelConfig):
|
||||
"""训练过程"""
|
||||
|
||||
# 1.加载数据集并预处理
|
||||
# 借用transformers框架中的分词器
|
||||
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
|
||||
dataset = LoadSenClsDataset(
|
||||
vocab_path=config.vocab_path,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=config.batch_size,
|
||||
max_sen_len=config.max_sen_len,
|
||||
split_sep=config.split_sep,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
pad_index=config.pad_token_id,
|
||||
is_sample_shuffle=config.is_sample_shuffle,
|
||||
)
|
||||
train_loader, test_loader, val_loader = dataset.data_loader(
|
||||
config.train_filepath, config.val_filepath, config.test_filepath
|
||||
)
|
||||
|
||||
# 2.从本地BERT模型文件创建文本分类模型
|
||||
model = BertForSenCls(config, config.pretrained_model_dir)
|
||||
# 若不是第一次训练,则加载已有权重
|
||||
model_save_path = os.path.join(config.model_save_dir, "sen_cls_model.pt")
|
||||
if os.path.exists(model_save_path):
|
||||
loaded_params = torch.load(model_save_path)
|
||||
model.load_state_dict(loaded_params)
|
||||
logging.info("## 成功载入已有模型,继续训练......")
|
||||
model = model.to(config.device)
|
||||
|
||||
# 3.定义优化器
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
||||
|
||||
# 4.执行训练
|
||||
model.train()
|
||||
best_eval_acc = 0.0
|
||||
for epoch in range(config.epochs):
|
||||
losses = 0.0
|
||||
start_time = time.time()
|
||||
for idx, (seqs, labels) in enumerate(train_loader):
|
||||
seqs = seqs.to(config.device) # #[src_len, batch_size]
|
||||
labels = labels.to(config.device)
|
||||
# #[batch_size, src_len]
|
||||
padding_mask = (seqs == dataset.PAD_IDX).transpose(0, 1)
|
||||
# feed forward
|
||||
loss, logits = model(
|
||||
input_ids=seqs,
|
||||
position_ids=None,
|
||||
token_type_ids=None, # 输入的序列是单个文本序列
|
||||
attention_mask=padding_mask,
|
||||
labels=labels,
|
||||
)
|
||||
optimizer.zero_grad() # 清除上一批次计算的权重梯度值
|
||||
loss.backward() # 反向传播,计算梯度
|
||||
optimizer.step() # 更新权重
|
||||
|
||||
# # acc = (logits.argmax(1) == labels).float().mean()
|
||||
if (idx + 1) % 50 == 0:
|
||||
logging.info(
|
||||
f"Epoch: [{epoch + 1}/{config.epochs}], Batch: [{idx + 1}/{len(train_loader)}], "
|
||||
f"Batch loss: {loss.item():.3f}"
|
||||
)
|
||||
losses += loss.item()
|
||||
|
||||
end_time = time.time()
|
||||
train_loss = losses / len(train_loader)
|
||||
logging.info(
|
||||
f"Epoch: [{epoch + 1}/{config.epochs}], Train loss: {train_loss:.3f}, Epoch time: {(end_time - start_time):.3f}s"
|
||||
)
|
||||
# 评估模型性能
|
||||
if (epoch + 1) % config.eval_per_epoch == 0:
|
||||
eval_acc = evaluate(model, val_loader, dataset.PAD_IDX, config.device)
|
||||
# 保存性能最好的模型权重
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
state_dict = copy.deepcopy(model.state_dict())
|
||||
torch.save(state_dict, model_save_path)
|
||||
logging.info(
|
||||
f"Epoch: [{epoch + 1}/{config.epochs}], Eval acc: {eval_acc:.3f}, Best eval acc: {best_eval_acc:.3f}"
|
||||
)
|
||||
|
||||
|
||||
def inference(config: ModelConfig):
|
||||
"""推理过程"""
|
||||
|
||||
# 1.加载数据集并预处理
|
||||
tokenizer = BertTokenizer.from_pretrained(config.pretrained_model_dir).tokenize
|
||||
dataset = LoadSenClsDataset(
|
||||
vocab_path=config.vocab_path,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=config.batch_size,
|
||||
max_sen_len=config.max_sen_len,
|
||||
split_sep=config.split_sep,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
pad_index=config.pad_token_id,
|
||||
is_sample_shuffle=config.is_sample_shuffle,
|
||||
)
|
||||
# 仅加载测试集即可
|
||||
test_loader = dataset.data_loader(
|
||||
test_filepath=config.test_filepath, only_test=True
|
||||
)
|
||||
|
||||
# 2.创建模型并加载权重
|
||||
model = BertForSenCls(config, config.pretrained_model_dir)
|
||||
model_save_path = os.path.join(config.model_save_dir, "sen_cls_model.pt")
|
||||
if os.path.exists(model_save_path):
|
||||
loaded_params = torch.load(model_save_path)
|
||||
model.load_state_dict(loaded_params)
|
||||
print("## 成功载入已有模型,执行推理......")
|
||||
else:
|
||||
raise FileNotFoundError("未找到模型权重文件,请先执行模型训练......")
|
||||
model = model.to(config.device)
|
||||
|
||||
test_acc = evaluate(model, test_loader, dataset.PAD_IDX, config.device)
|
||||
return test_acc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_config = ModelConfig()
|
||||
# 训练
|
||||
train(model_config)
|
||||
# 推理
|
||||
infer_ct = 5
|
||||
for idx in range(infer_ct):
|
||||
start_time = time.time()
|
||||
infer_acc = inference(model_config)
|
||||
end_time = time.time()
|
||||
print(
|
||||
f"Infer number: [{idx + 1}/{infer_ct}], Acc: {infer_acc:.3f}, Cost time: {end_time - start_time:.4f}s"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
import torch
|
||||
from model.downstream.bert_for_sen_cls import BertForSenCls
|
||||
from model.BERT.config import BertConfig
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
config.__dict__["num_labels"] = 10
|
||||
# # config.__dict__["num_hidden_layers"] = 3
|
||||
model = BertForSenCls(config)
|
||||
|
||||
# #[src_len, batch_size] [6, 2]
|
||||
input_ids = torch.tensor(
|
||||
[[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
# #[batch_size, src_len] [2, 6]
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False, False, False, True],
|
||||
[False, False, False, True, True, True],
|
||||
]
|
||||
)
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
print(logits.shape) # #[batch_size, num_labels] [2, 10]
|
|
@ -181,10 +181,6 @@ class LoadSenClsDataset:
|
|||
创建DataLoader
|
||||
:param only_test: 是否只返回测试集
|
||||
"""
|
||||
train_data, max_len = self.token_to_idx(filepath=train_filepath)
|
||||
if self.max_sen_len == "same":
|
||||
self.max_sen_len = max_len
|
||||
|
||||
test_data, _ = self.token_to_idx(filepath=test_filepath)
|
||||
test_loader = DataLoader(
|
||||
test_data,
|
||||
|
@ -195,6 +191,10 @@ class LoadSenClsDataset:
|
|||
if only_test:
|
||||
return test_loader
|
||||
|
||||
train_data, max_len = self.token_to_idx(filepath=train_filepath)
|
||||
if self.max_sen_len == "same":
|
||||
self.max_sen_len = max_len
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_data,
|
||||
batch_size=self.batch_size,
|
||||
|
|
Loading…
Reference in New Issue