MyBERT/test/test_bert_for_mlm.py

59 lines
1.9 KiB
Python

import os
import sys
sys.path.append(os.getcwd())
import torch
import logging
import numpy as np
from model import BertConfig
from model import BertForMLM
from utils import logger_init
class ModelConfig(object):
def __init__(self):
self.project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
self.pretrained_model_dir = os.path.join(
self.project_dir, "archive", "bert_base_uncased_english"
)
self.log_save_dir = os.path.join(self.project_dir, "logs")
logger_init(
log_filename="test_bert_for_mlm",
log_level=logging.INFO,
log_dir=self.log_save_dir,
)
self.use_embedding_weight = True # 是否使用BERT TokenEmbedding的权重作为预测时输出层的权重
# 导入BERT模型部分配置
bert_config_path = os.path.join(self.pretrained_model_dir, "config.json")
bert_config = BertConfig.from_json_file(bert_config_path)
for key, value in bert_config.__dict__.items():
self.__dict__[key] = value
# 将当前配置打印到日志文件中
logging.info("=" * 20)
logging.info("### 将当前配置打印到日志文件中")
for key, value in self.__dict__.items():
logging.info(f"### {key} = {value}")
def generate_data():
ids = np.random.randint(0, 30000, (512, 3))
input_ids = torch.tensor(ids, dtype=torch.long) # #[src_len, batch_size] [512, 3]
labels = np.random.randint(0, 2, (512, 3))
mlm_labels = torch.tensor(labels, dtype=torch.long) # #[src_len, batch_size]
return input_ids, mlm_labels
if __name__ == "__main__":
config = ModelConfig()
input_ids, mlm_labels = generate_data()
model = BertForMLM(config, config.pretrained_model_dir)
output = model(input_ids=input_ids, mlm_labels=None)
print(output)
output = model(input_ids=input_ids, mlm_labels=mlm_labels)
print(output)