81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
import os
|
|
import sys
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
import torch
|
|
import logging
|
|
from model import BertConfig
|
|
from model import BertForNSP
|
|
from utils import logger_init
|
|
|
|
|
|
class ModelConfig(object):
|
|
def __init__(self):
|
|
self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
self.pretrained_model_dir = os.path.join(
|
|
self.project_dir, "archive", "bert_base_uncased_english"
|
|
)
|
|
self.log_save_dir = os.path.join(self.project_dir, "logs")
|
|
logger_init(
|
|
log_filename="test_bert_for_nsp",
|
|
log_level=logging.INFO,
|
|
log_dir=self.log_save_dir,
|
|
)
|
|
|
|
# 导入BERT模型部分配置
|
|
bert_config_path = os.path.join(self.pretrained_model_dir, "config.json")
|
|
bert_config = BertConfig.from_json_file(bert_config_path)
|
|
for key, value in bert_config.__dict__.items():
|
|
self.__dict__[key] = value
|
|
|
|
# 将当前配置打印到日志文件中
|
|
logging.info("=" * 20)
|
|
logging.info("### 将当前配置打印到日志文件中")
|
|
for key, value in self.__dict__.items():
|
|
logging.info(f"### {key} = {value}")
|
|
|
|
|
|
def generate_data():
|
|
# #[src_len, batch_size] [5, 3]
|
|
input_ids = torch.tensor(
|
|
[[1, 1, 1, 4, 5], [6, 7, 8, 7, 2], [5, 3, 4, 3, 4]], dtype=torch.long
|
|
).transpose(0, 1)
|
|
|
|
# #[src_len, batch_size]
|
|
token_type_ids = torch.tensor(
|
|
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 0], [0, 0, 0, 1, 1]], dtype=torch.long
|
|
).transpose(0, 1)
|
|
|
|
# #[batch_size, src_len]
|
|
padding_mask = torch.tensor(
|
|
[
|
|
[False, False, False, False, True],
|
|
[False, False, False, True, True],
|
|
[False, False, False, False, True],
|
|
]
|
|
)
|
|
# [batch_size,]
|
|
nsp_labels = torch.tensor([0, 1, 0], dtype=torch.long)
|
|
return input_ids, token_type_ids, padding_mask, nsp_labels
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = ModelConfig()
|
|
input_ids, token_type_ids, padding_mask, nsp_labels = generate_data()
|
|
model = BertForNSP(config, config.pretrained_model_dir)
|
|
output = model(
|
|
input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=padding_mask,
|
|
nsp_labels=None,
|
|
)
|
|
print(output)
|
|
output = model(
|
|
input_ids=input_ids,
|
|
token_type_ids=token_type_ids,
|
|
attention_mask=padding_mask,
|
|
nsp_labels=nsp_labels,
|
|
)
|
|
print(output)
|