add dump.py

This commit is contained in:
liangliangou 2025-04-18 09:03:02 +08:00
parent ce94fb9b22
commit 7642e53a1b
1 changed files with 95 additions and 0 deletions

95
netrans_py/dump.py Normal file
View File

@ -0,0 +1,95 @@
import os
import sys
import subprocess
from utils import check_path, AttributeCopier, creat_cla
class Infer(AttributeCopier):
def __init__(self, source_obj) -> None:
super().__init__(source_obj)
@check_path
def inference_network(self):
netrans = self.netrans
quantized = self.quantize_type
name = self.model_name
# print(self.__dict__)
netrans += " dump"
# 进入模型目录
# 定义类型和量化类型
if quantized == 'float':
type_ = 'float32'
quantization_type = 'float32'
elif quantized == 'uint8':
quantization_type = 'asymmetric_affine'
type_ = 'quantized'
elif quantized == 'int8':
quantization_type = 'dynamic_fixed_point-8'
type_ = 'quantized'
elif quantized == 'int16':
quantization_type = 'dynamic_fixed_point-16'
type_ = 'quantized'
else:
print("=========== wrong quantization_type ! ( float / uint8 / int8 / int16 )===========")
sys.exit(-1)
# 构建推理命令
inf_path = './inf'
cmd = f"{netrans} \
--dtype {type_} \
--batch-size 1 \
--model-quantize {name}_{quantization_type}.quantize \
--model {name}.json \
--model-data {name}.data \
--output-dir {inf_path} \
--with-input-meta {name}_inputmeta.yml \
--device CPU"
# 执行推理命令
if self.verbose is True:
print(cmd)
result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
# 检查执行结果
if result.returncode == 0:
print("\033[32m SUCCESS \033[0m")
else:
print(f"\033[31m ERROR: {result.stderr} \033[0m")
# 返回原始目录
def main():
# 检查命令行参数数量
if len(sys.argv) < 3:
print("Input a network name and quantized type ( float / uint8 / int8 / int16 )")
sys.exit(-1)
# 检查网络目录是否存在
network_name = sys.argv[1]
if not os.path.exists(network_name):
print(f"Directory {network_name} does not exist !")
sys.exit(-2)
# print("here")
# 定义 netrans 路径
# netrans = os.path.join(os.environ['NETRANS_PATH'], 'pnnacc')
network_name = sys.argv[1]
# check_env(network_name)
netrans_path = os.environ['NETRANS_PATH']
# netrans = os.path.join(netrans_path, 'pnnacc')
quantize_type = sys.argv[2]
cla = creat_cla(netrans_path, network_name,quantize_type,False)
# 调用量化函数
func = Infer(cla)
func.inference_network()
# 定义数据集文件路径
# dataset_path = './dataset.txt'
# 调用推理函数
# inference_network(network_name, sys.argv[2])
if __name__ == '__main__':
# print("main")
main()