实现BertModel类
This commit is contained in:
parent
b2136f965b
commit
c553948cc5
|
@ -0,0 +1,376 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
from torch.nn.init import normal_
|
||||
from .config import BertConfig
|
||||
from .embedding import BertEmbedding
|
||||
from .transformer import MultiheadAttention
|
||||
|
||||
|
||||
def get_activation(activation_string):
|
||||
"""将字符串转换为激活函数"""
|
||||
activation = activation_string.lower()
|
||||
if activation == "linear":
|
||||
return None
|
||||
elif activation == "relu":
|
||||
return nn.ReLU()
|
||||
elif activation == "gelu":
|
||||
return nn.GELU()
|
||||
elif activation == "tanh":
|
||||
return nn.Tanh()
|
||||
else:
|
||||
raise ValueError("Unsupported activation: %s" % activation)
|
||||
|
||||
|
||||
class BertSelfAttention(nn.Module):
|
||||
"""多头自注意力模块"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super(BertSelfAttention, self).__init__()
|
||||
# 使用Pytorch中的多头注意力模块
|
||||
if "use_torch_multi_head" in config.__dict__ and config.use_torch_multi_head:
|
||||
MultiHeadAttention = nn.MultiheadAttention
|
||||
# 使用自实现的多头注意力模块
|
||||
else:
|
||||
MultiHeadAttention = MultiheadAttention
|
||||
|
||||
self.multi_head_attention = MultiHeadAttention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.attention_probs_dropout_prob,
|
||||
)
|
||||
|
||||
def forward(self, query, key, value, attn_mask=None, key_padding_mask=None):
|
||||
return self.multi_head_attention(
|
||||
query, key, value, attn_mask=attn_mask, key_padding_mask=key_padding_mask
|
||||
)
|
||||
|
||||
|
||||
class BertSelfOutput(nn.Module):
|
||||
"""自注意力模块后的残差连接和标准化"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: 多头自注意力模块输出 `#[src_len, batch_size, hidden_size]`
|
||||
input_tensor: 多头自注意力模块输入 `#[src_len, batch_size, hidden_size]`
|
||||
"""
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.layer_norm(input_tensor + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertAttention(nn.Module):
|
||||
"""完整的自注意力模块"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.self_attention = BertSelfAttention(config)
|
||||
self.self_output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: 自注意力模块输入 `#[src_len, batch_size, hidden_size]`
|
||||
attention_mask: Padding mask,需要被mask的token用`True`表示,否者用`False`表示 `#[batch_size, src_len]`
|
||||
"""
|
||||
# self_attn返回编码结果和注意力权重矩阵
|
||||
attn_outputs = self.self_attention(
|
||||
hidden_states,
|
||||
hidden_states,
|
||||
hidden_states,
|
||||
attn_mask=None,
|
||||
key_padding_mask=attention_mask, # 注意:attention_mask是填充掩码,而不是注意力掩码
|
||||
)
|
||||
# attn_outputs[0]: #[src_len, batch_size, hidden_size]
|
||||
output = self.self_output(attn_outputs[0], hidden_states)
|
||||
return output
|
||||
|
||||
|
||||
class BertIntermediate(nn.Module):
|
||||
"""
|
||||
自注意力模块后的线性层——即Transformer FFN中的第一个线性层
|
||||
"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
# 线性层
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
# 激活函数
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.inter_activation = get_activation(config.hidden_act)
|
||||
else:
|
||||
self.inter_activation = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
if self.inter_activation is None:
|
||||
hidden_states = hidden_states
|
||||
else:
|
||||
hidden_states = self.inter_activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertOutput(nn.Module):
|
||||
"""
|
||||
第二个线性层及残差连接、标准化等模块——即Transformer FFN中的第二个线性层
|
||||
"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states, input_tensor):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: 第一个线性层输出 `#[src_len, batch_size, intermediate_size]`
|
||||
input_tensor: 第一个线性层输入 `#[src_len, batch_size, hidden_size]`
|
||||
Return:
|
||||
`#[src_len, batch_size, hidden_size]`
|
||||
"""
|
||||
# #[src_len, batch_size, hidden_size]
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.layer_norm(input_tensor + hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
"""单个Encoder Layer"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.bert_attention = BertAttention(config)
|
||||
self.bert_intermediate = BertIntermediate(config)
|
||||
self.bert_output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
"""
|
||||
Args:
|
||||
hidden_states: `#[src_len, batch_size, hidden_size]`
|
||||
attention_mask: padding mask `#[batch_size, src_len]`
|
||||
Return:
|
||||
`#[src_len, batch_size, hidden_size]`
|
||||
"""
|
||||
# #[src_len, batch_size, hidden_size]
|
||||
attn_output = self.bert_attention(hidden_states, attention_mask)
|
||||
# #[src_len, batch_size, intermediate_size]
|
||||
inter_output = self.bert_intermediate(attn_output)
|
||||
# #[src_len, batch_size, hidden_size]
|
||||
output = self.bert_output(inter_output, attn_output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class BertEncoder(nn.Module):
|
||||
"""
|
||||
Encoder——由多个Encoder Layer堆叠而成
|
||||
"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# 创建num_hidden_layers个Encoder Layer
|
||||
self.bert_layers = nn.ModuleList(
|
||||
[BertLayer(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None):
|
||||
all_encoder_layers = [] # 保存所有Encoder Layer的输出
|
||||
output = hidden_states
|
||||
for _, layer in enumerate(self.bert_layers):
|
||||
output = layer(output, attention_mask)
|
||||
all_encoder_layers.append(output)
|
||||
|
||||
return all_encoder_layers
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
"""
|
||||
用于获取整个句子的语义信息
|
||||
"""
|
||||
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if "pooler_type" not in self.config.__dict__:
|
||||
raise ValueError(
|
||||
"pooler_type must be in ['first_token_transform', 'all_token_average']"
|
||||
"请在配置文件config.json中添加一个pooler_type参数"
|
||||
)
|
||||
# 取第一个token,即[cls] token embedding
|
||||
if self.config.pooler_type == "first_token_transform":
|
||||
# #[batch_size, hidden_size]
|
||||
token_tensor = hidden_states[0, :].reshape(-1, self.config.hidden_size)
|
||||
# 取所有token embedding的平均值
|
||||
elif self.config.pooler_type == "all_token_average":
|
||||
token_tensor = torch.mean(hidden_states, dim=0)
|
||||
|
||||
# #[batch_size, hidden_size]
|
||||
output = self.dense(token_tensor)
|
||||
output = self.activation(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def format_params_for_torch(loaded_params_names, loaded_params):
|
||||
"""
|
||||
将加载的预训练模型参数格式化为符合torch(1.12.0)框架中MultiHeadAttention的形式——Q、K、V weight/bias放在一个tnesor中
|
||||
"""
|
||||
qkv_weight_names = ["query.weight", "key.weight", "value.weight"]
|
||||
qkv_bias_names = ["query.bias", "key.bias", "value.bias"]
|
||||
qkv_weight, qkv_bias = [], []
|
||||
torch_params = []
|
||||
for i in range(len(loaded_params_names)):
|
||||
param_name_in_pretrained = loaded_params_names[i]
|
||||
param_name = ".".join(param_name_in_pretrained.split(".")[-2:])
|
||||
if param_name in qkv_weight_names:
|
||||
qkv_weight.append(loaded_params[param_name_in_pretrained])
|
||||
elif param_name in qkv_bias_names:
|
||||
qkv_bias.append(loaded_params[param_name_in_pretrained])
|
||||
else:
|
||||
torch_params.append(loaded_params[param_name_in_pretrained])
|
||||
if len(qkv_weight) == 3:
|
||||
torch_params.append(torch.cat(qkv_weight, dim=0))
|
||||
qkv_weight = []
|
||||
if len(qkv_bias) == 3:
|
||||
torch_params.append(torch.cat(qkv_bias, dim=0))
|
||||
qkv_bias = []
|
||||
|
||||
return torch_params
|
||||
|
||||
|
||||
def load_512_position(init_embedding, loaded_embedding):
|
||||
"""
|
||||
预训练的BERT模型仅支持最大512个`position_ids`,而自定义的模型配置中
|
||||
`max_positional_embeddings`可能大于512,所以加载时用预训练模型的`positional embedding`矩阵替换随机初始化的`positional embedding`矩阵前512行\
|
||||
Args:
|
||||
init_embedding: 随机初始化的positional embedding矩阵,可能大于512行
|
||||
loaded_embedding: 加载的预训练模型的positional embedding矩阵,等于512行
|
||||
"""
|
||||
logging.info(f"模型配置 max_positional_embeddings > 512")
|
||||
init_embedding[:512, :] = loaded_embedding[:512, :]
|
||||
return init_embedding
|
||||
|
||||
|
||||
class BertModel(nn.Module):
|
||||
def __init__(self, config: BertConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.bert_embedding = BertEmbedding(config)
|
||||
self.bert_encoder = BertEncoder(config)
|
||||
self.bert_pooler = BertPooler(config)
|
||||
self._reset_parameters()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_ids: `#[src_len, batch_size]`
|
||||
position_ids: `#[1, src_len]`
|
||||
token_type_ids: `#[src_len, batch_size]`
|
||||
attention_mask: `#[batch_size, src_len]`
|
||||
Return:
|
||||
`#[src_len, batch_size, hidden_size]`
|
||||
"""
|
||||
input_embed = self.bert_embedding(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
)
|
||||
|
||||
# encoder_outputs包含num_hidden_layers个encoder layers的输出
|
||||
encoder_outputs = self.bert_encoder(
|
||||
hidden_states=input_embed, attention_mask=attention_mask
|
||||
)
|
||||
# 取最后一层encoder layer的输出结果传入pooler获取整个句子的语义信息
|
||||
sequence_output = encoder_outputs[-1] # #[src_len, batch_size, hidden_size]
|
||||
pooled_output = self.bert_pooler(sequence_output) # #[batch_size, hidden_size]
|
||||
|
||||
return pooled_output, encoder_outputs
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""初始化参数"""
|
||||
for param in self.parameters():
|
||||
if param.dim() > 1:
|
||||
normal_(param, mean=0.0, std=self.config.initializer_range)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, config: BertConfig, pretrained_model_dir=None):
|
||||
"""从预训练模型文件创建模型"""
|
||||
model = cls(config) # 创建模型,cls表示类名BertModel
|
||||
# 加载预训练模型
|
||||
pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")
|
||||
if not os.path.exists(pretrained_model_path):
|
||||
raise ValueError(
|
||||
f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"
|
||||
f"中文模型下载地址:https://huggingface.co/bert-base-chinese/tree/main\n"
|
||||
f"英文模型下载地址:https://huggingface.co/bert-base-uncased/tree/main\n"
|
||||
)
|
||||
loaded_params = torch.load(pretrained_model_path)
|
||||
loaded_params_names = list(loaded_params.keys())[:-8]
|
||||
model_params = deepcopy(model.state_dict())
|
||||
model_params_names = list(model_params.keys())[1:]
|
||||
|
||||
if "use_torch_multi_head" in config.__dict__ and config.use_torch_multi_head:
|
||||
logging.info(f"## 注意,正在使用torch框架中的MultiHeadAttention实现")
|
||||
|
||||
torch_params = format_params_for_torch(loaded_params_names, loaded_params)
|
||||
for i in range(len(model_params_names)):
|
||||
logging.debug(
|
||||
f"## 成功赋值参数 {model_params_names[i]} 参数形状为 {torch_params[i].size()}"
|
||||
)
|
||||
if "position_embedding" in model_params_names[i]:
|
||||
if config.max_position_embeddings > 512:
|
||||
new_embedding = load_512_position(
|
||||
model_params[model_params_names[i]],
|
||||
torch_params[i],
|
||||
)
|
||||
model_params[model_params_names[i]] = new_embedding
|
||||
continue
|
||||
|
||||
model_params[model_params_names[i]] = torch_params[i]
|
||||
else:
|
||||
logging.info(
|
||||
f"## 注意,正在使用本地transformer.py中的MultiheadAttention实现,"
|
||||
f"如需使用torch框架中的MultiHeadAttention模块,可设置config.__dict__['use_torch_multi_head'] = True实现"
|
||||
)
|
||||
|
||||
for i in range(len(loaded_params_names)):
|
||||
logging.debug(
|
||||
f"## 成功将参数 {loaded_params_names[i]} 赋值给 {model_params_names[i]} "
|
||||
f"参数形状为 {loaded_params[loaded_params_names[i]].size()}"
|
||||
)
|
||||
if "position_embedding" in model_params_names[i]:
|
||||
if config.max_position_embeddings > 512:
|
||||
new_embedding = load_512_position(
|
||||
model_params[model_params_names[i]],
|
||||
loaded_params[loaded_params_names[i]],
|
||||
)
|
||||
model_params[model_params_names[i]] = new_embedding
|
||||
continue
|
||||
# 把加载的预训练模型参数值赋给新创建模型
|
||||
model_params[model_params_names[i]] = loaded_params[
|
||||
loaded_params_names[i]
|
||||
]
|
||||
|
||||
model.load_state_dict(model_params)
|
||||
return model
|
|
@ -1,7 +1,5 @@
|
|||
import json
|
||||
import copy
|
||||
|
||||
# import six
|
||||
import logging
|
||||
|
||||
|
||||
|
@ -85,19 +83,3 @@ class BertConfig(object):
|
|||
"""把对象转换为字典"""
|
||||
out = copy.deepcopy(self.__dict__)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
|
||||
for key, value in config.__dict__.items():
|
||||
print(f"{key} = {value}")
|
||||
|
||||
print("=" * 20)
|
||||
print(config.to_json_str())
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from .config import BertConfig
|
||||
from torch.nn.init import normal_
|
||||
from bert_config import BertConfig
|
||||
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
|
@ -126,7 +126,7 @@ class BertEmbedding(nn.Module):
|
|||
initializer_range=config.initializer_range,
|
||||
)
|
||||
|
||||
self.layernorm = nn.LayerNorm(config.hidden_size) # 层标准化
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size) # 层标准化
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
# 提前创建所有position id #[1, max_position_embeddings]
|
||||
|
@ -161,55 +161,7 @@ class BertEmbedding(nn.Module):
|
|||
|
||||
# 相加
|
||||
input_embed = token_embed + pos_embed + seg_embed
|
||||
input_embed = self.layernorm(input_embed)
|
||||
input_embed = self.layer_norm(input_embed)
|
||||
input_embed = self.dropout(input_embed)
|
||||
|
||||
return input_embed
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
|
||||
src = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
|
||||
src = src.transpose(0, 1) # #[src_len, batch_size] [5, 2]
|
||||
|
||||
print("***** --------- 测试TokenEmbedding ------------")
|
||||
token_embedding = TokenEmbedding(vocab_size=16, hidden_size=32)
|
||||
token_embed = token_embedding(input_ids=src)
|
||||
print("src shape #[src_len, batch_size]: ", src.shape)
|
||||
print(
|
||||
f"token embedding shape #[src_len, batch_size, hidden_size]: {token_embed.shape}\n"
|
||||
)
|
||||
|
||||
print("***** --------- 测试PositionalEmbedding ------------")
|
||||
# #[1, src_len]
|
||||
position_ids = torch.arange(src.shape[0]).expand((1, -1))
|
||||
position_embedding = PositionalEmbedding(max_position_embeddings=8, hidden_size=32)
|
||||
pos_embed = position_embedding(position_ids=position_ids)
|
||||
# print(position_embedding.embedding.weight) # embedding 矩阵
|
||||
print("position_ids shape #[1, src_len]: ", position_ids.shape)
|
||||
print(f"positional embedding shape #[src_len, 1, hidden_size]: {pos_embed.shape}\n")
|
||||
|
||||
print("***** --------- 测试SegmentEmbedding ------------")
|
||||
token_type_ids = torch.tensor(
|
||||
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
segmet_embedding = SegmentEmbedding(type_vocab_size=2, hidden_size=32)
|
||||
seg_embed = segmet_embedding(token_type_ids)
|
||||
print("token_type_ids shape #[src_len, batch_size]: ", token_type_ids.shape)
|
||||
print(
|
||||
f"segment embedding shape #[src_len, batch_size, hidden_size]: {seg_embed.shape}\n"
|
||||
)
|
||||
|
||||
print("***** --------- 测试BertEmbedding ------------")
|
||||
bert_embedding = BertEmbedding(config)
|
||||
input_embed = bert_embedding(src, token_type_ids=token_type_ids)
|
||||
print(
|
||||
f"input embedding shape #[src_len, batch_size, hidden_size]: {input_embed.shape}"
|
||||
)
|
||||
|
|
|
@ -0,0 +1,565 @@
|
|||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import xavier_uniform_
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Module):
|
||||
"""
|
||||
多头注意力机制
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim=512, num_heads=8, dropout=0.0, bias=True):
|
||||
"""
|
||||
:param embed_dim: 词嵌入维度,即参数d_model
|
||||
:param num_heads: 多头注意力机制中头的数量,即参数nhead
|
||||
:param bias: 线性变换时,是否使用偏置
|
||||
"""
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads # head_dim是指单头注意力中变换矩阵的列数,也即q,k,v向量的维度
|
||||
self.kdim = self.head_dim
|
||||
self.vdim = self.head_dim
|
||||
self.dropout = dropout
|
||||
|
||||
assert (
|
||||
self.head_dim * self.num_heads == self.embed_dim
|
||||
), "embed_dim除以num_heads必须为整数"
|
||||
# 原论文中的 d_k = d_v = d_model/nhead 限制条件
|
||||
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
||||
# 变换矩阵W_q,embed_dim = num_heads * kdim,kdim=qdim
|
||||
# 第二个维度之所以是embed_dim,因为这里同时初始化了num_heads个W_q,也就是num_heads个头,然后横向拼接
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
||||
# W_k,embed_dim = num_heads * kdim
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
||||
# W_v,embed_dim = num_heads * vdim
|
||||
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
||||
# 将多头注意力计算结果(横向拼接)再执行一次线性转换后输出
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=None,
|
||||
key_padding_mask=None,
|
||||
training=True,
|
||||
is_print_shape=False,
|
||||
):
|
||||
"""
|
||||
Encoder中,query、key、value都是源序列src seq\\
|
||||
Decoder中,query、key、value都是目标序列tgt seq\\
|
||||
Encoder和Decoder交互时,key、value指的是Encoder memory,query指的是tgt seq\\
|
||||
:param query: # [tgt_len, batch_size, embed_dim]
|
||||
:param key: # [src_len, batch_size, embed_dim]
|
||||
:param value: # [src_len, batch_size, embed_dim]
|
||||
:param attn_mask: 注意力掩码矩阵 # [tgt_len, src_len] 或 [batch_size * num_heads, tgt_len, src_len]
|
||||
一般只在Decoder的Training中使用,因为训练时并行传入所有tgt tokens,需要掩盖当前时刻之后的tokens信息
|
||||
:param key_padding_mask: 对Padding tokens进行掩码 # [batch_size, src_len]
|
||||
:return:
|
||||
attn_output: 多头注意力计算结果 # [tgt_len, batch_size, embed_dim]
|
||||
attn_output_weights: 多头注意力平均权重矩阵 # [batch_size, tgt_len, src_len]
|
||||
"""
|
||||
|
||||
# 1.计算Q、K、V
|
||||
# 注意:query、key、value是没有经过线性变换前的序列,例如在Encoder中都是源序列src seq
|
||||
Q = self.q_proj(query)
|
||||
# [tgt_len, batch_size, embed_dim] x [embed_dim, num_heads * kdim] = [tgt_len, batch_size, num_heads * kdim]
|
||||
K = self.k_proj(key)
|
||||
# [src_len, batch_size, embed_dim] x [embed_dim, num_heads * kdim] = [src_len, batch_size, num_heads * kdim]
|
||||
V = self.v_proj(value)
|
||||
# [src_len, batch_size, embed_dim] x [embed_dim, num_heads * vdim] = [src_len, batch_size, num_heads * vdim]
|
||||
|
||||
if is_print_shape:
|
||||
print("=" * 80)
|
||||
print("开始计算多头注意力:")
|
||||
print(
|
||||
f"\t 多头数num_heads = {self.num_heads},d_model={query.size(-1)},d_k = d_v = d_model/num_heads={query.size(-1) // self.num_heads}"
|
||||
)
|
||||
print(f"\t query的shape([tgt_len, batch_size, embed_dim]):{query.shape}")
|
||||
print(
|
||||
f"\t W_q的shape([embed_dim, num_heads * kdim]):{self.q_proj.weight.shape}"
|
||||
)
|
||||
print(f"\t Q的shape([tgt_len, batch_size, num_heads * kdim]):{Q.shape}")
|
||||
print("\t" + "-" * 70)
|
||||
|
||||
print(f"\t key的shape([src_len, batch_size, embed_dim]):{key.shape}")
|
||||
print(
|
||||
f"\t W_k的shape([embed_dim, num_heads * kdim]):{self.k_proj.weight.shape}"
|
||||
)
|
||||
print(f"\t K的shape([src_len, batch_size, num_heads * kdim]):{K.shape}")
|
||||
print("\t" + "-" * 70)
|
||||
|
||||
print(f"\t value的shape([src_len, batch_size, embed_dim]):{value.shape}")
|
||||
print(
|
||||
f"\t W_v的shape([embed_dim, num_heads * vdim]):{self.v_proj.weight.shape}"
|
||||
)
|
||||
print(f"\t V的shape([src_len, batch_size, num_heads * vdim]):{V.shape}")
|
||||
print("\t" + "-" * 70)
|
||||
print(
|
||||
"\t ***** 注意,这里的W_q、W_k、W_v是多头注意力变换矩阵拼接的,因此,Q、K、V也是多个q、k、v向量拼接的结果 *****"
|
||||
)
|
||||
|
||||
# 2.缩放,并判断attn_mask维度是否正确
|
||||
scaling = float(self.head_dim) ** -0.5 # 缩放系数
|
||||
Q = Q * scaling
|
||||
# [query_len, batch_size, num_heads * kdim],其中query_len就是tgt_len
|
||||
|
||||
src_len = key.size(0)
|
||||
tgt_len, bsz, _ = query.size() # [tgt_len, batch_size, embed_dim]
|
||||
|
||||
if attn_mask is not None:
|
||||
# [tgt_len, src_len] 或 [batch_size * num_heads, tgt_len, src_len]
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0) # [1, tgt_len, src_len]
|
||||
if list(attn_mask.size()) != [1, tgt_len, src_len]:
|
||||
raise RuntimeError("The size of the 2D attn_mask is not correct.")
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]:
|
||||
raise RuntimeError("The size of the 3D attn_mask is not correct.")
|
||||
# 此时atten_mask的维度变成了3
|
||||
|
||||
# 3.计算注意力得分
|
||||
# 这里需要进行一下变形,以便后续执行bmm运算
|
||||
Q = (
|
||||
Q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.kdim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# [batch_size * num_heads, tgt_len, kdim]
|
||||
K = (
|
||||
K.contiguous()
|
||||
.view(src_len, bsz * self.num_heads, self.kdim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# [batch_size * num_heads, src_len, kdim]
|
||||
V = (
|
||||
V.contiguous()
|
||||
.view(src_len, bsz * self.num_heads, self.vdim)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# [batch_size * num_heads, src_len, vdim]
|
||||
|
||||
attn_weights = torch.bmm(Q, K.transpose(1, 2)) # bmm用于三维tensor的矩阵运算
|
||||
# [batch_size * num_heads, tgt_len, kdim] x [batch_size * num_heads, kdim, src_len]
|
||||
# -> [batch_size * num_heads, tgt_len, src_len] 这是num_heads个Q、K相乘后的注意力矩阵
|
||||
|
||||
# 4.进行掩码操作
|
||||
# Attention mask
|
||||
if attn_mask is not None:
|
||||
attn_weights += attn_mask
|
||||
# [batch_size * num_heads, tgt_len, src_len]
|
||||
|
||||
# Padding mask(列Padding mask)
|
||||
if key_padding_mask is not None:
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
# [batch_size, num_heads, tgt_len, src_len]
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")
|
||||
)
|
||||
# key_padding_mask扩展维度 [batch_size, src_len] -> [batch_size, 1, 1, src_len]
|
||||
# 然后对attn_output_weights进行掩码,其中masked_fill会将值为True(或非0值)的列mask掉
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
# [batch_size * num_heads, tgt_len, src_len]
|
||||
|
||||
# 5.计算多头注意力输出
|
||||
# 计算注意力权重
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
# [batch_size * num_heads, tgt_len, src_len]
|
||||
attn_weights = F.dropout(attn_weights, p=self.dropout, training=training)
|
||||
|
||||
# 计算MultiheadAttention(Q, K, V)
|
||||
attn_output = torch.bmm(attn_weights, V)
|
||||
# [batch_size * num_heads, tgt_len, src_len] x [batch_size * num_heads, src_len, vdim]
|
||||
# -> [batch_size * num_heads, tgt_len, vdim]
|
||||
|
||||
# 最后执行一次线性变换,输出
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, self.num_heads * self.vdim)
|
||||
)
|
||||
# [tgt_len, batch_size, num_heads * vdim]
|
||||
Z = self.out_proj(attn_output)
|
||||
# [tgt_len, batch_size, embed_dim]
|
||||
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
# [batch_size, num_heads, tgt_len, src_len]
|
||||
|
||||
if is_print_shape:
|
||||
print(
|
||||
f"\t 多头注意力计算结束后的形状(横向拼接)为([tgt_len, batch_size, num_heads * vdim]):{attn_output.shape}"
|
||||
)
|
||||
print(
|
||||
f"\t 对多头注意力计算结果进行线性变换的权重W_o形状为([num_heads * vdim, embed_dim]):{self.out_proj.weight.shape}"
|
||||
)
|
||||
print(f"\t 多头注意力计算结果线性变换后的形状为([tgt_len, batch_size, embed_dim]):{Z.shape}")
|
||||
|
||||
return (
|
||||
Z,
|
||||
attn_weights.sum(dim=1) / self.num_heads, # 返回多头注意力权重矩阵的平均值
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""
|
||||
初始化参数
|
||||
"""
|
||||
for param in self.parameters():
|
||||
if param.dim() > 1:
|
||||
xavier_uniform_(param)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
"""
|
||||
对module进行N次拷贝
|
||||
"""
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""
|
||||
单个编码层
|
||||
"""
|
||||
|
||||
def __init__(self, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1):
|
||||
"""
|
||||
:param d_model: 模型中向量维度,即词嵌入维度
|
||||
:param nhead: 多头注意力中的多头数量
|
||||
:param dim_feedforward: 全连接层的输出维度
|
||||
:param dropout: 丢弃率
|
||||
"""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim=d_model, num_heads=nhead, dropout=dropout
|
||||
)
|
||||
|
||||
# 多头注意力输出后的Add&Norm
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
# 注意:LayerNorm是沿着feature维度归一化;BatchNorm是沿着batch维度归一化
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
|
||||
# Feed Forward Network
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.activation = nn.ReLU()
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""
|
||||
:param src: # [src_len, batch_size, embed_dim]
|
||||
:param src_mask: None,Encoder中不需要Attention Mask
|
||||
:param src_key_padding_mask: # [batch_size, src_len]
|
||||
:return: # [src_len, batch_size, embed_dim] <==> [src_len, batch_size, num_heads * kdim]
|
||||
"""
|
||||
# 计算多头注意力
|
||||
src1 = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
# [src_len, batch_size, embed_dim],其中embed_dim = num_heads * kdim
|
||||
|
||||
# 残差连接和LayerNorm
|
||||
src = src + self.dropout1(src1)
|
||||
src = self.norm1(src)
|
||||
|
||||
# Feed Forward
|
||||
src1 = self.activation(self.linear1(src))
|
||||
# [src_len, batch_size, dim_feedforward]
|
||||
src1 = self.linear2(self.dropout2(src1))
|
||||
# [src_len, batch_size, embed_dim]
|
||||
|
||||
# 残差连接和LayerNorm
|
||||
src = src + self.dropout3(src1)
|
||||
src = self.norm2(src)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
编码器,由多个编码层堆叠而成
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer, num_layers=6, norm=None):
|
||||
"""
|
||||
:param encoder_layer: 单个编码层
|
||||
:param num_layers: 编码层数
|
||||
:param norm: 归一化层
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
# 拷贝多个编码层,得到编码层列表
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""
|
||||
:param src: # [src_len, batch_size, embed_dim]
|
||||
:param src_mask: None,Encoder中不需要Attention Mask
|
||||
:param src_key_padding_mask: # [batch_size, src_len]
|
||||
:return: # [src_len, batch_size, embed_dim] <==> [src_len, batch_size, num_heads * kdim]
|
||||
"""
|
||||
output = src
|
||||
# 遍历每一个编码层,执行forward,并传递给下一层
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
|
||||
)
|
||||
# 对最后一层输出执行Norm操作
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
单个解码层
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
|
||||
"""
|
||||
:param d_model: 模型中向量维度,即词嵌入维度
|
||||
:param nhead: 多头注意力中的多头数量
|
||||
:param dim_feedforward: 全连接层的输出维度
|
||||
:param dropout: 丢弃率
|
||||
"""
|
||||
super(DecoderLayer, self).__init__()
|
||||
# Masked多头注意力,对解码层输入序列进行计算
|
||||
self.self_attn = MultiheadAttention(
|
||||
embed_dim=d_model, num_heads=nhead, dropout=dropout
|
||||
)
|
||||
# 编码器输出(memory)和解码层交互的多头注意力
|
||||
self.multihead_attn = MultiheadAttention(
|
||||
embed_dim=d_model, num_heads=nhead, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
# Feed Forward Network
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.activation = nn.ReLU()
|
||||
self.dropout4 = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
:param tgt: 解码层输入序列 # [tgt_len, batch_size, embed_dim]
|
||||
:param memory: 编码器输出(memory) # [src_len, batch_size, embed_dim]
|
||||
:param tgt_mask: 解码层多头注意力掩码 # [tgt_len, tgt_len]
|
||||
:param memory_mask: 编码器-解码器交互多头注意力掩码,一般为None
|
||||
:param tgt_key_padding_mask: 解码器输入序列的Padding情况 # [batch_size, tgt_len]
|
||||
:param memory_key_padding_mask: 编码器输入序列的Padding情况 # [batch_size, src_len]
|
||||
:return: # [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
|
||||
"""
|
||||
# Masked多头注意力计算
|
||||
tgt1 = self.self_attn(
|
||||
tgt,
|
||||
tgt,
|
||||
tgt,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask,
|
||||
)[0]
|
||||
|
||||
# 残差连接&LayerNorm
|
||||
tgt = tgt + self.dropout1(tgt1)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
# 编码器-解码器交互多头注意力计算
|
||||
tgt1 = self.multihead_attn(
|
||||
tgt,
|
||||
memory,
|
||||
memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
|
||||
# 残差连接&LayerNorm
|
||||
tgt = tgt + self.dropout2(tgt1)
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
# Feed Forward
|
||||
tgt1 = self.activation(self.linear1(tgt))
|
||||
# [tgt_len, batch_size, dim_feedforward]
|
||||
tgt1 = self.linear2(self.dropout4(tgt1))
|
||||
# [tgt_len, batch_size, embed_dim]
|
||||
|
||||
# 残差连接&LayerNorm
|
||||
tgt = tgt + self.dropout3(tgt1)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
解码器,由多个解码层堆叠而成
|
||||
"""
|
||||
|
||||
def __init__(self, decoder_layer, num_layers, norm=None):
|
||||
"""
|
||||
:param decoder_layer: 单个解码层
|
||||
:param num_layers: 解码层数
|
||||
:param norm: 归一化层
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
self.norm = norm
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
:param tgt: 解码层输入序列 # [tgt_len, batch_size, embed_dim]
|
||||
:param memory: 编码器输出(memory) # [src_len, batch_size, embed_dim]
|
||||
:param tgt_mask: 解码层多头注意力掩码 # [tgt_len, tgt_len]
|
||||
:param memory_mask: 编码器-解码器交互多头注意力掩码,一般为None
|
||||
:param tgt_key_padding_mask: 解码器输入序列的Padding情况 # [batch_size, tgt_len]
|
||||
:param memory_key_padding_mask: 编码器输入序列的Padding情况 # [batch_size, src_len]
|
||||
:return: # [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
|
||||
"""
|
||||
output = tgt # [tgt_len,batch_size, embed_dim]
|
||||
# 遍历每一个解码层,执行forward,并传递给下一层
|
||||
for layer in self.layers:
|
||||
output = layer(
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
# 对最后一层输出执行Norm操作
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
):
|
||||
"""
|
||||
:param d_model: 模型中向量维度,即词嵌入维度
|
||||
:param nhead: 多头注意力中的多头数量
|
||||
:param num_encoder_layers: EncoderLayer堆叠的数量
|
||||
:param num_decoder_layers: DecoderLayer堆叠的数量
|
||||
:param dim_feedforward: 全连接层的输出维度
|
||||
:param dropout: 丢弃率
|
||||
"""
|
||||
super(Transformer, self).__init__()
|
||||
# Encoder
|
||||
encoder_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout)
|
||||
encoder_norm = nn.LayerNorm(d_model)
|
||||
self.encoder = Encoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||
|
||||
# Decoder
|
||||
decoder_layer = DecoderLayer(d_model, nhead, dim_feedforward, dropout)
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src,
|
||||
tgt,
|
||||
src_mask=None,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
src_key_padding_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None,
|
||||
):
|
||||
"""
|
||||
:param src: # [src_len, batch_size, embed_dim]
|
||||
:param tgt: # [tgt_len, batch_size, embed_dim]
|
||||
:param src_mask: None
|
||||
:param tgt_mask: # [tgt_len, tgt_len]
|
||||
:param memory_mask: None
|
||||
:param src_key_padding_mask: # [batch_size, src_len]
|
||||
:param tgt_key_padding_mask: # [batch_size, tgt_len]
|
||||
:param memory_key_padding_mask: # [batch_size, src_len]
|
||||
:return: [tgt_len, batch_size, embed_dim] <==> [tgt_len, batch_size, num_heads * kdim]
|
||||
"""
|
||||
|
||||
# Encoding,生成memory
|
||||
memory = self.encoder(
|
||||
src, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask
|
||||
)
|
||||
|
||||
# Decoding
|
||||
output = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""
|
||||
初始化参数
|
||||
"""
|
||||
for param in self.parameters():
|
||||
if param.dim() > 1:
|
||||
xavier_uniform_(param)
|
||||
|
||||
def generate_attn_mask(self, sz):
|
||||
"""
|
||||
生成注意力掩码矩阵
|
||||
"""
|
||||
mask = torch.tril(torch.ones(sz, sz)) # tril取矩阵下三角(包括对角线)
|
||||
mask = mask.masked_fill(mask == 0, float("-inf")).masked_fill(
|
||||
mask == 1, float(0.0)
|
||||
)
|
||||
return mask
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
import logging
|
||||
import torch
|
||||
from util.log_helper import logger_init
|
||||
from model.BERT.config import BertConfig
|
||||
from model.BERT.bert import BertEmbedding
|
||||
from model.BERT.bert import BertAttention
|
||||
from model.BERT.bert import BertLayer
|
||||
from model.BERT.bert import BertEncoder
|
||||
from model.BERT.bert import BertModel
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger_init(log_filename="test", log_level=logging.DEBUG)
|
||||
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
# # 使用torch框架中的MultiHeadAttention实现
|
||||
# config.__dict__["use_torch_multi_head"] = True
|
||||
config.max_position_embeddings = 518 # 测试max_position_embeddings大于512时的情况
|
||||
|
||||
# #[src_len, batch_size]
|
||||
src = torch.tensor(
|
||||
[[1, 3, 5, 7, 9, 2, 3], [2, 4, 6, 8, 10, 0, 0]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
print(f"input shape #[src_len, batch_size]: ", src.shape)
|
||||
# #[src_len, batch_size]
|
||||
token_type_ids = torch.tensor(
|
||||
[[0, 0, 0, 1, 1, 1, 1], [0, 0, 1, 1, 1, 0, 0]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
|
||||
# attention_mask实质上是padding mask #[src_len, batch_size]
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False, False, False, True, True],
|
||||
[False, False, False, False, False, False, True],
|
||||
]
|
||||
)
|
||||
|
||||
print("------ 测试BertEmbedding -------")
|
||||
bert_embedding = BertEmbedding(config)
|
||||
bert_embed_out = bert_embedding(src, token_type_ids=token_type_ids)
|
||||
print(
|
||||
f"BertEmbedding output shape #[src_len, batch_size, hidden_size]: {bert_embed_out.shape}"
|
||||
)
|
||||
|
||||
print("------ 测试BertAttention -------")
|
||||
bert_attention = BertAttention(config)
|
||||
bert_attn_out = bert_attention(bert_embed_out, attention_mask=attention_mask)
|
||||
print(
|
||||
f"BertAttention output shape #[src_len, batch_size, hidden_size]: {bert_attn_out.shape}",
|
||||
)
|
||||
|
||||
print("------ 测试BertLayer -------")
|
||||
bert_layer = BertLayer(config)
|
||||
bert_layer_out = bert_layer(bert_embed_out, attention_mask)
|
||||
print(
|
||||
f"BertLayer output shape #[src_len, batch_size, hidden_size]: {bert_layer_out.shape}",
|
||||
)
|
||||
|
||||
print("------ 测试BertEncoder -------")
|
||||
bert_encoder = BertEncoder(config)
|
||||
bert_encoder_outs = bert_encoder(bert_embed_out, attention_mask)
|
||||
print(
|
||||
f"Num of BertEncoder outputs #[config.num_hidden_layers]: {len(bert_encoder_outs)}",
|
||||
)
|
||||
print(
|
||||
f"Each output shape of BertEncoder #[src_len, batch_size, hidden_size]: {bert_encoder_outs[0].shape}",
|
||||
)
|
||||
|
||||
print("------ 测试BertModel -------")
|
||||
position_ids = torch.arange(src.size()[0]).expand((1, -1)) # [1, src_len]
|
||||
bert_model = BertModel(config)
|
||||
bert_pooler_out = bert_model(
|
||||
input_ids=src,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)[0]
|
||||
print(
|
||||
f"BertPooler output shape #[batch_size, hidden_size]: {bert_pooler_out.shape}",
|
||||
)
|
||||
print("======= BertModel参数: ========")
|
||||
for param in bert_model.state_dict():
|
||||
print(param, "\t #", bert_model.state_dict()[param].size())
|
||||
|
||||
print("------ 测试BertModel载入预训练模型 -------")
|
||||
model = BertModel.from_pretrained(
|
||||
config=config, pretrained_model_dir="./archive/bert_base_chinese"
|
||||
)
|
|
@ -0,0 +1,15 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
from model.BERT.config import BertConfig
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
|
||||
for key, value in config.__dict__.items():
|
||||
print(f"{key} = {value}")
|
||||
|
||||
print("=" * 20)
|
||||
print(config.to_json_str())
|
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
import torch
|
||||
from model.BERT.config import BertConfig
|
||||
from model.BERT.embedding import TokenEmbedding
|
||||
from model.BERT.embedding import PositionalEmbedding
|
||||
from model.BERT.embedding import SegmentEmbedding
|
||||
from model.BERT.embedding import BertEmbedding
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_file = "./archive/bert_base_chinese/config.json"
|
||||
config = BertConfig.from_json_file(json_file)
|
||||
|
||||
src = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]], dtype=torch.long)
|
||||
src = src.transpose(0, 1) # #[src_len, batch_size] [5, 2]
|
||||
|
||||
print("***** --------- 测试TokenEmbedding ------------")
|
||||
token_embedding = TokenEmbedding(vocab_size=16, hidden_size=32)
|
||||
token_embed = token_embedding(input_ids=src)
|
||||
print("src shape #[src_len, batch_size]: ", src.shape)
|
||||
print(
|
||||
f"token embedding shape #[src_len, batch_size, hidden_size]: {token_embed.shape}\n"
|
||||
)
|
||||
|
||||
print("***** --------- 测试PositionalEmbedding ------------")
|
||||
# #[1, src_len]
|
||||
position_ids = torch.arange(src.shape[0]).expand((1, -1))
|
||||
position_embedding = PositionalEmbedding(max_position_embeddings=8, hidden_size=32)
|
||||
pos_embed = position_embedding(position_ids=position_ids)
|
||||
# print(position_embedding.embedding.weight) # embedding 矩阵
|
||||
print("position_ids shape #[1, src_len]: ", position_ids.shape)
|
||||
print(f"positional embedding shape #[src_len, 1, hidden_size]: {pos_embed.shape}\n")
|
||||
|
||||
print("***** --------- 测试SegmentEmbedding ------------")
|
||||
token_type_ids = torch.tensor(
|
||||
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 1]], dtype=torch.long
|
||||
).transpose(0, 1)
|
||||
segmet_embedding = SegmentEmbedding(type_vocab_size=2, hidden_size=32)
|
||||
seg_embed = segmet_embedding(token_type_ids)
|
||||
print("token_type_ids shape #[src_len, batch_size]: ", token_type_ids.shape)
|
||||
print(
|
||||
f"segment embedding shape #[src_len, batch_size, hidden_size]: {seg_embed.shape}\n"
|
||||
)
|
||||
|
||||
print("***** --------- 测试BertEmbedding ------------")
|
||||
bert_embedding = BertEmbedding(config)
|
||||
input_embed = bert_embedding(src, token_type_ids=token_type_ids)
|
||||
print(
|
||||
f"input embedding shape #[src_len, batch_size, hidden_size]: {input_embed.shape}"
|
||||
)
|
|
@ -0,0 +1,76 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.getcwd())
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from model.BERT.transformer import MultiheadAttention
|
||||
from model.BERT.transformer import EncoderLayer, Encoder
|
||||
from model.BERT.transformer import DecoderLayer, Decoder
|
||||
from model.BERT.transformer import Transformer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = 2
|
||||
src_len = 5
|
||||
tgt_len = 7
|
||||
d_model = 32
|
||||
nhead = 4
|
||||
|
||||
src = torch.rand((src_len, batch_size, d_model)) # [src_len, batch_size, embed_dim]
|
||||
src_key_padding_mask = torch.tensor(
|
||||
[[False, False, False, True, True], [False, False, False, False, True]]
|
||||
) # [batch_size, src_len]
|
||||
|
||||
tgt = torch.rand((tgt_len, batch_size, d_model)) # [tgt_len, batch_size, embed_dim]
|
||||
tgt_key_padding_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False, False, False, False, True],
|
||||
[False, False, False, False, True, True, True],
|
||||
]
|
||||
) # [batch_size, tgt_len]
|
||||
|
||||
print("============ 测试 MultiheadAttention ============")
|
||||
my_mh = MultiheadAttention(embed_dim=d_model, num_heads=nhead)
|
||||
my_mh_out = my_mh(src, src, src, key_padding_mask=src_key_padding_mask)
|
||||
print(my_mh_out[0].shape) # [5, 2, 32]
|
||||
|
||||
mh = torch.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead)
|
||||
mh_out = mh(src, src, src, key_padding_mask=src_key_padding_mask)
|
||||
print(mh_out[0].shape) # [5, 2, 32]
|
||||
|
||||
print("============ 测试 Encoder ============")
|
||||
enlayer = EncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=512)
|
||||
encoder = Encoder(enlayer, num_layers=4, norm=nn.LayerNorm(d_model))
|
||||
memory = encoder(src, src_key_padding_mask=src_key_padding_mask)
|
||||
print(memory.shape) # [5, 2, 32]
|
||||
|
||||
print("============ 测试 Decoder ============")
|
||||
delayer = DecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=256)
|
||||
decoder = Decoder(delayer, num_layers=6, norm=nn.LayerNorm(d_model))
|
||||
out = decoder(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
print(out.shape) # [7, 2, 32]
|
||||
|
||||
print("============ 测试 Transformer ============")
|
||||
my_transformer = Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=256,
|
||||
)
|
||||
tgt_mask = my_transformer.generate_attn_mask(tgt_len)
|
||||
output = my_transformer(
|
||||
src=src,
|
||||
tgt=tgt,
|
||||
tgt_mask=tgt_mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
print(output.shape) # [7, 2, 32]
|
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def logger_init(
|
||||
log_filename="monitor", log_level=logging.DEBUG, log_dir="./log/", only_file=False
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
log_filename: 日志文件名.
|
||||
log_level: 日志等级.
|
||||
log_dir: 日志目录.
|
||||
only_file: 是否只保存到日志文件中.
|
||||
"""
|
||||
# 指定日志文件路径
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir)
|
||||
log_filepath = os.path.join(
|
||||
log_dir, log_filename + "_" + str(datetime.now())[:10] + ".txt"
|
||||
)
|
||||
# 指定日志格式
|
||||
formatter = "[%(asctime)s] - %(levelname)s: %(message)s"
|
||||
# 只保存到日志文件中
|
||||
if only_file:
|
||||
logging.basicConfig(
|
||||
filename=log_filepath,
|
||||
level=log_level,
|
||||
format=formatter,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
# 保存到日志文件并输出到终端
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=log_level,
|
||||
format=formatter,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[
|
||||
logging.FileHandler(log_filepath),
|
||||
logging.StreamHandler(sys.stdout),
|
||||
],
|
||||
)
|
Loading…
Reference in New Issue