MyBERT/utils/create_pretraining_data.py

457 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import logging
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from .data_helpers import Vocab
from .data_helpers import pad_sequence
from .data_helpers import process_cache
def format_wikitext2(filepath=None, sep=" . "):
"""
格式化原始的wikitext2数据集
:return: 返回一个二维list外层list元素为一个文本段落内层list元素为一个段落中的句子 `[[para1_sen1, para1_sen2, ...], [para2_sen1, para2_sen2, ...], ...]`
"""
with open(filepath, "r", encoding="utf-8") as f:
lines = f.readlines() # 读取所有行,每一行为一个文本段落
paragraphs = []
for line in tqdm(lines, ncols=80, desc="## 正在读取wikitext2原始数据"):
# 将段落转为小写,并按分隔符分为句子
sentences = line.lower().split(sep)
# 给每个句子加上分隔符,并且去除最后的空句
tmp_sens = []
for sen in sentences:
sen = sen.strip()
if len(sen) == 0:
continue
sen += sep
tmp_sens.append(sen)
# 若段落内少于两条句子则舍弃因为NSP任务的输入需要一对句子
if len(tmp_sens) < 2:
continue
paragraphs.append(tmp_sens)
random.shuffle(paragraphs) # 将所有段落打乱
return paragraphs
def format_songci(filepath=None, sep=""):
"""格式化原始的宋词数据集"""
with open(filepath, "r", encoding="utf-8") as f:
lines = f.readlines() # 一次读取所有行,每一行为一首词
paragraphs = []
for line in tqdm(lines, ncols=80, desc="## 正在读取宋词原始数据"):
# 去除有乱码字符的段落
if "" in line or "……" in line:
continue
sentences = line.split(sep)
# 给每个句子加上分隔符,并且去除最后的空句
tmp_sens = []
for sen in sentences:
sen = sen.strip()
if len(sen) == 0:
continue
sen += sep
tmp_sens.append(sen)
# 去除少于两个句子的段落
if len(tmp_sens) < 2:
continue
paragraphs.append(tmp_sens)
random.shuffle(paragraphs) # 将所有段落打乱
return paragraphs
def format_custom(filepath=None, sep=None):
"""格式化自定义的数据集"""
raise NotImplementedError(
"本函数未实现,请参照 `format_wikitext2()` 或 `format_songci()` 函数返回格式进行实现"
)
class LoadPretrainingDataset(object):
"""加载预训练数据集"""
def __init__(
self,
vocab_path="./vocab.txt",
tokenizer=None,
batch_size=32,
max_sen_len=None,
max_position_embeddings=512,
split_sep="",
pad_index=0,
is_sample_shuffle=True,
dataset_name="wikitext2",
masked_rate=0.15,
masked_token_rate=0.8,
masked_token_unchanged_rate=0.5,
random_seed=2023,
):
self.vocab = Vocab(vocab_path)
self.tokenizer = tokenizer
self.batch_size = batch_size
# min(max_sen_len, max_position_embeddings)决定了输入序列长度
if isinstance(max_sen_len, int) and max_sen_len > max_position_embeddings:
max_sen_len = max_position_embeddings
self.max_sen_len = max_sen_len
self.max_position_embeddings = max_position_embeddings
self.split_sep = split_sep
self.PAD_IDX = pad_index
self.CLS_IDX = self.vocab["[CLS]"]
self.SEP_IDX = self.vocab["[SEP]"]
self.MASK_IDX = self.vocab["[MASK]"]
self.is_sample_shuffle = is_sample_shuffle
self.dataset_name = dataset_name
self.masked_rate = masked_rate
self.masked_token_rate = masked_token_rate
self.masked_token_unchanged_rate = masked_token_unchanged_rate
self.random_seed = random_seed
def format_data(self, filepath):
"""
将原始数据集格式化为标准形式
:return: `[[para1_sen1, para1_sen2, ...], [para2_sen1, para2_sen2, ...], ...]`
"""
# 依据数据集名称调用对应的格式化函数,注意:格式化函数返回格式需要保持一致
# wikitext2数据集
if self.dataset_name == "wikitext2":
return format_wikitext2(filepath, self.split_sep)
# 宋词数据集
elif self.dataset_name == "songci":
return format_songci(filepath, self.split_sep)
# 其他自定义数据集
elif self.dataset_name == "custom":
return format_custom(filepath)
else:
raise ValueError(
f"数据集 {self.dataset_name} 不存在对应的格式化函数,"
f"请参考函数 `format_wikitext2()` 实现对应的格式化函数!"
)
@staticmethod
def get_next_sentence_sample(sentence, next_sentence, paragraphs):
"""由给定的连续两个序列和所有文本段落,生成一个句子对样本"""
# 正负样本数量相同
# 传入的两个句子构成正样本标签为True
if random.random() < 0.5:
is_next = True
# 构造负样本标签为False
else:
# 先随机选中一个段落,再随机选中一个句子
new_next_sentence = next_sentence
while next_sentence == new_next_sentence: # 避免随机选中的下一个句子和传入的下一个句子相同
new_next_sentence = random.choice(random.choice(paragraphs))
next_sentence = new_next_sentence
is_next = False
return sentence, next_sentence, is_next
def masking_tokens(self, token_ids, candidate_mask_positions, num_mask_ids):
"""依据需要mask的tokens数量和候选mask位置对token_ids进行mask"""
# MLM任务样本的标签——若token被mask则对应标签为词表内的索引id若没有被mask则标签为PAD_IDX在计算loss时会被忽略
mlm_label = [self.PAD_IDX] * len(token_ids)
mask_ct = 0 # 记录已被mask的tokens数量
for mask_pos in candidate_mask_positions:
# 被mask的tokens数量已达到要求
if mask_ct >= num_mask_ids:
break
new_token_id = None # 用于mask即替换的token id
# 15%的tokens中的80%替换为[MASK] token
if random.random() < self.masked_token_rate:
new_token_id = self.MASK_IDX
else:
# 15%的tokens中的10%保持不变
# 20% * 0.5 = 10%
if random.random() < self.masked_token_unchanged_rate:
new_token_id = token_ids[mask_pos]
# 15%的tokens中的最后10%随机替换为一个词表内的token
else:
new_token_id = random.randint(0, len(self.vocab.itos) - 1)
# 保存原索引并mask
mlm_label[mask_pos] = token_ids[mask_pos]
token_ids[mask_pos] = new_token_id
mask_ct += 1
return token_ids, mlm_label
def get_masked_sample(self, token_ids):
"""
对token_ids进行mask处理
:param token_ids: e.g. `[101, 1031, 4895, 2243, 1033, 10029, 2000, 2624, 1031,....]`
:return mlm_id_seq: `[101, 1031, 103, 2243, 1033, 10029, 2000, 103, 1031, ...]`
mlm_label: `[ 0, 0, 4895, 0, 0, 0, 0, 2624, 0,...]`
"""
candidate_mask_positions = [] # 候选mask位置
for pos, id in enumerate(token_ids):
# 在MLM任务中不会mask特殊token
if id in [self.CLS_IDX, self.SEP_IDX]:
continue
candidate_mask_positions.append(pos) # 例如[2, 3, 4, 5, ....]
random.shuffle(candidate_mask_positions) # 打乱候选mask位置
# 计算需要被mask的tokens数量BERT模型中的默认mask比例是15%
num_mask_ids = max(1, round(len(token_ids) * self.masked_rate))
logging.debug(f"## 被Mask的tokens数量为{num_mask_ids}")
mlm_id_seq, mlm_label = self.masking_tokens(
token_ids, candidate_mask_positions, num_mask_ids
)
return mlm_id_seq, mlm_label
@process_cache(
unique_keys=[
"max_sen_len",
"masked_rate",
"masked_token_rate",
"masked_token_unchanged_rate",
"random_seed",
]
)
def data_process(self, filepath=None):
"""构造NSP和MLM两个预训练任务接收格式的样本"""
paragraphs = self.format_data(filepath) # 格式化原始数据
data = [] # 每个元素为一个样本包括Masked索引序列、token_type_id序列、MLM、NSP任务的标签
max_len = 0 # 保存最长序列长度
desc = f"## 正在处理NSP和MLM预训练数据集 {filepath.split(os.sep)[-1]}"
for para in tqdm(paragraphs, ncols=80, desc=desc): # 遍历每个段落
for i in range(len(para) - 1): # 遍历每个句子
# 生成一条句子对样本及标签
sen, next_sen, is_next = self.get_next_sentence_sample(
para[i], para[i + 1], paragraphs
)
logging.debug(f"## 当前句子文本:{sen}")
logging.debug(f"## 下一句文本:{next_sen}")
logging.debug(f"## 句子对标签:{is_next}")
# 下一句为空或者只有一个字符,舍弃
if len(next_sen) < 2:
logging.warning(
f"句子 '{sen}' 的下一句 '{next_sen}' 为空应舍弃此时NSP标签为{is_next}"
)
continue
# 分词、转换为索引序列
id_seq1 = [self.vocab[token] for token in self.tokenizer(sen)]
id_seq2 = [self.vocab[token] for token in self.tokenizer(next_sen)]
# 拼接两个句子的索引序列,并加上[CLS]、[SEP] token
id_seq = [self.CLS_IDX] + id_seq1 + [self.SEP_IDX] + id_seq2
# BERT模型最大支持512个token的序列若超过则截断
if len(id_seq) > self.max_position_embeddings - 1:
id_seq = id_seq[: self.max_position_embeddings - 1]
id_seq += [self.SEP_IDX]
assert len(id_seq) <= self.max_position_embeddings
# 创建token_type_id序列用于表示token所在序列
seg1 = [0] * (len(id_seq1) + 2) # 起始[CLS]和中间的[SEP]两个token属于第一个序列
seg2 = [1] * (len(id_seq) - len(seg1)) # 末尾的[SEP]token则属于第二个序列
seg = seg1 + seg2
assert len(seg) == len(id_seq)
logging.debug(
f"## Mask之前tokens{[self.vocab.itos[id] for id in id_seq]}"
)
logging.debug(f"## Mask之前token ids{id_seq}")
logging.debug(f"## segment ids{seg},序列长度:{len(seg)}")
# 对token索引序列进行mask操作生成Masked序列样本及标签
mlm_id_seq, mlm_label = self.get_masked_sample(id_seq)
logging.debug(
f"## Mask之后tokens{[self.vocab.itos[id] for id in mlm_id_seq]}"
)
logging.debug(f"## Mask之后token ids{mlm_id_seq}")
logging.debug(f"## Mask之后labels{mlm_label}")
logging.debug("=" * 20)
id_seq = torch.tensor(mlm_id_seq, dtype=torch.long)
seg = torch.tensor(seg, dtype=torch.long)
mlm_label = torch.tensor(mlm_label, dtype=torch.long)
nsp_label = torch.tensor(int(is_next), dtype=torch.long)
max_len = max(max_len, id_seq.size(0))
data.append([id_seq, seg, mlm_label, nsp_label])
return {"data": data, "max_len": max_len}
def generate_batch(self, data_batch):
"""
对每个批次中的样本进行处理的函数将作为一个参数传入DataLoader的构造函数
:param data_batch: 一个批次的数据
"""
b_id_seqs, b_segs, b_mlm_labels, b_nsp_labels = [], [], [], []
# 遍历一个批次内的样本取出索引序列、token_type_id序列和MLM、NSP任务的样本标签
for id_seq, seg, mlm_label, nsp_label in data_batch:
b_id_seqs.append(id_seq)
b_segs.append(seg)
b_mlm_labels.append(mlm_label)
b_nsp_labels.append(nsp_label)
# 填充
# #[max_sen_len, batch_size]
b_id_seqs = pad_sequence(
b_id_seqs,
padding_value=self.PAD_IDX,
max_len=self.max_sen_len,
batch_first=False,
)
b_segs = pad_sequence(
b_segs,
padding_value=self.PAD_IDX,
max_len=self.max_sen_len,
batch_first=False,
)
b_mlm_labels = pad_sequence(
b_mlm_labels,
padding_value=self.PAD_IDX,
max_len=self.max_sen_len,
batch_first=False,
)
# 生成Padding mask
# #[batch_size, max_sen_len]
b_mask = (b_id_seqs == self.PAD_IDX).transpose(0, 1)
# #[batch_size, ]
b_nsp_labels = torch.tensor(b_nsp_labels, dtype=torch.long)
return b_id_seqs, b_segs, b_mask, b_mlm_labels, b_nsp_labels
def data_loader(
self,
train_filepath=None,
val_filepath=None,
test_filepath=None,
only_test=False,
):
"""
创建DataLoader
:param only_test: 是否只返回测试集
"""
test_data = self.data_process(filepath=test_filepath)["data"]
test_loader = DataLoader(
test_data,
batch_size=self.batch_size,
shuffle=False, # 测试集不打乱
collate_fn=self.generate_batch,
)
if only_test:
logging.info(f"## 成功返回测试集,包含样本{len(test_loader.dataset)}")
return test_loader
tmp_data = self.data_process(filepath=train_filepath)
train_data, max_len = tmp_data["data"], tmp_data["max_len"]
if self.max_sen_len == "same":
self.max_sen_len = max_len
train_loader = DataLoader(
train_data,
batch_size=self.batch_size,
shuffle=self.is_sample_shuffle,
collate_fn=self.generate_batch,
)
val_data = self.data_process(filepath=val_filepath)["data"]
val_loader = DataLoader(
val_data,
batch_size=self.batch_size,
shuffle=False, # 验证集不打乱
collate_fn=self.generate_batch,
)
logging.info(
f"## 成功返回训练集样本{len(train_loader.dataset)}个,验证集样本{len(val_loader.dataset)}个,"
f"测试集样本{len(test_loader.dataset)}"
)
return train_loader, val_loader, test_loader
def get_inference_samples(self, sentences=None, is_masked=False):
"""
制作推理阶段输入模型的样本
:param sentences: 列表,每个元素表示一个文本段落
:param is_masked: 传入的句子是否已被mask
"""
# sentences可能是单个文本段落字符串需要转换为列表
if not isinstance(sentences, list):
sentences = [sentences]
mask_token = self.vocab.itos[self.MASK_IDX] # [MASK] token
b_id_seq = [] # 保存所有样本句子的索引序列
b_mask_pos = [] # 保存所有样本句子中的mask_token位置
for sentence in sentences:
# 分词注意推理阶段只需把段落中所有token分开即可不用考虑上下句关系或者说推理时只考察MLM能力
token_seq = self.tokenizer(sentence)
# 传入的句子没有被mask则执行mask
if not is_masked:
# 候选的mask位置
candidate_mask_positions = [pos for pos in range(len(token_seq))]
random.shuffle(candidate_mask_positions) # 打乱以实现随机mask
# 注意推理时的mask比例可以设置为任意值
num_mask_tokens = max(1, round(len(token_seq) * self.masked_rate))
# 执行mask
for pos in candidate_mask_positions[:num_mask_tokens]:
token_seq[pos] = mask_token
# 转换为索引序列
id_seq = [self.vocab[token] for token in token_seq]
# 加上[CLS]和[SEP] tokens
id_seq = [self.CLS_IDX] + id_seq + [self.SEP_IDX]
# 得到被mask的token位置包含[CLS]、[SEP] token在内的序列内位置
b_mask_pos.append(self.get_mask_pos(id_seq))
b_id_seq.append(torch.tensor(id_seq, dtype=torch.long))
# 填充,按一个批次内的最长序列长度填充
b_id_seq = pad_sequence(
b_id_seq,
padding_value=self.PAD_IDX,
max_len=None,
batch_first=False,
)
b_mask = (b_id_seq == self.PAD_IDX).transpose(0, 1)
return b_id_seq, b_mask_pos, b_mask
def get_mask_pos(self, id_seq):
"""返回id_seq中[MASK] token所在的位置"""
mask_positions = []
for pos, id in enumerate(id_seq):
if id == self.MASK_IDX:
mask_positions.append(pos)
return mask_positions