基于BERT的文本分类模型数据集预处理
This commit is contained in:
parent
c553948cc5
commit
756b1cfaad
|
@ -0,0 +1,43 @@
|
||||||
|
## 数据集下载
|
||||||
|
https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset
|
||||||
|
|
||||||
|
## 数据格式
|
||||||
|
|
||||||
|
```text
|
||||||
|
6552431613437805063_!_102_!_news_entertainment_!_谢娜为李浩菲澄清网络谣言,之后她的两个行为给自己加分_!_佟丽娅,网络谣言,快乐大本营,李浩菲,谢娜,观众们
|
||||||
|
```
|
||||||
|
每行一条数据,以 `_!_` 分割字段,从前往后分别是:新闻ID,类别代码(见下文),类别名称(见下文),新闻标题文本,新闻关键词
|
||||||
|
|
||||||
|
## 类别与名称
|
||||||
|
|
||||||
|
```text
|
||||||
|
100 民生 故事 news_story
|
||||||
|
101 文化 文化 news_culture
|
||||||
|
102 娱乐 娱乐 news_entertainment
|
||||||
|
103 体育 体育 news_sports
|
||||||
|
104 财经 财经 news_finance
|
||||||
|
106 房产 房产 news_house
|
||||||
|
107 汽车 汽车 news_car
|
||||||
|
108 教育 教育 news_edu
|
||||||
|
109 科技 科技 news_tech
|
||||||
|
110 军事 军事 news_military
|
||||||
|
112 旅游 旅游 news_travel
|
||||||
|
113 国际 国际 news_world
|
||||||
|
114 证券 股票 stock
|
||||||
|
115 农业 三农 news_agriculture
|
||||||
|
116 电竞 游戏 news_game
|
||||||
|
```
|
||||||
|
## 数据规模
|
||||||
|
共382688条、15个分类
|
||||||
|
|
||||||
|
## 数据预处理
|
||||||
|
原始数据集下载完成后,运行当前文件夹中的 `format.py` 脚本文件即可将原始数据按照 `7:2:1` 的比例划分成规整的训练集 `toutiao_train.txt`、验证集 `toutiao_val.txt` 和 测试集 `test.txt`
|
||||||
|
|
||||||
|
处理完成后的数据格式如下:
|
||||||
|
```text
|
||||||
|
轻松一刻:带你看全球最噩梦监狱,每天进几百人,审讯时已过几年_!_11
|
||||||
|
千万不要乱申请网贷,否则后果很严重_!_4
|
||||||
|
10年前的今年,纪念5.12汶川大地震10周年_!_11
|
||||||
|
怎么看待杨毅在一NBA直播比赛中说詹姆斯的球场统治力已经超过乔丹、伯德和科比?_!_3
|
||||||
|
戴安娜王妃的车祸有什么谜团?_!_2
|
||||||
|
```
|
|
@ -0,0 +1,60 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
label_map = {
|
||||||
|
"100": "0",
|
||||||
|
"101": "1",
|
||||||
|
"102": "2",
|
||||||
|
"103": "3",
|
||||||
|
"104": "4",
|
||||||
|
"106": "5",
|
||||||
|
"107": "6",
|
||||||
|
"108": "7",
|
||||||
|
"109": "8",
|
||||||
|
"110": "9",
|
||||||
|
"112": "10",
|
||||||
|
"113": "11",
|
||||||
|
"114": "12",
|
||||||
|
"115": "13",
|
||||||
|
"116": "14",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def format(filepath="./toutiao_cat_data.txt"):
|
||||||
|
np.random.seed(42)
|
||||||
|
raw_data = open(filepath, "r", encoding="utf-8").readlines()
|
||||||
|
num_samples = len(raw_data)
|
||||||
|
num_train, num_val = int(0.7 * num_samples), int(0.2 * num_samples)
|
||||||
|
num_test = num_samples - num_train - num_val
|
||||||
|
# 生成随机索引及三个数据集样本索引
|
||||||
|
idx = np.random.permutation(num_samples)
|
||||||
|
train_idx, val_idx, test_idx = (
|
||||||
|
idx[:num_train],
|
||||||
|
idx[num_train : num_train + num_val],
|
||||||
|
idx[-num_test:],
|
||||||
|
)
|
||||||
|
# 写入文件
|
||||||
|
f_train = open("./toutiao_train.txt", "w", encoding="utf-8")
|
||||||
|
f_val = open("./toutiao_val.txt", "w", encoding="utf-8")
|
||||||
|
f_test = open("./toutiao_test.txt", "w", encoding="utf-8")
|
||||||
|
|
||||||
|
for i in train_idx:
|
||||||
|
line = raw_data[i].strip("\n").split("_!_")
|
||||||
|
label, text = label_map[line[1]], line[3]
|
||||||
|
f_train.write(text + "_!_" + label + "\n")
|
||||||
|
f_train.close()
|
||||||
|
|
||||||
|
for i in val_idx:
|
||||||
|
line = raw_data[i].strip("\n").split("_!_")
|
||||||
|
label, text = label_map[line[1]], line[3]
|
||||||
|
f_val.write(text + "_!_" + label + "\n")
|
||||||
|
f_val.close()
|
||||||
|
|
||||||
|
for i in test_idx:
|
||||||
|
line = raw_data[i].strip("\n").split("_!_")
|
||||||
|
label, text = label_map[line[1]], line[3]
|
||||||
|
f_test.write(text + "_!_" + label + "\n")
|
||||||
|
f_test.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
format()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,58 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from model.BERT.config import BertConfig
|
||||||
|
from utils.log_helper import logger_init
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig:
|
||||||
|
"""基于BERT的文本分类模型的配置类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
self.dataset_dir = os.path.join(
|
||||||
|
self.project_dir, "data", "sentence_classification"
|
||||||
|
)
|
||||||
|
self.pretrained_model_dir = os.path.join(
|
||||||
|
self.project_dir, "archive", "bert_base_chinese"
|
||||||
|
)
|
||||||
|
self.vocab_path = os.path.join(self.pretrained_model_dir, "vocab.txt")
|
||||||
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
|
self.train_filepath = os.path.join(self.dataset_dir, "toutiao_train.txt")
|
||||||
|
self.val_filepath = os.path.join(self.dataset_dir, "toutiao_val.txt")
|
||||||
|
self.test_filepath = os.path.join(self.dataset_dir, "toutiao_test.txt")
|
||||||
|
|
||||||
|
self.model_save_dir = os.path.join(self.project_dir, "cache")
|
||||||
|
if not os.path.exists(self.model_save_dir):
|
||||||
|
os.makedirs(self.model_save_dir)
|
||||||
|
self.log_save_dir = os.path.join(self.project_dir, "logs")
|
||||||
|
|
||||||
|
self.epochs = 10
|
||||||
|
self.batch_size = 64
|
||||||
|
self.num_labels = 15
|
||||||
|
self.split_sep = "_!_"
|
||||||
|
self.is_sample_shuffle = True # 是否打乱数据集
|
||||||
|
self.max_sen_len = None # 填充模式
|
||||||
|
self.eval_per_epoch = 2 # 验证模型的epoch数
|
||||||
|
|
||||||
|
logger_init(
|
||||||
|
log_filename="sen_cls",
|
||||||
|
log_level=logging.INFO,
|
||||||
|
log_dir=self.log_save_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 导入BERT模型部分配置
|
||||||
|
bert_config_path = os.path.join(self.pretrained_model_dir, "config.json")
|
||||||
|
bert_config = BertConfig.from_json_file(bert_config_path)
|
||||||
|
for key, value in bert_config.__dict__.items():
|
||||||
|
self.__dict__[key] = value
|
||||||
|
|
||||||
|
# 将当前配置打印到日志文件中
|
||||||
|
logging.info("=" * 20)
|
||||||
|
logging.info("### 将当前配置打印到日志文件中")
|
||||||
|
for key, value in self.__dict__.items():
|
||||||
|
logging.info(f"### {key} = {value}")
|
|
@ -4,7 +4,7 @@ import sys
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from util.log_helper import logger_init
|
from utils.log_helper import logger_init
|
||||||
from model.BERT.config import BertConfig
|
from model.BERT.config import BertConfig
|
||||||
from model.BERT.bert import BertEmbedding
|
from model.BERT.bert import BertEmbedding
|
||||||
from model.BERT.bert import BertAttention
|
from model.BERT.bert import BertAttention
|
||||||
|
@ -14,7 +14,7 @@ from model.BERT.bert import BertModel
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger_init(log_filename="test", log_level=logging.DEBUG)
|
logger_init(log_filename="test_bert", log_level=logging.DEBUG)
|
||||||
|
|
||||||
json_file = "./archive/bert_base_chinese/config.json"
|
json_file = "./archive/bert_base_chinese/config.json"
|
||||||
config = BertConfig.from_json_file(json_file)
|
config = BertConfig.from_json_file(json_file)
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
from tasks.sentence_classification import ModelConfig
|
||||||
|
from utils.data_helpers import LoadSenClsDataset
|
||||||
|
from transformers import BertTokenizer # 借用transformers框架中的分词器
|
||||||
|
|
||||||
|
# 加载数据集和预处理:1.分词(使用BertTokenizer);2.创建词表(读取已有文件vocab.txt);3.把token转换为索引序列,添加[CLS]和[SEP] token;
|
||||||
|
# 4.填充;5.构造DataLoader
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_config = ModelConfig()
|
||||||
|
dataset = LoadSenClsDataset(
|
||||||
|
vocab_path=model_config.vocab_path,
|
||||||
|
tokenizer=BertTokenizer.from_pretrained(
|
||||||
|
model_config.pretrained_model_dir
|
||||||
|
).tokenize,
|
||||||
|
batch_size=model_config.batch_size,
|
||||||
|
max_sen_len=model_config.max_sen_len,
|
||||||
|
split_sep=model_config.split_sep,
|
||||||
|
max_position_embeddings=model_config.max_position_embeddings,
|
||||||
|
pad_index=model_config.pad_token_id,
|
||||||
|
is_sample_shuffle=model_config.is_sample_shuffle,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader, test_loader, val_loader = dataset.data_loader(
|
||||||
|
model_config.train_filepath,
|
||||||
|
model_config.val_filepath,
|
||||||
|
model_config.test_filepath,
|
||||||
|
)
|
||||||
|
|
||||||
|
for sample, label in train_loader:
|
||||||
|
print(sample.shape) # #[seq_len, batch_size]
|
||||||
|
print(sample.transpose(0, 1))
|
||||||
|
print(label)
|
||||||
|
|
||||||
|
# #[batch_size, seq_len]
|
||||||
|
padding_mask = (sample == dataset.PAD_IDX).transpose(0, 1)
|
||||||
|
print(padding_mask)
|
||||||
|
|
||||||
|
break
|
|
@ -1,43 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
def logger_init(
|
|
||||||
log_filename="monitor", log_level=logging.DEBUG, log_dir="./log/", only_file=False
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
log_filename: 日志文件名.
|
|
||||||
log_level: 日志等级.
|
|
||||||
log_dir: 日志目录.
|
|
||||||
only_file: 是否只保存到日志文件中.
|
|
||||||
"""
|
|
||||||
# 指定日志文件路径
|
|
||||||
if not os.path.exists(log_dir):
|
|
||||||
os.makedirs(log_dir)
|
|
||||||
log_filepath = os.path.join(
|
|
||||||
log_dir, log_filename + "_" + str(datetime.now())[:10] + ".txt"
|
|
||||||
)
|
|
||||||
# 指定日志格式
|
|
||||||
formatter = "[%(asctime)s] - %(levelname)s: %(message)s"
|
|
||||||
# 只保存到日志文件中
|
|
||||||
if only_file:
|
|
||||||
logging.basicConfig(
|
|
||||||
filename=log_filepath,
|
|
||||||
level=log_level,
|
|
||||||
format=formatter,
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
)
|
|
||||||
# 保存到日志文件并输出到终端
|
|
||||||
else:
|
|
||||||
logging.basicConfig(
|
|
||||||
level=log_level,
|
|
||||||
format=formatter,
|
|
||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler(log_filepath),
|
|
||||||
logging.StreamHandler(sys.stdout),
|
|
||||||
],
|
|
||||||
)
|
|
|
@ -0,0 +1,236 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
class Vocab:
|
||||||
|
"""
|
||||||
|
根据本地的vocab文件,构造词表——微调预训练模型时无需从训练数据创建词表
|
||||||
|
"""
|
||||||
|
|
||||||
|
UNK = "[UNK]"
|
||||||
|
|
||||||
|
def __init__(self, vocab_path):
|
||||||
|
self.stoi = {} # 字典,记录词和索引的键值对
|
||||||
|
self.itos = [] # 列表,记录词表中所有词
|
||||||
|
with open(vocab_path, "r", encoding="utf-8") as f:
|
||||||
|
for idx, token in enumerate(f):
|
||||||
|
token = token.strip("\n")
|
||||||
|
self.stoi[token] = idx
|
||||||
|
self.itos.append(token)
|
||||||
|
|
||||||
|
def __getitem__(self, token):
|
||||||
|
"""
|
||||||
|
获取token的索引,支持vocab[token]的方式访问
|
||||||
|
"""
|
||||||
|
return self.stoi.get(token, self.stoi.get(Vocab.UNK))
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
获取词表长度,支持len(vocab)的方式访问
|
||||||
|
"""
|
||||||
|
return len(self.itos)
|
||||||
|
|
||||||
|
|
||||||
|
def pad_sequence(sequences, padding_value=0, max_len=None, batch_first=False):
|
||||||
|
"""
|
||||||
|
对一个序列的样本进行填充
|
||||||
|
:param sequences: 一批序列
|
||||||
|
:param padding_value: 填充值
|
||||||
|
:param max_len: 最大序列长度,以该长度填充序列,若`==None`则以该批次内最长序列长度填充;若`==int`则以该值填充,超过部分截断
|
||||||
|
:param batch_first: 是否以batch_size作为返回tensor的第一个维度
|
||||||
|
"""
|
||||||
|
if max_len is None:
|
||||||
|
max_len = max([seq.size(0) for seq in sequences])
|
||||||
|
|
||||||
|
out_tensor = []
|
||||||
|
# 遍历每个序列,和max_len比较,填充或截断
|
||||||
|
for seq in sequences:
|
||||||
|
if seq.size(0) < max_len:
|
||||||
|
seq = torch.cat(
|
||||||
|
[seq, torch.tensor([padding_value] * (max_len - seq.size(0)))],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
seq = seq[:max_len]
|
||||||
|
|
||||||
|
out_tensor.append(seq)
|
||||||
|
out_tensor = torch.stack(out_tensor, dim=1)
|
||||||
|
# 将batch_size作为第一个维度
|
||||||
|
if batch_first:
|
||||||
|
out_tensor = out_tensor.transpose(0, 1)
|
||||||
|
|
||||||
|
return out_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def cache_decorator(func):
|
||||||
|
"""
|
||||||
|
修饰器——缓存token转换为索引的结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
filepath = kwargs["filepath"] # 文件路径
|
||||||
|
filename = "".join(filepath.split(os.sep)[-1].split(".")[:-1]) # 文件名(不包含拓展名)
|
||||||
|
filedir = f"{os.sep}".join(filepath.split(os.sep)[:-1]) # 文件目录
|
||||||
|
|
||||||
|
cache_filename = f"cache_{filename}_token2idx.pt"
|
||||||
|
cache_path = os.path.join(filedir, cache_filename)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
if not os.path.exists(cache_path):
|
||||||
|
logging.info(f"缓存文件 {cache_path} 不存在,处理数据集并缓存!")
|
||||||
|
data = func(*args, **kwargs) # token转换为索引
|
||||||
|
with open(cache_path, "wb") as f:
|
||||||
|
torch.save(data, f) # 缓存
|
||||||
|
else:
|
||||||
|
logging.info(f"缓存文件 {cache_path} 存在,载入缓存!")
|
||||||
|
with open(cache_path, "rb") as f:
|
||||||
|
data = torch.load(f)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
logging.info(f"数据预处理一共耗时{(end_time - start_time):.3f}s")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class LoadSenClsDataset:
|
||||||
|
"""加载文本分类数据集"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_path="./vocab.txt",
|
||||||
|
tokenizer=None,
|
||||||
|
batch_size=32,
|
||||||
|
max_sen_len=None,
|
||||||
|
split_sep="\n",
|
||||||
|
max_position_embeddings=512,
|
||||||
|
pad_index=0,
|
||||||
|
is_sample_shuffle=True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param vocab_path: 本地词表路径
|
||||||
|
:param tokenizer: 分词器
|
||||||
|
:param batch_size: 批次大小
|
||||||
|
:param max_sen_len: 填充模式,`="same"`时,按照整个数据集中最长序列填充样本;`=None`时,按照批次内最长序列填充样本;\
|
||||||
|
`=int`时,表示以固定长度填充样本,多余的截掉
|
||||||
|
:param split_sep: 文本和标签之间的分隔符
|
||||||
|
:param max_position_embeddings: 最大序列长度,超过部分将被截断
|
||||||
|
:param pad_index: padding token的索引
|
||||||
|
:param is_sample_shuffle: 是否打乱数据集,注意仅用于打乱训练集,而不打乱验证集和测试集
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.vocab = Vocab(vocab_path) # 读取本地vocab.txt文件,创建词表
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.batch_size = batch_size
|
||||||
|
# min(max_sen_len, max_position_embeddings)限制了最大序列长度
|
||||||
|
if isinstance(max_sen_len, int) and max_sen_len > max_position_embeddings:
|
||||||
|
max_sen_len = max_position_embeddings
|
||||||
|
self.max_sen_len = max_sen_len
|
||||||
|
|
||||||
|
self.split_sep = split_sep
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.PAD_IDX = pad_index
|
||||||
|
self.CLS_IDX = self.vocab["[CLS]"]
|
||||||
|
self.SEP_IDX = self.vocab["[SEP]"]
|
||||||
|
self.is_sample_shuffle = is_sample_shuffle
|
||||||
|
|
||||||
|
@cache_decorator
|
||||||
|
def token_to_idx(self, filepath=None):
|
||||||
|
"""
|
||||||
|
将token序列转换为索引序列,并返回最长序列长度
|
||||||
|
"""
|
||||||
|
with open(filepath, encoding="utf8") as f:
|
||||||
|
raw_iter = f.readlines()
|
||||||
|
data = [] # data列表中每个元素表示一个索引序列及标签
|
||||||
|
max_len = 0 # 最长序列长度
|
||||||
|
|
||||||
|
for raw_line in tqdm(raw_iter, ncols=80):
|
||||||
|
# 取出文本序列和类别标签
|
||||||
|
line = raw_line.rstrip("\n").split(self.split_sep)
|
||||||
|
text, label = line[0], line[1]
|
||||||
|
|
||||||
|
# 分词、转换为索引序列并添加[CLS]、[SEP] token
|
||||||
|
idx_seq = [self.CLS_IDX] + [
|
||||||
|
self.vocab[token] for token in self.tokenizer(text)
|
||||||
|
]
|
||||||
|
# BERT模型最大支持512个token的序列
|
||||||
|
if len(idx_seq) > self.max_position_embeddings - 1:
|
||||||
|
idx_seq = idx_seq[: self.max_position_embeddings - 1]
|
||||||
|
idx_seq += [self.SEP_IDX]
|
||||||
|
idx_seq = torch.tensor(idx_seq, dtype=torch.long)
|
||||||
|
label = torch.tensor(int(label), dtype=torch.long) # 类别标签0~14
|
||||||
|
|
||||||
|
max_len = max(max_len, idx_seq.size(0))
|
||||||
|
data.append((idx_seq, label))
|
||||||
|
|
||||||
|
return data, max_len
|
||||||
|
|
||||||
|
def data_loader(
|
||||||
|
self,
|
||||||
|
train_filepath=None,
|
||||||
|
val_filepath=None,
|
||||||
|
test_filepath=None,
|
||||||
|
only_test=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建DataLoader
|
||||||
|
:param only_test: 是否只返回测试集
|
||||||
|
"""
|
||||||
|
train_data, max_len = self.token_to_idx(filepath=train_filepath)
|
||||||
|
if self.max_sen_len == "same":
|
||||||
|
self.max_sen_len = max_len
|
||||||
|
|
||||||
|
test_data, _ = self.token_to_idx(filepath=test_filepath)
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_data,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=False, # 测试集不打乱
|
||||||
|
collate_fn=self.generate_batch,
|
||||||
|
)
|
||||||
|
if only_test:
|
||||||
|
return test_loader
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_data,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=self.is_sample_shuffle,
|
||||||
|
collate_fn=self.generate_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_data, _ = self.token_to_idx(filepath=val_filepath)
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_data,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=False, # 验证集不打乱
|
||||||
|
collate_fn=self.generate_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, test_loader, val_loader
|
||||||
|
|
||||||
|
def generate_batch(self, data_batch):
|
||||||
|
"""
|
||||||
|
对每个批次中的样本进行处理的函数,将作为一个参数传入DataLoader的构造函数
|
||||||
|
:param data_batch: 一个批次的数据
|
||||||
|
"""
|
||||||
|
batch_seqs, batch_labels = [], []
|
||||||
|
|
||||||
|
# 遍历一个批次内的样本,取出序列和标签
|
||||||
|
for seq, label in data_batch:
|
||||||
|
batch_seqs.append(seq)
|
||||||
|
batch_labels.append(label)
|
||||||
|
|
||||||
|
batch_seqs = pad_sequence(
|
||||||
|
batch_seqs,
|
||||||
|
padding_value=self.PAD_IDX,
|
||||||
|
max_len=self.max_sen_len,
|
||||||
|
batch_first=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
batch_labels = torch.tensor(batch_labels, dtype=torch.long)
|
||||||
|
|
||||||
|
return batch_seqs, batch_labels
|
Loading…
Reference in New Issue