实现BertModel类

This commit is contained in:
huaian_zhou 2023-12-11 11:23:04 +08:00
parent b2136f965b
commit c553948cc5
13 changed files with 1224 additions and 70 deletions

View File

@ -1 +0,0 @@
![Alt text](img/bert_framework.png)

0
model/BERT/__init__.py Normal file
View File

376
model/BERT/bert.py Normal file
View File

@ -0,0 +1,376 @@
import os
import logging
import torch
import torch.nn as nn
from copy import deepcopy
from torch.nn.init import normal_
from .config import BertConfig
from .embedding import BertEmbedding
from .transformer import MultiheadAttention
def get_activation(activation_string):
"""将字符串转换为激活函数"""
activation = activation_string.lower()
if activation == "linear":
return None
elif activation == "relu":
return nn.ReLU()
elif activation == "gelu":
return nn.GELU()
elif activation == "tanh":
return nn.Tanh()
else:
raise ValueError("Unsupported activation: %s" % activation)
class BertSelfAttention(nn.Module):
"""多头自注意力模块"""
def __init__(self, config: BertConfig):
super(BertSelfAttention, self).__init__()
# 使用Pytorch中的多头注意力模块
if "use_torch_multi_head" in config.__dict__ and config.use_torch_multi_head:
MultiHeadAttention = nn.MultiheadAttention
# 使用自实现的多头注意力模块
else:
MultiHeadAttention = MultiheadAttention
self.multi_head_attention = MultiHeadAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_probs_dropout_prob,
)
def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
return self.multi_head_attention(
query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask
)
class BertSelfOutput(nn.Module):
"""自注意力模块后的残差连接和标准化"""
def __init__(self, config: BertConfig):
super().__init__()
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
"""
Args:
hidden_states: 多头自注意力模块输出 `#[src_len, batch_size, hidden_size]`
input_tensor: 多头自注意力模块输入 `#[src_len, batch_size, hidden_size]`
"""
hidden_states = self.dropout(hidden_states)
hidden_states = self.layer_norm(input_tensor + hidden_states)
return hidden_states
class BertAttention(nn.Module):
"""完整的自注意力模块"""
def __init__(self, config: BertConfig):
super().__init__()
self.self_attention = BertSelfAttention(config)
self.self_output = BertSelfOutput(config)
def forward(self, hidden_states, attention_mask=None):
"""
Args:
hidden_states: 自注意力模块输入 `#[src_len, batch_size, hidden_size]`
attention_mask: Padding mask需要被mask的token用`True`表示否者用`False`表示 `#[batch_size, src_len]`
"""
# self_attn返回编码结果和注意力权重矩阵
attn_outputs = self.self_attention(
hidden_states,
hidden_states,
hidden_states,
attn_mask=None,
key_padding_mask=attention_mask, # 注意attention_mask是填充掩码而不是注意力掩码
)
# attn_outputs[0]: #[src_len, batch_size, hidden_size]
output = self.self_output(attn_outputs[0], hidden_states)
return output
class BertIntermediate(nn.Module):
"""
自注意力模块后的线性层即Transformer FFN中的第一个线性层
"""
def __init__(self, config: BertConfig):
super().__init__()
# 线性层
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
# 激活函数
if isinstance(config.hidden_act, str):
self.inter_activation = get_activation(config.hidden_act)
else:
self.inter_activation = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
if self.inter_activation is None:
hidden_states = hidden_states
else:
hidden_states = self.inter_activation(hidden_states)
return hidden_states
class BertOutput(nn.Module):
"""
第二个线性层及残差连接标准化等模块即Transformer FFN中的第二个线性层
"""
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
"""
Args:
hidden_states: 第一个线性层输出 `#[src_len, batch_size, intermediate_size]`
input_tensor: 第一个线性层输入 `#[src_len, batch_size, hidden_size]`
Return:
`#[src_len, batch_size, hidden_size]`
"""
# #[src_len, batch_size, hidden_size]
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.layer_norm(input_tensor + hidden_states)
return hidden_states
class BertLayer(nn.Module):
"""单个Encoder Layer"""
def __init__(self, config):
super().__init__()
self.bert_attention = BertAttention(config)
self.bert_intermediate = BertIntermediate(config)
self.bert_output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None):
"""
Args:
hidden_states: `#[src_len, batch_size, hidden_size]`
attention_mask: padding mask `#[batch_size, src_len]`
Return:
`#[src_len, batch_size, hidden_size]`
"""
# #[src_len, batch_size, hidden_size]
attn_output = self.bert_attention(hidden_states, attention_mask)
# #[src_len, batch_size, intermediate_size]
inter_output = self.bert_intermediate(attn_output)
# #[src_len, batch_size, hidden_size]
output = self.bert_output(inter_output, attn_output)
return output
class BertEncoder(nn.Module):
"""
Encoder由多个Encoder Layer堆叠而成
"""
def __init__(self, config: BertConfig):
super().__init__()
self.config = config
# 创建num_hidden_layers个Encoder Layer
self.bert_layers = nn.ModuleList(
[BertLayer(config) for _ in range(config.num_hidden_layers)]
)
def forward(self, hidden_states, attention_mask=None):
all_encoder_layers = [] # 保存所有Encoder Layer的输出
output = hidden_states
for _, layer in enumerate(self.bert_layers):
output = layer(output, attention_mask)
all_encoder_layers.append(output)
return all_encoder_layers
class BertPooler(nn.Module):
"""
用于获取整个句子的语义信息
"""
def __init__(self, config: BertConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
if "pooler_type" not in self.config.__dict__:
raise ValueError(
"pooler_type must be in ['first_token_transform', 'all_token_average']"
"请在配置文件config.json中添加一个pooler_type参数"
)
# 取第一个token即[cls] token embedding
if self.config.pooler_type == "first_token_transform":
# #[batch_size, hidden_size]
token_tensor = hidden_states[0, :].reshape(-1, self.config.hidden_size)
# 取所有token embedding的平均值
elif self.config.pooler_type == "all_token_average":
token_tensor = torch.mean(hidden_states, dim=0)
# #[batch_size, hidden_size]
output = self.dense(token_tensor)
output = self.activation(output)
return output
def format_params_for_torch(loaded_params_names, loaded_params):
"""
将加载的预训练模型参数格式化为符合torch(1.12.0)框架中MultiHeadAttention的形式QKV weight/bias放在一个tnesor中
"""
qkv_weight_names = ["query.weight", "key.weight", "value.weight"]
qkv_bias_names = ["query.bias", "key.bias", "value.bias"]
qkv_weight, qkv_bias = [], []
torch_params = []
for i in range(len(loaded_params_names)):
param_name_in_pretrained = loaded_params_names[i]
param_name = ".".join(param_name_in_pretrained.split(".")[-2:])
if param_name in qkv_weight_names:
qkv_weight.append(loaded_params[param_name_in_pretrained])
elif param_name in qkv_bias_names:
qkv_bias.append(loaded_params[param_name_in_pretrained])
else:
torch_params.append(loaded_params[param_name_in_pretrained])
if len(qkv_weight) == 3:
torch_params.append(torch.cat(qkv_weight, dim=0))
qkv_weight = []
if len(qkv_bias) == 3:
torch_params.append(torch.cat(qkv_bias, dim=0))
qkv_bias = []
return torch_params
def load_512_position(init_embedding, loaded_embedding):
"""
预训练的BERT模型仅支持最大512个`position_ids`而自定义的模型配置中
`max_positional_embeddings`可能大于512所以加载时用预训练模型的`positional embedding`矩阵替换随机初始化的`positional embedding`矩阵前512行\
Args:
init_embedding: 随机初始化的positional embedding矩阵可能大于512行
loaded_embedding: 加载的预训练模型的positional embedding矩阵等于512行
"""
logging.info(f"模型配置 max_positional_embeddings > 512")
init_embedding[:512, :] = loaded_embedding[:512, :]
return init_embedding
class BertModel(nn.Module):
def __init__(self, config: BertConfig):
super().__init__()
self.config = config
self.bert_embedding = BertEmbedding(config)
self.bert_encoder = BertEncoder(config)
self.bert_pooler = BertPooler(config)
self._reset_parameters()
def forward(
self,
input_ids=None,
position_ids=None,
token_type_ids=None,
attention_mask=None,
):
"""
Args:
input_ids: `#[src_len, batch_size]`
position_ids: `#[1, src_len]`
token_type_ids: `#[src_len, batch_size]`
attention_mask: `#[batch_size, src_len]`
Return:
`#[src_len, batch_size, hidden_size]`
"""
input_embed = self.bert_embedding(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
)
# encoder_outputs包含num_hidden_layers个encoder layers的输出
encoder_outputs = self.bert_encoder(
hidden_states=input_embed, attention_mask=attention_mask
)
# 取最后一层encoder layer的输出结果传入pooler获取整个句子的语义信息
sequence_output = encoder_outputs[-1] # #[src_len, batch_size, hidden_size]
pooled_output = self.bert_pooler(sequence_output) # #[batch_size, hidden_size]
return pooled_output, encoder_outputs
def _reset_parameters(self):
"""初始化参数"""
for param in self.parameters():
if param.dim() > 1:
normal_(param, mean=0.0, std=self.config.initializer_range)
@classmethod
def from_pretrained(cls, config: BertConfig, pretrained_model_dir=None):
"""从预训练模型文件创建模型"""
model = cls(config) # 创建模型cls表示类名BertModel
# 加载预训练模型
pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")
if not os.path.exists(pretrained_model_path):
raise ValueError(
f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"
f"中文模型下载地址https://huggingface.co/bert-base-chinese/tree/main\n"
f"英文模型下载地址https://huggingface.co/bert-base-uncased/tree/main\n"
)
loaded_params = torch.load(pretrained_model_path)
loaded_params_names = list(loaded_params.keys())[:-8]
model_params = deepcopy(model.state_dict())
model_params_names = list(model_params.keys())[1:]
if "use_torch_multi_head" in config.__dict__ and config.use_torch_multi_head:
logging.info(f"## 注意正在使用torch框架中的MultiHeadAttention实现")
torch_params = format_params_for_torch(loaded_params_names, loaded_params)
for i in range(len(model_params_names)):
logging.debug(
f"## 成功赋值参数 {model_params_names[i]} 参数形状为 {torch_params[i].size()}"
)
if "position_embedding" in model_params_names[i]:
if config.max_position_embeddings > 512:
new_embedding = load_512_position(
model_params[model_params_names[i]],
torch_params[i],
)
model_params[model_params_names[i]] = new_embedding
continue
model_params[model_params_names[i]] = torch_params[i]
else:
logging.info(
f"## 注意正在使用本地transformer.py中的MultiheadAttention实现"
f"如需使用torch框架中的MultiHeadAttention模块可设置config.__dict__['use_torch_multi_head'] = True实现"
)
for i in range(len(loaded_params_names)):
logging.debug(
f"## 成功将参数 {loaded_params_names[i]} 赋值给 {model_params_names[i]} "
f"参数形状为 {loaded_params[loaded_params_names[i]].size()}"
)
if "position_embedding" in model_params_names[i]:
if config.max_position_embeddings > 512:
new_embedding = load_512_position(
model_params[model_params_names[i]],
loaded_params[loaded_params_names[i]],
)
model_params[model_params_names[i]] = new_embedding
continue
# 把加载的预训练模型参数值赋给新创建模型
model_params[model_params_names[i]] = loaded_params[
loaded_params_names[i]
]
model.load_state_dict(model_params)
return model

View File

@ -1,7 +1,5 @@
import json
import copy
# import six
import logging
@ -85,19 +83,3 @@ class BertConfig(object):
"""把对象转换为字典"""
out = copy.deepcopy(self.__dict__)
return out
if __name__ == "__main__":
import os
import sys
sys.path.append(os.getcwd())
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
for key, value in config.__dict__.items():
print(f"{key} = {value}")
print("=" * 20)
print(config.to_json_str())

View File

@ -1,7 +1,7 @@
import torch
import torch.nn as nn
from .config import BertConfig
from torch.nn.init import normal_
from bert_config import BertConfig
class TokenEmbedding(nn.Module):
@ -126,7 +126,7 @@ class BertEmbedding(nn.Module):
initializer_range=config.initializer_range,
)
self.layernorm = nn.LayerNorm(config.hidden_size) # 层标准化
self.layer_norm = nn.LayerNorm(config.hidden_size) # 层标准化
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# 提前创建所有position id #[1, max_position_embeddings]
@ -161,55 +161,7 @@ class BertEmbedding(nn.Module):
# 相加
input_embed = token_embed + pos_embed + seg_embed
input_embed = self.layernorm(input_embed)
input_embed = self.layer_norm(input_embed)
input_embed = self.dropout(input_embed)
return input_embed
if __name__ == "__main__":
import os
import sys
sys.path.append(os.getcwd())
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
src = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
src = src.transpose(0, 1) # #[src_len, batch_size] [5, 2]
print("***** --------- 测试TokenEmbedding ------------")
token_embedding = TokenEmbedding(vocab_size=16, hidden_size=32)
token_embed = token_embedding(input_ids=src)
print("src shape #[src_len, batch_size]: ", src.shape)
print(
f"token embedding shape #[src_len, batch_size, hidden_size]: {token_embed.shape}\n"
)
print("***** --------- 测试PositionalEmbedding ------------")
# #[1, src_len]
position_ids = torch.arange(src.shape[0]).expand((1, -1))
position_embedding = PositionalEmbedding(max_position_embeddings=8, hidden_size=32)
pos_embed = position_embedding(position_ids=position_ids)
# print(position_embedding.embedding.weight) # embedding 矩阵
print("position_ids shape #[1, src_len]: ", position_ids.shape)
print(f"positional embedding shape #[src_len, 1, hidden_size]: {pos_embed.shape}\n")
print("***** --------- 测试SegmentEmbedding ------------")
token_type_ids = torch.tensor(
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.long
).transpose(0, 1)
segmet_embedding = SegmentEmbedding(type_vocab_size=2, hidden_size=32)
seg_embed = segmet_embedding(token_type_ids)
print("token_type_ids shape #[src_len, batch_size]: ", token_type_ids.shape)
print(
f"segment embedding shape #[src_len, batch_size, hidden_size]: {seg_embed.shape}\n"
)
print("***** --------- 测试BertEmbedding ------------")
bert_embedding = BertEmbedding(config)
input_embed = bert_embedding(src, token_type_ids=token_type_ids)
print(
f"input embedding shape #[src_len, batch_size, hidden_size]: {input_embed.shape}"
)

565
model/BERT/transformer.py Normal file
View File

@ -0,0 +1,565 @@
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_
class MultiheadAttention(nn.Module):
"""
多头注意力机制
"""
def __init__(self, embed_dim=512, num_heads=8, dropout=0.0, bias=True):
"""
:param embed_dim: 词嵌入维度即参数d_model
:param num_heads: 多头注意力机制中头的数量即参数nhead
:param bias: 线性变换时是否使用偏置
"""
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads # head_dim是指单头注意力中变换矩阵的列数也即qkv向量的维度
self.kdim = self.head_dim
self.vdim = self.head_dim
self.dropout = dropout
assert (
self.head_dim * self.num_heads == self.embed_dim
), "embed_dim除以num_heads必须为整数"
# 原论文中的 d_k = d_v = d_model/nhead 限制条件
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
# 变换矩阵W_qembed_dim = num_heads * kdimkdim=qdim
# 第二个维度之所以是embed_dim因为这里同时初始化了num_heads个W_q也就是num_heads个头然后横向拼接
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
# W_kembed_dim = num_heads * kdim
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
# W_vembed_dim = num_heads * vdim
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
# 将多头注意力计算结果(横向拼接)再执行一次线性转换后输出
self._reset_parameters()
def forward(
self,
query,
key,
value,
attn_mask=None,
key_padding_mask=None,
training=True,
is_print_shape=False,
):
"""
Encoder中querykeyvalue都是源序列src seq\\
Decoder中querykeyvalue都是目标序列tgt seq\\
Encoder和Decoder交互时keyvalue指的是Encoder memoryquery指的是tgt seq\\
:param query: # [tgt_len, batch_size, embed_dim]
:param key: # [src_len, batch_size, embed_dim]
:param value: # [src_len, batch_size, embed_dim]
:param attn_mask: 注意力掩码矩阵 # [tgt_len, src_len] 或 [batch_size * num_heads, tgt_len, src_len]
一般只在Decoder的Training中使用因为训练时并行传入所有tgt tokens需要掩盖当前时刻之后的tokens信息
:param key_padding_mask: 对Padding tokens进行掩码 # [batch_size, src_len]
:return:
attn_output: 多头注意力计算结果 # [tgt_len, batch_size, embed_dim]
attn_output_weights: 多头注意力平均权重矩阵 # [batch_size, tgt_len, src_len]
"""
# 1.计算Q、K、V
# 注意query、key、value是没有经过线性变换前的序列例如在Encoder中都是源序列src seq
Q = self.q_proj(query)
# [tgt_len, batch_size, embed_dim] x [embed_dim, num_heads * kdim] = [tgt_len, batch_size, num_heads * kdim]
K = self.k_proj(key)
# [src_len, batch_size, embed_dim] x [embed_dim, num_heads * kdim] = [src_len, batch_size, num_heads * kdim]
V = self.v_proj(value)
# [src_len, batch_size, embed_dim] x [embed_dim, num_heads * vdim] = [src_len, batch_size, num_heads * vdim]
if is_print_shape:
print("=" * 80)
print("开始计算多头注意力:")
print(
f"\t 多头数num_heads = {self.num_heads}d_model={query.size(-1)}d_k = d_v = d_model/num_heads={query.size(-1) // self.num_heads}"
)
print(f"\t query的shape([tgt_len, batch_size, embed_dim]){query.shape}")
print(
f"\t W_q的shape([embed_dim, num_heads * kdim]){self.q_proj.weight.shape}"
)
print(f"\t Q的shape([tgt_len, batch_size, num_heads * kdim]){Q.shape}")
print("\t" + "-" * 70)
print(f"\t key的shape([src_len, batch_size, embed_dim]){key.shape}")
print(
f"\t W_k的shape([embed_dim, num_heads * kdim]){self.k_proj.weight.shape}"
)
print(f"\t K的shape([src_len, batch_size, num_heads * kdim]){K.shape}")
print("\t" + "-" * 70)
print(f"\t value的shape([src_len, batch_size, embed_dim]){value.shape}")
print(
f"\t W_v的shape([embed_dim, num_heads * vdim]){self.v_proj.weight.shape}"
)
print(f"\t V的shape([src_len, batch_size, num_heads * vdim]){V.shape}")
print("\t" + "-" * 70)
print(
"\t ***** 注意这里的W_q、W_k、W_v是多头注意力变换矩阵拼接的因此Q、K、V也是多个q、k、v向量拼接的结果 *****"
)
# 2.缩放并判断attn_mask维度是否正确
scaling = float(self.head_dim) ** -0.5 # 缩放系数
Q = Q * scaling
# [query_len, batch_size, num_heads * kdim]其中query_len就是tgt_len
src_len = key.size(0)
tgt_len, bsz, _ = query.size() # [tgt_len, batch_size, embed_dim]
if attn_mask is not None:
# [tgt_len, src_len] 或 [batch_size * num_heads, tgt_len, src_len]
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0) # [1, tgt_len, src_len]
if list(attn_mask.size()) != [1, tgt_len, src_len]:
raise RuntimeError("The size of the 2D attn_mask is not correct.")
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [
bsz * self.num_heads,
tgt_len,
src_len,
]:
raise RuntimeError("The size of the 3D attn_mask is not correct.")
# 此时atten_mask的维度变成了3
# 3.计算注意力得分
# 这里需要进行一下变形以便后续执行bmm运算
Q = (
Q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.kdim)
.transpose(0, 1)
)
# [batch_size * num_heads, tgt_len, kdim]
K = (
K.contiguous()
.view(src_len, bsz * self.num_heads, self.kdim)
.transpose(0, 1)
)
# [batch_size * num_heads, src_len, kdim]
V = (
V.contiguous()
.view(src_len, bsz * self.num_heads, self.vdim)
.transpose(0, 1)
)
# [batch_size * num_heads, src_len, vdim]
attn_weights = torch.bmm(Q, K.transpose(1, 2)) # bmm用于三维tensor的矩阵运算
# [batch_size * num_heads, tgt_len, kdim] x [batch_size * num_heads, kdim, src_len]
# -> [batch_size * num_heads, tgt_len, src_len] 这是num_heads个Q、K相乘后的注意力矩阵
# 4.进行掩码操作
# Attention mask
if attn_mask is not None:
attn_weights += attn_mask
# [batch_size * num_heads, tgt_len, src_len]
# Padding mask列Padding mask
if key_padding_mask is not None:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
# [batch_size, num_heads, tgt_len, src_len]
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
)
# key_padding_mask扩展维度 [batch_size, src_len] -> [batch_size, 1, 1, src_len]
# 然后对attn_output_weights进行掩码其中masked_fill会将值为True或非0值的列mask掉
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
# [batch_size * num_heads, tgt_len, src_len]
# 5.计算多头注意力输出
# 计算注意力权重
attn_weights = F.softmax(attn_weights, dim=-1)
# [batch_size * num_heads, tgt_len, src_len]
attn_weights = F.dropout(attn_weights, p=self.dropout, training=training)
# 计算MultiheadAttention(Q, K, V)
attn_output = torch.bmm(attn_weights, V)
# [batch_size * num_heads, tgt_len, src_len] x [batch_size * num_heads, src_len, vdim]
# -> [batch_size * num_heads, tgt_len, vdim]
# 最后执行一次线性变换,输出
attn_output = (
attn_output.transpose(0, 1)
.contiguous()
.view(tgt_len, bsz, self.num_heads * self.vdim)
)
# [tgt_len, batch_size, num_heads * vdim]
Z = self.out_proj(attn_output)
# [tgt_len, batch_size, embed_dim]
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
# [batch_size, num_heads, tgt_len, src_len]
if is_print_shape:
print(
f"\t 多头注意力计算结束后的形状(横向拼接)为([tgt_len, batch_size, num_heads * vdim]){attn_output.shape}"
)
print(
f"\t 对多头注意力计算结果进行线性变换的权重W_o形状为([num_heads * vdim, embed_dim]){self.out_proj.weight.shape}"
)
print(f"\t 多头注意力计算结果线性变换后的形状为([tgt_len, batch_size, embed_dim]){Z.shape}")
return (
Z,
attn_weights.sum(dim=1) / self.num_heads, # 返回多头注意力权重矩阵的平均值
)
def _reset_parameters(self):
"""
初始化参数
"""
for param in self.parameters():
if param.dim() > 1:
xavier_uniform_(param)
def _get_clones(module, N):
"""
对module进行N次拷贝
"""
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
class EncoderLayer(nn.Module):
"""
单个编码层
"""
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
"""
:param d_model: 模型中向量维度即词嵌入维度
:param nhead: 多头注意力中的多头数量
:param dim_feedforward: 全连接层的输出维度
:param dropout: 丢弃率
"""
super(EncoderLayer, self).__init__()
self.self_attn = MultiheadAttention(
embed_dim=d_model, num_heads=nhead, dropout=dropout
)
# 多头注意力输出后的Add&Norm
self.dropout1 = nn.Dropout(dropout)
# 注意LayerNorm是沿着feature维度归一化BatchNorm是沿着batch维度归一化
self.norm1 = nn.LayerNorm(d_model)
# Feed Forward Network
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.dropout3 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
:param src: # [src_len, batch_size, embed_dim]
:param src_mask: NoneEncoder中不需要Attention Mask
:param src_key_padding_mask: # [batch_size, src_len]
:return: # [src_len, batch_size, embed_dim] <==> [src_len, batch_size, num_heads * kdim]
"""
# 计算多头注意力
src1 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
# [src_len, batch_size, embed_dim]其中embed_dim = num_heads * kdim
# 残差连接和LayerNorm
src = src + self.dropout1(src1)
src = self.norm1(src)
# Feed Forward
src1 = self.activation(self.linear1(src))
# [src_len, batch_size, dim_feedforward]
src1 = self.linear2(self.dropout2(src1))
# [src_len, batch_size, embed_dim]
# 残差连接和LayerNorm
src = src + self.dropout3(src1)
src = self.norm2(src)
return src
class Encoder(nn.Module):
"""
编码器由多个编码层堆叠而成
"""
def __init__(self, encoder_layer, num_layers=6, norm=None):
"""
:param encoder_layer: 单个编码层
:param num_layers: 编码层数
:param norm: 归一化层
"""
super(Encoder, self).__init__()
# 拷贝多个编码层,得到编码层列表
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
:param src: # [src_len, batch_size, embed_dim]
:param src_mask: NoneEncoder中不需要Attention Mask
:param src_key_padding_mask: # [batch_size, src_len]
:return: # [src_len, batch_size, embed_dim] <==> [src_len, batch_size, num_heads * kdim]
"""
output = src
# 遍历每一个编码层执行forward并传递给下一层
for layer in self.layers:
output = layer(
output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
)
# 对最后一层输出执行Norm操作
if self.norm is not None:
output = self.norm(output)
return output
class DecoderLayer(nn.Module):
"""
单个解码层
"""
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
"""
:param d_model: 模型中向量维度即词嵌入维度
:param nhead: 多头注意力中的多头数量
:param dim_feedforward: 全连接层的输出维度
:param dropout: 丢弃率
"""
super(DecoderLayer, self).__init__()
# Masked多头注意力对解码层输入序列进行计算
self.self_attn = MultiheadAttention(
embed_dim=d_model, num_heads=nhead, dropout=dropout
)
# 编码器输出memory和解码层交互的多头注意力
self.multihead_attn = MultiheadAttention(
embed_dim=d_model, num_heads=nhead, dropout=dropout
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
# Feed Forward Network
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.activation = nn.ReLU()
self.dropout4 = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
"""
:param tgt: 解码层输入序列 # [tgt_len, batch_size, embed_dim]
:param memory: 编码器输出memory # [src_len, batch_size, embed_dim]
:param tgt_mask: 解码层多头注意力掩码 # [tgt_len, tgt_len]
:param memory_mask: 编码器-解码器交互多头注意力掩码一般为None
:param tgt_key_padding_mask: 解码器输入序列的Padding情况 # [batch_size, tgt_len]
:param memory_key_padding_mask: 编码器输入序列的Padding情况 # [batch_size, src_len]
:return: # [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
"""
# Masked多头注意力计算
tgt1 = self.self_attn(
tgt,
tgt,
tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
# 残差连接&LayerNorm
tgt = tgt + self.dropout1(tgt1)
tgt = self.norm1(tgt)
# 编码器-解码器交互多头注意力计算
tgt1 = self.multihead_attn(
tgt,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
# 残差连接&LayerNorm
tgt = tgt + self.dropout2(tgt1)
tgt = self.norm2(tgt)
# Feed Forward
tgt1 = self.activation(self.linear1(tgt))
# [tgt_len, batch_size, dim_feedforward]
tgt1 = self.linear2(self.dropout4(tgt1))
# [tgt_len, batch_size, embed_dim]
# 残差连接&LayerNorm
tgt = tgt + self.dropout3(tgt1)
tgt = self.norm3(tgt)
return tgt
class Decoder(nn.Module):
"""
解码器由多个解码层堆叠而成
"""
def __init__(self, decoder_layer, num_layers, norm=None):
"""
:param decoder_layer: 单个解码层
:param num_layers: 解码层数
:param norm: 归一化层
"""
super(Decoder, self).__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
"""
:param tgt: 解码层输入序列 # [tgt_len, batch_size, embed_dim]
:param memory: 编码器输出memory # [src_len, batch_size, embed_dim]
:param tgt_mask: 解码层多头注意力掩码 # [tgt_len, tgt_len]
:param memory_mask: 编码器-解码器交互多头注意力掩码一般为None
:param tgt_key_padding_mask: 解码器输入序列的Padding情况 # [batch_size, tgt_len]
:param memory_key_padding_mask: 编码器输入序列的Padding情况 # [batch_size, src_len]
:return: # [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
"""
output = tgt # [tgt_len,batch_size, embed_dim]
# 遍历每一个解码层执行forward并传递给下一层
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
# 对最后一层输出执行Norm操作
if self.norm is not None:
output = self.norm(output)
return output
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
):
"""
:param d_model: 模型中向量维度即词嵌入维度
:param nhead: 多头注意力中的多头数量
:param num_encoder_layers: EncoderLayer堆叠的数量
:param num_decoder_layers: DecoderLayer堆叠的数量
:param dim_feedforward: 全连接层的输出维度
:param dropout: 丢弃率
"""
super(Transformer, self).__init__()
# Encoder
encoder_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout)
encoder_norm = nn.LayerNorm(d_model)
self.encoder = Encoder(encoder_layer, num_encoder_layers, encoder_norm)
# Decoder
decoder_layer = DecoderLayer(d_model, nhead, dim_feedforward, dropout)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def forward(
self,
src,
tgt,
src_mask=None,
tgt_mask=None,
memory_mask=None,
src_key_padding_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
):
"""
:param src: # [src_len, batch_size, embed_dim]
:param tgt: # [tgt_len, batch_size, embed_dim]
:param src_mask: None
:param tgt_mask: # [tgt_len, tgt_len]
:param memory_mask: None
:param src_key_padding_mask: # [batch_size, src_len]
:param tgt_key_padding_mask: # [batch_size, tgt_len]
:param memory_key_padding_mask: # [batch_size, src_len]
:return: [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
"""
# Encoding生成memory
memory = self.encoder(
src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
)
# Decoding
output = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
)
return output
def _reset_parameters(self):
"""
初始化参数
"""
for param in self.parameters():
if param.dim() > 1:
xavier_uniform_(param)
def generate_attn_mask(self, sz):
"""
生成注意力掩码矩阵
"""
mask = torch.tril(torch.ones(sz, sz)) # tril取矩阵下三角包括对角线
mask = mask.masked_fill(mask == 0, float("-inf")).masked_fill(
mask == 1, float(0.0)
)
return mask

0
model/__init__.py Normal file
View File

93
test/test_bert.py Normal file
View File

@ -0,0 +1,93 @@
import os
import sys
sys.path.append(os.getcwd())
import logging
import torch
from util.log_helper import logger_init
from model.BERT.config import BertConfig
from model.BERT.bert import BertEmbedding
from model.BERT.bert import BertAttention
from model.BERT.bert import BertLayer
from model.BERT.bert import BertEncoder
from model.BERT.bert import BertModel
if __name__ == "__main__":
logger_init(log_filename="test", log_level=logging.DEBUG)
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
# # 使用torch框架中的MultiHeadAttention实现
# config.__dict__["use_torch_multi_head"] = True
config.max_position_embeddings = 518 # 测试max_position_embeddings大于512时的情况
# #[src_len, batch_size]
src = torch.tensor(
[[1, 3, 5, 7, 9, 2, 3], [2, 4, 6, 8, 10, 0, 0]], dtype=torch.long
).transpose(0, 1)
print(f"input shape #[src_len, batch_size]: ", src.shape)
# #[src_len, batch_size]
token_type_ids = torch.tensor(
[[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 0, 0]], dtype=torch.long
).transpose(0, 1)
# attention_mask实质上是padding mask #[src_len, batch_size]
attention_mask = torch.tensor(
[
[False, False, False, False, False, True, True],
[False, False, False, False, False, False, True],
]
)
print("------ 测试BertEmbedding -------")
bert_embedding = BertEmbedding(config)
bert_embed_out = bert_embedding(src, token_type_ids=token_type_ids)
print(
f"BertEmbedding output shape #[src_len, batch_size, hidden_size]: {bert_embed_out.shape}"
)
print("------ 测试BertAttention -------")
bert_attention = BertAttention(config)
bert_attn_out = bert_attention(bert_embed_out, attention_mask=attention_mask)
print(
f"BertAttention output shape #[src_len, batch_size, hidden_size]: {bert_attn_out.shape}",
)
print("------ 测试BertLayer -------")
bert_layer = BertLayer(config)
bert_layer_out = bert_layer(bert_embed_out, attention_mask)
print(
f"BertLayer output shape #[src_len, batch_size, hidden_size]: {bert_layer_out.shape}",
)
print("------ 测试BertEncoder -------")
bert_encoder = BertEncoder(config)
bert_encoder_outs = bert_encoder(bert_embed_out, attention_mask)
print(
f"Num of BertEncoder outputs #[config.num_hidden_layers]: {len(bert_encoder_outs)}",
)
print(
f"Each output shape of BertEncoder #[src_len, batch_size, hidden_size]: {bert_encoder_outs[0].shape}",
)
print("------ 测试BertModel -------")
position_ids = torch.arange(src.size()[0]).expand((1, -1)) # [1, src_len]
bert_model = BertModel(config)
bert_pooler_out = bert_model(
input_ids=src,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
print(
f"BertPooler output shape #[batch_size, hidden_size]: {bert_pooler_out.shape}",
)
print("======= BertModel参数: ========")
for param in bert_model.state_dict():
print(param, "\t #", bert_model.state_dict()[param].size())
print("------ 测试BertModel载入预训练模型 -------")
model = BertModel.from_pretrained(
config=config, pretrained_model_dir="./archive/bert_base_chinese"
)

15
test/test_config.py Normal file
View File

@ -0,0 +1,15 @@
import os
import sys
sys.path.append(os.getcwd())
from model.BERT.config import BertConfig
if __name__ == "__main__":
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
for key, value in config.__dict__.items():
print(f"{key} = {value}")
print("=" * 20)
print(config.to_json_str())

53
test/test_embedding.py Normal file
View File

@ -0,0 +1,53 @@
import os
import sys
sys.path.append(os.getcwd())
import torch
from model.BERT.config import BertConfig
from model.BERT.embedding import TokenEmbedding
from model.BERT.embedding import PositionalEmbedding
from model.BERT.embedding import SegmentEmbedding
from model.BERT.embedding import BertEmbedding
if __name__ == "__main__":
json_file = "./archive/bert_base_chinese/config.json"
config = BertConfig.from_json_file(json_file)
src = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
src = src.transpose(0, 1) # #[src_len, batch_size] [5, 2]
print("***** --------- 测试TokenEmbedding ------------")
token_embedding = TokenEmbedding(vocab_size=16, hidden_size=32)
token_embed = token_embedding(input_ids=src)
print("src shape #[src_len, batch_size]: ", src.shape)
print(
f"token embedding shape #[src_len, batch_size, hidden_size]: {token_embed.shape}\n"
)
print("***** --------- 测试PositionalEmbedding ------------")
# #[1, src_len]
position_ids = torch.arange(src.shape[0]).expand((1, -1))
position_embedding = PositionalEmbedding(max_position_embeddings=8, hidden_size=32)
pos_embed = position_embedding(position_ids=position_ids)
# print(position_embedding.embedding.weight) # embedding 矩阵
print("position_ids shape #[1, src_len]: ", position_ids.shape)
print(f"positional embedding shape #[src_len, 1, hidden_size]: {pos_embed.shape}\n")
print("***** --------- 测试SegmentEmbedding ------------")
token_type_ids = torch.tensor(
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.long
).transpose(0, 1)
segmet_embedding = SegmentEmbedding(type_vocab_size=2, hidden_size=32)
seg_embed = segmet_embedding(token_type_ids)
print("token_type_ids shape #[src_len, batch_size]: ", token_type_ids.shape)
print(
f"segment embedding shape #[src_len, batch_size, hidden_size]: {seg_embed.shape}\n"
)
print("***** --------- 测试BertEmbedding ------------")
bert_embedding = BertEmbedding(config)
input_embed = bert_embedding(src, token_type_ids=token_type_ids)
print(
f"input embedding shape #[src_len, batch_size, hidden_size]: {input_embed.shape}"
)

76
test/test_transformer.py Normal file
View File

@ -0,0 +1,76 @@
import os
import sys
sys.path.append(os.getcwd())
import torch
import torch.nn as nn
from model.BERT.transformer import MultiheadAttention
from model.BERT.transformer import EncoderLayer, Encoder
from model.BERT.transformer import DecoderLayer, Decoder
from model.BERT.transformer import Transformer
if __name__ == "__main__":
batch_size = 2
src_len = 5
tgt_len = 7
d_model = 32
nhead = 4
src = torch.rand((src_len, batch_size, d_model)) # [src_len, batch_size, embed_dim]
src_key_padding_mask = torch.tensor(
[[False, False, False, True, True], [False, False, False, False, True]]
) # [batch_size, src_len]
tgt = torch.rand((tgt_len, batch_size, d_model)) # [tgt_len, batch_size, embed_dim]
tgt_key_padding_mask = torch.tensor(
[
[False, False, False, False, False, False, True],
[False, False, False, False, True, True, True],
]
) # [batch_size, tgt_len]
print("============ 测试 MultiheadAttention ============")
my_mh = MultiheadAttention(embed_dim=d_model, num_heads=nhead)
my_mh_out = my_mh(src, src, src, key_padding_mask=src_key_padding_mask)
print(my_mh_out[0].shape) # [5, 2, 32]
mh = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead)
mh_out = mh(src, src, src, key_padding_mask=src_key_padding_mask)
print(mh_out[0].shape) # [5, 2, 32]
print("============ 测试 Encoder ============")
enlayer = EncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=512)
encoder = Encoder(enlayer, num_layers=4, norm=nn.LayerNorm(d_model))
memory = encoder(src, src_key_padding_mask=src_key_padding_mask)
print(memory.shape) # [5, 2, 32]
print("============ 测试 Decoder ============")
delayer = DecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=256)
decoder = Decoder(delayer, num_layers=6, norm=nn.LayerNorm(d_model))
out = decoder(
tgt,
memory,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
print(out.shape) # [7, 2, 32]
print("============ 测试 Transformer ============")
my_transformer = Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=256,
)
tgt_mask = my_transformer.generate_attn_mask(tgt_len)
output = my_transformer(
src=src,
tgt=tgt,
tgt_mask=tgt_mask,
src_key_padding_mask=src_key_padding_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=src_key_padding_mask,
)
print(output.shape) # [7, 2, 32]

0
util/__init__.py Normal file
View File

43
util/log_helper.py Normal file
View File

@ -0,0 +1,43 @@
import os
import sys
import logging
from datetime import datetime
def logger_init(
log_filename="monitor", log_level=logging.DEBUG, log_dir="./log/", only_file=False
):
"""
Args:
log_filename: 日志文件名.
log_level: 日志等级.
log_dir: 日志目录.
only_file: 是否只保存到日志文件中.
"""
# 指定日志文件路径
if not os.path.exists(log_dir):
os.makedirs(log_dir)
log_filepath = os.path.join(
log_dir, log_filename + "_" + str(datetime.now())[:10] + ".txt"
)
# 指定日志格式
formatter = "[%(asctime)s] - %(levelname)s: %(message)s"
# 只保存到日志文件中
if only_file:
logging.basicConfig(
filename=log_filepath,
level=log_level,
format=formatter,
datefmt="%Y-%m-%d %H:%M:%S",
)
# 保存到日志文件并输出到终端
else:
logging.basicConfig(
level=log_level,
format=formatter,
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.FileHandler(log_filepath),
logging.StreamHandler(sys.stdout),
],
)